├── .build └── .gitkeep ├── .dockerignore ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── black.toml ├── build.py ├── build.sh ├── configs ├── _base_ │ ├── datasets │ │ ├── wheat_detection_mstrain.py │ │ ├── wheat_detection_mstrain_hard.py │ │ ├── wheat_detection_mstrain_light.py │ │ └── wheat_detection_mstrain_pseudo.py │ ├── default_runtime.py │ ├── models │ │ ├── cascade_rcnn_r50_fpn.py │ │ ├── detectors_r50_ga.py │ │ └── universe_r101_gfl.py │ └── schedules │ │ ├── schedule_1x.py │ │ ├── schedule_4x.py │ │ └── schedule_pseudo.py ├── detectors │ ├── detectors_r50_ga_mstrain_local_pseudo.py │ ├── detectors_r50_ga_mstrain_private_pseudo.py │ ├── detectors_r50_ga_mstrain_public_pseudo.py │ ├── detectors_r50_ga_mstrain_stage0.py │ ├── detectors_r50_ga_mstrain_stage1.py │ └── detectors_r50_ga_mstrain_stage2.py └── universe_r101_gfl │ ├── universe_r101_gfl_mstrain_local_pseudo.py │ ├── universe_r101_gfl_mstrain_private_pseudo.py │ ├── universe_r101_gfl_mstrain_public_pseudo.py │ ├── universe_r101_gfl_mstrain_stage0.py │ ├── universe_r101_gfl_mstrain_stage1.py │ └── universe_r101_gfl_mstrain_stage2.py ├── gwd ├── __init__.py ├── assigners │ ├── __init__.py │ └── atss_assigner.py ├── backbones │ ├── __init__.py │ ├── res2net.py │ └── sac │ │ ├── __init__.py │ │ ├── conv_aws.py │ │ ├── resnet.py │ │ ├── resnext.py │ │ └── saconv.py ├── colorization │ ├── __init__.py │ ├── generate.py │ └── models.py ├── converters │ ├── __init__.py │ ├── coco2crop.py │ ├── images2coco.py │ ├── kaggle2coco.py │ └── spike2kaggle.py ├── datasets │ ├── __init__.py │ ├── evaluation.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── albumentations.py │ │ ├── loading.py │ │ ├── test_aug.py │ │ ├── transforms.py │ │ └── utils.py │ ├── source_balanced_dataset.py │ └── wheat_detection.py ├── dense_heads │ ├── __init__.py │ └── gfl_head.py ├── detectors │ ├── __init__.py │ ├── atss.py │ └── rfp.py ├── eda │ ├── __init__.py │ ├── coco.py │ ├── kmeans.py │ ├── pipeline.py │ ├── sources.py │ ├── submission.py │ └── visualization.py ├── jigsaw │ ├── __init__.py │ ├── calculate_distance.py │ ├── collect_bboxes.py │ ├── collect_images.py │ └── crop.py ├── losses │ ├── __init__.py │ └── cross_entropy_loss.py ├── misc │ ├── __init__.py │ └── logging.py ├── necks │ ├── __init__.py │ ├── sepc.py │ └── sepc_dconv.py ├── patches.py ├── prepare_pseudo.py ├── select_anchors.py ├── split_folds.py ├── stylize │ ├── __init__.py │ ├── function.py │ ├── net.py │ └── run.py ├── submit.py ├── test.py ├── train.py ├── wbf.py └── weights │ ├── __init__.py │ ├── prepare_weights.py │ ├── rm_optimizer.py │ └── upgrade_model_version.py ├── requirements.in ├── requirements.txt ├── script_template.py ├── scripts ├── colorization.sh ├── kaggle2coco.sh ├── seach_thresholds.sh ├── stylize.sh ├── test.sh ├── test_crops.sh ├── train.sh ├── train_detectors.sh └── train_universenet.sh ├── setup.cfg ├── setup.py └── tests ├── __init__.py └── test_evaluation.py /.build/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/.build/.gitkeep -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # custom: 2 | 3 | .git/* 4 | data/* 5 | ipynb/* 6 | .idea/* 7 | 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Environments 93 | .env 94 | .venv 95 | env/ 96 | venv/ 97 | ENV/ 98 | env.bak/ 99 | venv.bak/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # custom: 2 | artifacts 3 | *.pth 4 | *.tflite 5 | 6 | *.DS_Store* 7 | 8 | data/* 9 | images/* 10 | ipynb/* 11 | .idea/* 12 | inference/gwd/data/* 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | 119 | node_modules/ 120 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: check-yaml 7 | - id: end-of-file-fixer 8 | - repo: https://gitlab.com/pycqa/flake8.git 9 | rev: 3.8.0 10 | hooks: 11 | - id: flake8 12 | - repo: https://github.com/psf/black 13 | rev: 19.3b0 14 | hooks: 15 | - id: black 16 | exclude: gwd/__init__.py 17 | args: [--config=black.toml] 18 | - repo: https://github.com/timothycrosley/isort 19 | rev: 4.3.21 20 | hooks: 21 | - id: isort 22 | exclude: gwd/__init__.py 23 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.5" 2 | ARG CUDA="10.1" 3 | ARG CUDNN="7" 4 | 5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel 6 | 7 | ARG DEBIAN_FRONTEND=noninteractive 8 | 9 | RUN apt-get update && apt-get install -y \ 10 | git \ 11 | ninja-build \ 12 | libglib2.0-0 \ 13 | libsm6 \ 14 | libxrender-dev \ 15 | libxext6 \ 16 | locales \ 17 | wget \ 18 | && apt-get clean \ 19 | && rm -rf /var/lib/apt/lists/* 20 | 21 | RUN locale-gen en_US.UTF-8 22 | ENV LANG en_US.UTF-8 23 | ENV LANGUAGE en_US:en 24 | ENV LC_ALL en_US.UTF-8 25 | 26 | RUN pip install --no-cache-dir --upgrade pip 27 | 28 | COPY ./requirements.txt /requirements.txt 29 | RUN pip install --no-cache-dir -r /requirements.txt 30 | 31 | # Install mmdetection 32 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX" 33 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" 34 | ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" 35 | ENV FORCE_CUDA="1" 36 | RUN git clone https://github.com/open-mmlab/mmdetection /tmp/mmdetection \ 37 | && cd /tmp/mmdetection \ 38 | && git checkout 38dfa875c048207fd46b8cd2b7ccafd5239b4a4e \ 39 | && pip install --no-cache-dir "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=pycocotools" \ 40 | && pip install --no-cache-dir /tmp/mmdetection \ 41 | && rm -r /tmp/mmdetection 42 | 43 | RUN pip install --no-cache-dir mmcv==0.6.2 44 | 45 | ENV PROJECT_ROOT /global-wheat-detection 46 | ENV PYTHONPATH "${PYTHONPATH}:${PROJECT_ROOT}" 47 | WORKDIR /global-wheat-detection 48 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | APP_NAME=amirassov/gwd 2 | CONTAINER_NAME=gwd 3 | PROJECT_NAME=/global-wheat-detection 4 | 5 | help: ## This help. 6 | @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) 7 | 8 | build: ## Build the container 9 | nvidia-docker build -t ${APP_NAME}$(MODE) -f Dockerfile$(MODE) . 10 | 11 | run-extreme: ## Run container in extreme 12 | nvidia-docker run \ 13 | -e DISPLAY=unix${DISPLAY} -v /tmp/.X11-unix:/tmp/.X11-unix --privileged \ 14 | --ipc=host \ 15 | -itd \ 16 | --name=${CONTAINER_NAME}$(MODE) \ 17 | -v $(shell pwd):${PROJECT_NAME} \ 18 | -v /home/amirassov/global_wheat_data:/data_old \ 19 | -v /home/amirassov/global_wheat_data_test:/data \ 20 | -v /home/amirassov/global_wheat_dumps:/dumps ${APP_NAME}$(MODE) bash 21 | 22 | run-dgx: ## Run container in dgx 23 | nvidia-docker run \ 24 | --ipc=host \ 25 | -itd \ 26 | --name=${CONTAINER_NAME}$(MODE) \ 27 | -v $(shell pwd):${PROJECT_NAME} \ 28 | -v /raid/data_share/amirassov/global_wheat_data:/data \ 29 | -v /raid/data_share/amirassov/global_wheat_dumps:/dumps ${APP_NAME}$(MODE) bash 30 | 31 | exec: ## Run a bash in a running container 32 | nvidia-docker exec -it ${CONTAINER_NAME}$(MODE) bash 33 | 34 | stop: ## Stop and remove a running container 35 | docker stop ${CONTAINER_NAME}$(MODE); docker rm ${CONTAINER_NAME}$(MODE) 36 | -------------------------------------------------------------------------------- /black.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py37'] 4 | -------------------------------------------------------------------------------- /build.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import gzip 3 | from pathlib import Path 4 | 5 | 6 | def encode_file(path: Path) -> str: 7 | compressed = gzip.compress(path.read_bytes(), compresslevel=9) 8 | return base64.b64encode(compressed).decode("utf-8") 9 | 10 | 11 | def build_script(): 12 | to_encode = list(Path("gwd").rglob("*.py")) + [Path("setup.py")] + list(Path("configs").rglob("*.py")) 13 | file_data = {str(path): encode_file(path) for path in to_encode} 14 | template = Path("script_template.py").read_text("utf8") 15 | Path(".build/script.py").write_text(template.replace("{file_data}", str(file_data)), encoding="utf8") 16 | 17 | 18 | if __name__ == "__main__": 19 | build_script() 20 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | python build.py 5 | cat .build/script.py 6 | -------------------------------------------------------------------------------- /configs/_base_/datasets/wheat_detection_mstrain.py: -------------------------------------------------------------------------------- 1 | dataset_type = "WheatDataset" 2 | data_root = "/data/" 3 | 4 | img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 5 | 6 | albu_train_transforms = [ 7 | dict(type="RandomRotate90", p=1.0), 8 | dict( 9 | type="OneOf", 10 | transforms=[ 11 | dict(type="HueSaturationValue", hue_shift_limit=10, sat_shift_limit=35, val_shift_limit=25), 12 | dict(type="RandomGamma"), 13 | dict(type="CLAHE"), 14 | ], 15 | p=0.5, 16 | ), 17 | dict( 18 | type="OneOf", 19 | transforms=[ 20 | dict(type="RandomBrightnessContrast", brightness_limit=0.25, contrast_limit=0.25), 21 | dict(type="RGBShift", r_shift_limit=15, g_shift_limit=15, b_shift_limit=15), 22 | ], 23 | p=0.5, 24 | ), 25 | dict( 26 | type="OneOf", 27 | transforms=[ 28 | dict(type="Blur"), 29 | dict(type="MotionBlur"), 30 | dict(type="GaussNoise"), 31 | dict(type="ImageCompression", quality_lower=75), 32 | ], 33 | p=0.4, 34 | ), 35 | dict( 36 | type="CoarseDropout", 37 | max_holes=30, 38 | max_height=30, 39 | max_width=30, 40 | min_holes=5, 41 | min_height=10, 42 | min_width=10, 43 | fill_value=img_norm_cfg["mean"][::-1], 44 | p=0.4, 45 | ), 46 | dict( 47 | type="ModifiedShiftScaleRotate", 48 | shift_limit=0.3, 49 | rotate_limit=5, 50 | scale_limit=(-0.3, 0.75), 51 | border_mode=0, 52 | value=img_norm_cfg["mean"][::-1], 53 | ), 54 | dict(type="RandomBBoxesSafeCrop", num_rate=(0.5, 1.0), erosion_rate=0.2), 55 | ] 56 | 57 | train_pipeline = [ 58 | dict(type="LoadImageFromFile"), 59 | dict(type="LoadAnnotations", with_bbox=True), 60 | dict( 61 | type="Albumentations", 62 | transforms=albu_train_transforms, 63 | keymap=dict(img="image", gt_masks="masks", gt_bboxes="bboxes"), 64 | update_pad_shape=False, 65 | skip_img_without_anno=True, 66 | bbox_params=dict(type="BboxParams", format="pascal_voc", label_fields=["labels"]), 67 | min_visibility=0.3, 68 | min_size=4, 69 | max_aspect_ratio=15, 70 | ), 71 | dict(type="RandomFlip", flip_ratio=0.5), 72 | dict( 73 | type="Resize", 74 | img_scale=[(768 + 32 * i, 768 + 32 * i) for i in range(25)], 75 | multiscale_mode="value", 76 | keep_ratio=True, 77 | ), 78 | dict(type="Normalize", **img_norm_cfg), 79 | dict(type="Pad", size_divisor=32), 80 | dict(type="DefaultFormatBundle"), 81 | dict(type="Collect", keys=["img", "gt_bboxes", "gt_labels"]), 82 | ] 83 | 84 | val_pipeline = [ 85 | dict(type="LoadImageFromFile"), 86 | dict( 87 | type="MultiScaleFlipAug", 88 | img_scale=(1024, 1024), 89 | flip=False, 90 | transforms=[ 91 | dict(type="Resize", keep_ratio=True), 92 | dict(type="RandomFlip"), 93 | dict(type="Normalize", **img_norm_cfg), 94 | dict(type="Pad", size_divisor=32), 95 | dict(type="ImageToTensor", keys=["img"]), 96 | dict(type="Collect", keys=["img"]), 97 | ], 98 | ), 99 | ] 100 | 101 | test_pipeline = [ 102 | dict(type="LoadImageFromFile"), 103 | dict( 104 | type="ModifiedMultiScaleFlipAug", 105 | img_scale=[(1408, 1408), (1536, 1536)], 106 | flip=True, 107 | flip_direction=["horizontal", "vertical"], 108 | transforms=[ 109 | dict(type="Resize", keep_ratio=True), 110 | dict(type="RandomFlip"), 111 | dict(type="Normalize", **img_norm_cfg), 112 | dict(type="Pad", size_divisor=32), 113 | dict(type="ImageToTensor", keys=["img"]), 114 | dict(type="Collect", keys=["img"]), 115 | ], 116 | ), 117 | ] 118 | 119 | data = dict( 120 | samples_per_gpu=4, 121 | workers_per_gpu=4, 122 | train=dict( 123 | type=dataset_type, 124 | ann_file=data_root + "folds_v2/{fold}/coco_tile_train.json", 125 | img_prefix=data_root + "train/", 126 | pipeline=train_pipeline, 127 | ), 128 | val=dict( 129 | type=dataset_type, 130 | ann_file=data_root + "folds_v2/{fold}/coco_tile_val.json", 131 | img_prefix=data_root + "train/", 132 | pipeline=val_pipeline, 133 | ), 134 | test=dict( 135 | type=dataset_type, 136 | ann_file=data_root + "folds_v2/{fold}/coco_tile_val.json", 137 | img_prefix=data_root + "train/", 138 | pipeline=test_pipeline, 139 | ), 140 | ) 141 | evaluation = dict(interval=1, metric="bbox") 142 | -------------------------------------------------------------------------------- /configs/_base_/datasets/wheat_detection_mstrain_hard.py: -------------------------------------------------------------------------------- 1 | _base_ = "./wheat_detection_mstrain_light.py" 2 | 3 | data_root = "/data/" 4 | data = dict( 5 | train=dict( 6 | ann_file=[ 7 | data_root + "folds_v2/{fold}/coco_tile_train.json", 8 | data_root + "folds_v2/{fold}/coco_pseudo_train.json", 9 | data_root + "coco_spike.json", 10 | ], 11 | img_prefix=[ 12 | dict( 13 | roots=[ 14 | data_root + "train/", 15 | data_root + "colored_train/", 16 | data_root + "stylized_train/", 17 | data_root + "stylized_by_test_v1/", 18 | data_root + "stylized_by_test_v2/", 19 | data_root + "stylized_by_test_v3/", 20 | data_root + "stylized_by_test_v4/", 21 | ], 22 | probabilities=[0.4, 0.3, 0.3 / 5, 0.3 / 5, 0.3 / 5, 0.3 / 5, 0.3 / 5], 23 | ), 24 | dict( 25 | roots=[ 26 | data_root + "crops_fold0/", 27 | data_root + "colored_crops_fold0/", 28 | data_root + "stylized_pseudo_by_test_v1/", 29 | data_root + "stylized_pseudo_by_test_v2/", 30 | data_root + "stylized_pseudo_by_test_v3/", 31 | data_root + "stylized_pseudo_by_test_v4/", 32 | ], 33 | probabilities=[0.5, 0.3, 0.2 / 4, 0.2 / 4, 0.2 / 4, 0.2 / 4], 34 | ), 35 | dict( 36 | roots=[ 37 | data_root + "SPIKE_images/", 38 | data_root + "stylized_SPIKE_images_v1/", 39 | data_root + "stylized_SPIKE_images_v2/", 40 | data_root + "stylized_SPIKE_images_v3/", 41 | data_root + "stylized_SPIKE_images_v4/", 42 | ], 43 | probabilities=[0.7, 0.3 / 4, 0.3 / 4, 0.3 / 4, 0.3 / 4], 44 | ), 45 | ], 46 | ) 47 | ) 48 | -------------------------------------------------------------------------------- /configs/_base_/datasets/wheat_detection_mstrain_light.py: -------------------------------------------------------------------------------- 1 | _base_ = "./wheat_detection_mstrain.py" 2 | 3 | img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 4 | 5 | albu_train_transforms = [ 6 | dict(type="RandomRotate90", p=1.0), 7 | dict( 8 | type="OneOf", 9 | transforms=[ 10 | dict(type="HueSaturationValue", hue_shift_limit=10, sat_shift_limit=35, val_shift_limit=25), 11 | dict(type="RandomGamma"), 12 | dict(type="CLAHE"), 13 | ], 14 | p=0.5, 15 | ), 16 | dict( 17 | type="OneOf", 18 | transforms=[ 19 | dict(type="RandomBrightnessContrast", brightness_limit=0.25, contrast_limit=0.25), 20 | dict(type="RGBShift", r_shift_limit=15, g_shift_limit=15, b_shift_limit=15), 21 | ], 22 | p=0.5, 23 | ), 24 | dict( 25 | type="OneOf", 26 | transforms=[ 27 | dict(type="Blur"), 28 | dict(type="MotionBlur"), 29 | dict(type="GaussNoise"), 30 | dict(type="ImageCompression", quality_lower=75), 31 | ], 32 | p=0.4, 33 | ), 34 | dict( 35 | type="CoarseDropout", 36 | max_holes=30, 37 | max_height=30, 38 | max_width=30, 39 | min_holes=5, 40 | min_height=10, 41 | min_width=10, 42 | fill_value=img_norm_cfg["mean"][::-1], 43 | p=0.4, 44 | ), 45 | dict( 46 | type="ModifiedShiftScaleRotate", 47 | shift_limit=0.3, 48 | rotate_limit=5, 49 | scale_limit=(-0.3, 0.75), 50 | border_mode=0, 51 | value=img_norm_cfg["mean"][::-1], 52 | ), 53 | dict(type="RandomBBoxesSafeCrop", num_rate=(0.5, 1.0), erosion_rate=0.2), 54 | ] 55 | 56 | train_pipeline = [ 57 | dict(type="MultipleLoadImageFromFile"), 58 | dict(type="LoadAnnotations", with_bbox=True), 59 | dict(type="Mosaic", p=0.25, min_buffer_size=4, pad_val=img_norm_cfg["mean"][::-1]), 60 | dict( 61 | type="Albumentations", 62 | transforms=albu_train_transforms, 63 | keymap=dict(img="image", gt_masks="masks", gt_bboxes="bboxes"), 64 | update_pad_shape=False, 65 | skip_img_without_anno=True, 66 | bbox_params=dict(type="BboxParams", format="pascal_voc", label_fields=["labels"]), 67 | min_visibility=0.3, 68 | min_size=4, 69 | max_aspect_ratio=15, 70 | ), 71 | dict(type="Mixup", p=0.25, min_buffer_size=2, pad_val=img_norm_cfg["mean"][::-1]), 72 | dict(type="RandomFlip", flip_ratio=0.5), 73 | dict( 74 | type="Resize", 75 | img_scale=[(768 + 32 * i, 768 + 32 * i) for i in range(25)], 76 | multiscale_mode="value", 77 | keep_ratio=True, 78 | ), 79 | dict(type="Normalize", **img_norm_cfg), 80 | dict(type="Pad", size_divisor=32), 81 | dict(type="DefaultFormatBundle"), 82 | dict(type="Collect", keys=["img", "gt_bboxes", "gt_labels"]), 83 | ] 84 | 85 | data_root = "/data/" 86 | data = dict( 87 | train=dict( 88 | pipeline=train_pipeline, 89 | ann_file=[ 90 | data_root + "folds_v2/{fold}/coco_tile_train.json", 91 | data_root + "folds_v2/{fold}/coco_pseudo_train.json", 92 | ], 93 | img_prefix=[ 94 | dict( 95 | roots=[ 96 | data_root + "train/", 97 | data_root + "colored_train/", 98 | data_root + "stylized_train/", 99 | data_root + "stylized_by_test_v1/", 100 | data_root + "stylized_by_test_v2/", 101 | data_root + "stylized_by_test_v3/", 102 | data_root + "stylized_by_test_v4/", 103 | ], 104 | probabilities=[0.75, 0.15, 0.1 / 5, 0.1 / 5, 0.1 / 5, 0.1 / 5, 0.1 / 5], 105 | ), 106 | dict( 107 | roots=[ 108 | data_root + "crops_fold0/", 109 | data_root + "colored_crops_fold0/", 110 | data_root + "stylized_pseudo_by_test_v1/", 111 | data_root + "stylized_pseudo_by_test_v2/", 112 | data_root + "stylized_pseudo_by_test_v3/", 113 | data_root + "stylized_pseudo_by_test_v4/", 114 | ], 115 | probabilities=[0.75, 0.15, 0.1 / 4, 0.1 / 4, 0.1 / 4, 0.1 / 4], 116 | ), 117 | ], 118 | ) 119 | ) 120 | -------------------------------------------------------------------------------- /configs/_base_/datasets/wheat_detection_mstrain_pseudo.py: -------------------------------------------------------------------------------- 1 | _base_ = "./wheat_detection_mstrain.py" 2 | 3 | data_root = "/data/" 4 | data = dict( 5 | samples_per_gpu=1, 6 | train=dict( 7 | ann_file=[ 8 | data_root + "coco_train.json", 9 | data_root + "coco_pseudo_test.json", 10 | data_root + "coco_pseudo_test.json", 11 | data_root + "coco_pseudo_test.json", 12 | ], 13 | img_prefix=[data_root + "train/", data_root + "test/", data_root + "test/", data_root + "test/"], 14 | ), 15 | ) 16 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=4) 2 | # yapf:disable 3 | log_config = dict(interval=10, hooks=[dict(type="TextLoggerHook")]) 4 | # yapf:enable 5 | dist_params = dict(backend="nccl") 6 | log_level = "INFO" 7 | load_from = None 8 | resume_from = None 9 | workflow = [("train", 1)] 10 | -------------------------------------------------------------------------------- /configs/_base_/models/cascade_rcnn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type="CascadeRCNN", 4 | pretrained=None, 5 | backbone=dict( 6 | type="ResNet", 7 | depth=50, 8 | num_stages=4, 9 | out_indices=(0, 1, 2, 3), 10 | frozen_stages=1, 11 | norm_cfg=dict(type="BN", requires_grad=True), 12 | norm_eval=True, 13 | style="pytorch", 14 | ), 15 | neck=dict(type="FPN", in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), 16 | rpn_head=dict( 17 | type="RPNHead", 18 | in_channels=256, 19 | feat_channels=256, 20 | anchor_generator=dict(type="AnchorGenerator", scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), 21 | bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]), 22 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 23 | loss_bbox=dict(type="SmoothL1Loss", beta=1.0 / 9.0, loss_weight=1.0), 24 | ), 25 | roi_head=dict( 26 | type="CascadeRoIHead", 27 | num_stages=3, 28 | stage_loss_weights=[1, 0.5, 0.25], 29 | bbox_roi_extractor=dict( 30 | type="SingleRoIExtractor", 31 | roi_layer=dict(type="RoIAlign", out_size=7, sample_num=0), 32 | out_channels=256, 33 | featmap_strides=[4, 8, 16, 32], 34 | ), 35 | bbox_head=[ 36 | dict( 37 | type="Shared2FCBBoxHead", 38 | in_channels=256, 39 | fc_out_channels=1024, 40 | roi_feat_size=7, 41 | num_classes=1, 42 | bbox_coder=dict( 43 | type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2] 44 | ), 45 | reg_class_agnostic=True, 46 | loss_cls=dict(type="LabelSmoothCrossEntropyLoss", use_sigmoid=False, loss_weight=1.0, label_smooth=0.1), 47 | loss_bbox=dict(type="SmoothL1Loss", beta=1.0, loss_weight=1.0), 48 | ), 49 | dict( 50 | type="Shared2FCBBoxHead", 51 | in_channels=256, 52 | fc_out_channels=1024, 53 | roi_feat_size=7, 54 | num_classes=1, 55 | bbox_coder=dict( 56 | type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.05, 0.05, 0.1, 0.1] 57 | ), 58 | reg_class_agnostic=True, 59 | loss_cls=dict(type="LabelSmoothCrossEntropyLoss", use_sigmoid=False, loss_weight=1.0, label_smooth=0.1), 60 | loss_bbox=dict(type="SmoothL1Loss", beta=1.0, loss_weight=1.0), 61 | ), 62 | dict( 63 | type="Shared2FCBBoxHead", 64 | in_channels=256, 65 | fc_out_channels=1024, 66 | roi_feat_size=7, 67 | num_classes=1, 68 | bbox_coder=dict( 69 | type="DeltaXYWHBBoxCoder", 70 | target_means=[0.0, 0.0, 0.0, 0.0], 71 | target_stds=[0.033, 0.033, 0.067, 0.067], 72 | ), 73 | reg_class_agnostic=True, 74 | loss_cls=dict(type="LabelSmoothCrossEntropyLoss", use_sigmoid=False, loss_weight=1.0, label_smooth=0.1), 75 | loss_bbox=dict(type="SmoothL1Loss", beta=1.0, loss_weight=1.0), 76 | ), 77 | ], 78 | ), 79 | ) 80 | # model training and testing settings 81 | train_cfg = dict( 82 | rpn=dict( 83 | assigner=dict( 84 | type="MaxIoUAssigner", 85 | pos_iou_thr=0.7, 86 | neg_iou_thr=0.3, 87 | min_pos_iou=0.3, 88 | match_low_quality=True, 89 | ignore_iof_thr=-1, 90 | ), 91 | sampler=dict(type="RandomSampler", num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False), 92 | allowed_border=0, 93 | pos_weight=-1, 94 | debug=False, 95 | ), 96 | rpn_proposal=dict(nms_across_levels=False, nms_pre=2000, nms_post=2000, max_num=2000, nms_thr=0.7, min_bbox_size=0), 97 | rcnn=[ 98 | dict( 99 | assigner=dict( 100 | type="MaxIoUAssigner", 101 | pos_iou_thr=0.5, 102 | neg_iou_thr=0.5, 103 | min_pos_iou=0.5, 104 | match_low_quality=False, 105 | ignore_iof_thr=-1, 106 | ), 107 | sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), 108 | pos_weight=-1, 109 | debug=False, 110 | ), 111 | dict( 112 | assigner=dict( 113 | type="MaxIoUAssigner", 114 | pos_iou_thr=0.6, 115 | neg_iou_thr=0.6, 116 | min_pos_iou=0.6, 117 | match_low_quality=False, 118 | ignore_iof_thr=-1, 119 | ), 120 | sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), 121 | pos_weight=-1, 122 | debug=False, 123 | ), 124 | dict( 125 | assigner=dict( 126 | type="MaxIoUAssigner", 127 | pos_iou_thr=0.7, 128 | neg_iou_thr=0.7, 129 | min_pos_iou=0.7, 130 | match_low_quality=False, 131 | ignore_iof_thr=-1, 132 | ), 133 | sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True), 134 | pos_weight=-1, 135 | debug=False, 136 | ), 137 | ], 138 | ) 139 | test_cfg = dict( 140 | rpn=dict(nms_across_levels=False, nms_pre=1000, nms_post=1000, max_num=1000, nms_thr=0.7, min_bbox_size=0), 141 | rcnn=dict(score_thr=0.45, nms=dict(type="nms", iou_thr=0.45), max_per_img=200), 142 | ) 143 | -------------------------------------------------------------------------------- /configs/_base_/models/detectors_r50_ga.py: -------------------------------------------------------------------------------- 1 | _base_ = "./cascade_rcnn_r50_fpn.py" 2 | 3 | conv_cfg = dict(type="ConvAWSV1") 4 | model = dict( 5 | type="RecursiveFeaturePyramid", 6 | rfp_steps=2, 7 | rfp_sharing=False, 8 | stage_with_rfp=(False, True, True, True), 9 | backbone=dict( 10 | _delete_=True, 11 | type="DetectoRS_ResNetV1", 12 | depth=50, 13 | num_stages=4, 14 | out_indices=(0, 1, 2, 3), 15 | frozen_stages=1, 16 | conv_cfg=conv_cfg, 17 | sac=dict(type="SACV1", use_deform=True), 18 | stage_with_sac=(False, True, True, True), 19 | norm_cfg=dict(type="BN", requires_grad=True), 20 | style="pytorch", 21 | gen_attention=dict(spatial_range=-1, num_heads=8, attention_type="0010", kv_stride=2), 22 | stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]], 23 | ), 24 | ) 25 | 26 | test_cfg = dict(rcnn=dict(score_thr=0.5, nms=dict(iou_thr=0.5))) 27 | -------------------------------------------------------------------------------- /configs/_base_/models/universe_r101_gfl.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type="ModifiedATSS", 4 | pretrained=None, 5 | backbone=dict( 6 | type="FixedRes2Net", 7 | depth=101, 8 | scales=4, 9 | base_width=26, 10 | num_stages=4, 11 | out_indices=(0, 1, 2, 3), 12 | frozen_stages=1, 13 | norm_cfg=dict(type="BN", requires_grad=True), 14 | norm_eval=True, 15 | style="pytorch", 16 | dcn=dict(type="DCN", deformable_groups=1, fallback_on_stride=False), 17 | stage_with_dcn=(False, True, True, True), 18 | ), 19 | neck=[ 20 | dict( 21 | type="FPN", 22 | in_channels=[256, 512, 1024, 2048], 23 | out_channels=256, 24 | start_level=1, 25 | add_extra_convs=True, 26 | extra_convs_on_inputs=False, 27 | num_outs=5, 28 | ), 29 | dict( 30 | type="SEPC", 31 | out_channels=256, 32 | stacked_convs=4, 33 | pconv_deform=True, 34 | lcconv_deform=True, 35 | ibn=False, # please set imgs/gpu >= 4 36 | lcconv_padding=1, 37 | ), 38 | ], 39 | bbox_head=dict( 40 | type="GFLSEPCHead", 41 | num_classes=1, 42 | in_channels=256, 43 | stacked_convs=0, 44 | feat_channels=256, 45 | anchor_generator=dict( 46 | type="AnchorGenerator", ratios=[1.0], octave_base_scale=8, scales_per_octave=1, strides=[8, 16, 32, 64, 128] 47 | ), 48 | loss_cls=dict(type="QualityFocalLoss", use_sigmoid=True, beta=2.0, loss_weight=1.0), 49 | loss_dfl=dict(type="DistributionFocalLoss", loss_weight=0.25), 50 | reg_max=16, 51 | loss_bbox=dict(type="GIoULoss", loss_weight=2.0), 52 | ), 53 | ) 54 | # training and testing settings 55 | train_cfg = dict(assigner=dict(type="FixedATSSAssigner", topk=9), allowed_border=-1, pos_weight=-1, debug=False) 56 | test_cfg = dict(nms_pre=1000, min_bbox_size=0, score_thr=0.4, nms=dict(type="nms", iou_thr=0.45), max_per_img=200) 57 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_1x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type="SGD", lr=0.04, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy="step", warmup="linear", warmup_iters=100, warmup_ratio=0.001, step=[7, 11]) 6 | total_epochs = 12 7 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_4x.py: -------------------------------------------------------------------------------- 1 | _base_ = "schedule_1x.py" 2 | 3 | # learning policy 4 | lr_config = dict(step=[43, 47]) 5 | total_epochs = 48 6 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_pseudo.py: -------------------------------------------------------------------------------- 1 | _base_ = "schedule_1x.py" 2 | 3 | # optimizer 4 | optimizer = dict(lr=0.01 / 2 / 4) 5 | 6 | # learning policy 7 | lr_config = dict(step=[1500, 3500], by_epoch=False) 8 | total_epochs = 1 9 | -------------------------------------------------------------------------------- /configs/detectors/detectors_r50_ga_mstrain_local_pseudo.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/detectors_r50_ga.py", 3 | "../_base_/datasets/wheat_detection_mstrain_pseudo.py", 4 | "../_base_/schedules/schedule_pseudo.py", 5 | "../_base_/default_runtime.py", 6 | ] 7 | 8 | img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 9 | 10 | albu_train_transforms = [ 11 | dict(type="RandomRotate90", p=1.0), 12 | dict( 13 | type="OneOf", 14 | transforms=[ 15 | dict(type="HueSaturationValue", hue_shift_limit=10, sat_shift_limit=35, val_shift_limit=25), 16 | dict(type="RandomGamma"), 17 | dict(type="CLAHE"), 18 | ], 19 | p=0.5, 20 | ), 21 | dict( 22 | type="OneOf", 23 | transforms=[ 24 | dict(type="RandomBrightnessContrast", brightness_limit=0.25, contrast_limit=0.25), 25 | dict(type="RGBShift", r_shift_limit=15, g_shift_limit=15, b_shift_limit=15), 26 | ], 27 | p=0.5, 28 | ), 29 | dict( 30 | type="OneOf", 31 | transforms=[ 32 | dict(type="Blur"), 33 | dict(type="MotionBlur"), 34 | dict(type="GaussNoise"), 35 | dict(type="ImageCompression", quality_lower=75), 36 | ], 37 | p=0.4, 38 | ), 39 | dict( 40 | type="CoarseDropout", 41 | max_holes=30, 42 | max_height=30, 43 | max_width=30, 44 | min_holes=5, 45 | min_height=10, 46 | min_width=10, 47 | fill_value=img_norm_cfg["mean"][::-1], 48 | p=0.4, 49 | ), 50 | dict( 51 | type="ModifiedShiftScaleRotate", 52 | shift_limit=0.3, 53 | rotate_limit=5, 54 | scale_limit=(-0.3, 0.75), 55 | border_mode=0, 56 | value=img_norm_cfg["mean"][::-1], 57 | ), 58 | dict(type="RandomBBoxesSafeCrop", num_rate=(0.5, 1.0), erosion_rate=0.2), 59 | ] 60 | 61 | train_pipeline = [ 62 | dict(type="LoadImageFromFile"), 63 | dict(type="LoadAnnotations", with_bbox=True), 64 | dict( 65 | type="Albumentations", 66 | transforms=albu_train_transforms, 67 | keymap=dict(img="image", gt_masks="masks", gt_bboxes="bboxes"), 68 | update_pad_shape=False, 69 | skip_img_without_anno=True, 70 | bbox_params=dict(type="BboxParams", format="pascal_voc", label_fields=["labels"]), 71 | min_visibility=0.3, 72 | min_size=4, 73 | max_aspect_ratio=15, 74 | ), 75 | dict(type="RandomFlip", flip_ratio=0.5), 76 | dict( 77 | type="Resize", 78 | img_scale=[(768 + 32 * i, 768 + 32 * i) for i in range(20)], 79 | multiscale_mode="value", 80 | keep_ratio=True, 81 | ), 82 | dict(type="Normalize", **img_norm_cfg), 83 | dict(type="Pad", size_divisor=32), 84 | dict(type="DefaultFormatBundle"), 85 | dict(type="Collect", keys=["img", "gt_bboxes", "gt_labels"]), 86 | ] 87 | 88 | test_pipeline = [ 89 | dict(type="LoadImageFromFile"), 90 | dict( 91 | type="ModifiedMultiScaleFlipAug", 92 | img_scale=[(1184, 1184), (1376, 1376)], 93 | flip=True, 94 | flip_direction=["horizontal", "vertical"], 95 | transforms=[ 96 | dict(type="Resize", keep_ratio=True), 97 | dict(type="RandomFlip"), 98 | dict(type="Normalize", **img_norm_cfg), 99 | dict(type="Pad", size_divisor=32), 100 | dict(type="ImageToTensor", keys=["img"]), 101 | dict(type="Collect", keys=["img"]), 102 | ], 103 | ), 104 | ] 105 | 106 | data = dict(train=dict(pipeline=train_pipeline), test=dict(pipeline=test_pipeline)) 107 | 108 | checkpoint_config = dict(interval=1) 109 | -------------------------------------------------------------------------------- /configs/detectors/detectors_r50_ga_mstrain_private_pseudo.py: -------------------------------------------------------------------------------- 1 | _base_ = "./detectors_r50_ga_mstrain_local_pseudo.py" 2 | 3 | kaggle_working = "/kaggle/working/" 4 | data_root = "/kaggle/input/global-wheat-detection/" 5 | data = dict( 6 | train=dict( 7 | ann_file=[ 8 | kaggle_working + "coco_train.json", 9 | kaggle_working + "coco_pseudo_test.json", 10 | kaggle_working + "coco_pseudo_test.json", 11 | kaggle_working + "coco_pseudo_test.json", 12 | ], 13 | img_prefix=[data_root + "train/", data_root + "test/", data_root + "test/", data_root + "test/"], 14 | ) 15 | ) 16 | -------------------------------------------------------------------------------- /configs/detectors/detectors_r50_ga_mstrain_public_pseudo.py: -------------------------------------------------------------------------------- 1 | _base_ = "./detectors_r50_ga_mstrain_local_pseudo.py" 2 | 3 | kaggle_working = "/kaggle/working/" 4 | data_root = "/kaggle/input/global-wheat-detection/" 5 | data = dict(train=dict(ann_file=kaggle_working + "coco_pseudo_test.json", img_prefix=data_root + "test/")) 6 | -------------------------------------------------------------------------------- /configs/detectors/detectors_r50_ga_mstrain_stage0.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/detectors_r50_ga.py", 3 | "../_base_/datasets/wheat_detection_mstrain_hard.py", 4 | "../_base_/schedules/schedule_4x.py", 5 | "../_base_/default_runtime.py", 6 | ] 7 | 8 | data = dict(samples_per_gpu=2) 9 | optimizer = dict(lr=0.01) 10 | load_from = "/dumps/DetectoRS_R50-0f1c8080_v2_attention.pth" 11 | -------------------------------------------------------------------------------- /configs/detectors/detectors_r50_ga_mstrain_stage1.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/detectors_r50_ga.py", 3 | "../_base_/datasets/wheat_detection_mstrain_light.py", 4 | "../_base_/schedules/schedule_1x.py", 5 | "../_base_/default_runtime.py", 6 | ] 7 | 8 | data = dict(samples_per_gpu=2) 9 | optimizer = dict(lr=0.02 / 2) 10 | load_from = "/dumps/work_dirs/detectors_r50_ga_mstrain_stage0/0/epoch_48.pth" 11 | checkpoint_config = dict(interval=4) 12 | -------------------------------------------------------------------------------- /configs/detectors/detectors_r50_ga_mstrain_stage2.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/detectors_r50_ga.py", 3 | "../_base_/datasets/wheat_detection_mstrain.py", 4 | "../_base_/schedules/schedule_1x.py", 5 | "../_base_/default_runtime.py", 6 | ] 7 | 8 | data = dict(samples_per_gpu=2) 9 | optimizer = dict(lr=0.02 / 4) 10 | load_from = "/dumps/work_dirs/detectors_r50_ga_mstrain_stage1/0/epoch_12.pth" 11 | checkpoint_config = dict(interval=1) 12 | -------------------------------------------------------------------------------- /configs/universe_r101_gfl/universe_r101_gfl_mstrain_local_pseudo.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/universe_r101_gfl.py", 3 | "../_base_/datasets/wheat_detection_mstrain_pseudo.py", 4 | "../_base_/schedules/schedule_pseudo.py", 5 | "../_base_/default_runtime.py", 6 | ] 7 | 8 | optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) 9 | fp16 = dict(loss_scale=512.0) 10 | checkpoint_config = dict(interval=1) 11 | -------------------------------------------------------------------------------- /configs/universe_r101_gfl/universe_r101_gfl_mstrain_private_pseudo.py: -------------------------------------------------------------------------------- 1 | _base_ = "./universe_r101_gfl_mstrain_local_pseudo.py" 2 | 3 | kaggle_working = "/kaggle/working/" 4 | data_root = "/kaggle/input/global-wheat-detection/" 5 | data = dict( 6 | train=dict( 7 | ann_file=[ 8 | kaggle_working + "coco_train.json", 9 | kaggle_working + "coco_pseudo_test.json", 10 | kaggle_working + "coco_pseudo_test.json", 11 | kaggle_working + "coco_pseudo_test.json", 12 | ], 13 | img_prefix=[data_root + "train/", data_root + "test/", data_root + "test/", data_root + "test/"], 14 | ) 15 | ) 16 | -------------------------------------------------------------------------------- /configs/universe_r101_gfl/universe_r101_gfl_mstrain_public_pseudo.py: -------------------------------------------------------------------------------- 1 | _base_ = "./universe_r101_gfl_mstrain_local_pseudo.py" 2 | 3 | kaggle_working = "/kaggle/working/" 4 | data_root = "/kaggle/input/global-wheat-detection/" 5 | data = dict( 6 | train=dict( 7 | ann_file=[ 8 | kaggle_working + "coco_pseudo_test.json", 9 | kaggle_working + "coco_pseudo_test.json", 10 | kaggle_working + "coco_pseudo_test.json", 11 | ], 12 | img_prefix=[data_root + "test/", data_root + "test/", data_root + "test/"], 13 | ) 14 | ) 15 | -------------------------------------------------------------------------------- /configs/universe_r101_gfl/universe_r101_gfl_mstrain_stage0.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/universe_r101_gfl.py", 3 | "../_base_/datasets/wheat_detection_mstrain_hard.py", 4 | "../_base_/schedules/schedule_4x.py", 5 | "../_base_/default_runtime.py", 6 | ] 7 | 8 | data = dict(samples_per_gpu=6) 9 | optimizer = dict(lr=0.03) 10 | optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) 11 | load_from = "/dumps/universenet101_gfl_fp16_4x4_mstrain_480_960_2x_coco_20200716_epoch_24-1b9a1241.pth" 12 | fp16 = dict(loss_scale=512.0) 13 | -------------------------------------------------------------------------------- /configs/universe_r101_gfl/universe_r101_gfl_mstrain_stage1.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/universe_r101_gfl.py", 3 | "../_base_/datasets/wheat_detection_mstrain_light.py", 4 | "../_base_/schedules/schedule_1x.py", 5 | "../_base_/default_runtime.py", 6 | ] 7 | 8 | data = dict(samples_per_gpu=6) 9 | optimizer = dict(lr=0.03 / 2) 10 | optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) 11 | load_from = "/dumps/work_dirs/universe_r101_gfl_mstrain_stage0/0/epoch_48.pth" 12 | fp16 = dict(loss_scale=512.0) 13 | checkpoint_config = dict(interval=1) 14 | -------------------------------------------------------------------------------- /configs/universe_r101_gfl/universe_r101_gfl_mstrain_stage2.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/models/universe_r101_gfl.py", 3 | "../_base_/datasets/wheat_detection_mstrain.py", 4 | "../_base_/schedules/schedule_1x.py", 5 | "../_base_/default_runtime.py", 6 | ] 7 | 8 | data = dict(samples_per_gpu=6) 9 | optimizer = dict(lr=0.03 / 3) 10 | optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) 11 | load_from = "/dumps/work_dirs/universe_r101_gfl_mstrain_stage1/0/epoch_48.pth" 12 | fp16 = dict(loss_scale=512.0) 13 | checkpoint_config = dict(interval=1) 14 | -------------------------------------------------------------------------------- /gwd/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings # noqa 2 | 3 | warnings.simplefilter("ignore", UserWarning) # noqa 4 | 5 | from .assigners.atss_assigner import FixedATSSAssigner 6 | from .backbones.res2net import FixedRes2Net 7 | from .backbones.sac.conv_aws import ConvAWS2d 8 | from .backbones.sac.resnet import DetectoRS_ResNet 9 | from .backbones.sac.resnext import DetectoRS_ResNeXt 10 | from .backbones.sac.saconv import SAConv2d 11 | from .datasets.pipelines.albumentations import Albumentations, ModifiedShiftScaleRotate, RandomBBoxesSafeCrop 12 | from .datasets.pipelines.loading import MultipleLoadImageFromFile 13 | from .datasets.pipelines.test_aug import ModifiedMultiScaleFlipAug 14 | from .datasets.pipelines.transforms import ( 15 | Mixup, 16 | Mosaic, 17 | RandomCopyPasteFromFile, 18 | RandomCropVisibility, 19 | RandomRotate90, 20 | ) 21 | from .datasets.source_balanced_dataset import SourceBalancedDataset 22 | from .datasets.wheat_detection import WheatDataset 23 | from .dense_heads.gfl_head import GFLSEPCHead 24 | from .detectors.atss import ModifiedATSS 25 | from .detectors.rfp import RecursiveFeaturePyramid 26 | from .losses.cross_entropy_loss import LabelSmoothCrossEntropyLoss 27 | from .necks.sepc import SEPC 28 | from .patches import build_dataset 29 | -------------------------------------------------------------------------------- /gwd/assigners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/assigners/__init__.py -------------------------------------------------------------------------------- /gwd/assigners/atss_assigner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mmdet.core.bbox.assigners.assign_result import AssignResult 4 | from mmdet.core.bbox.assigners.base_assigner import BaseAssigner 5 | from mmdet.core.bbox.builder import BBOX_ASSIGNERS 6 | from mmdet.core.bbox.iou_calculators import build_iou_calculator 7 | 8 | 9 | @BBOX_ASSIGNERS.register_module() 10 | class FixedATSSAssigner(BaseAssigner): 11 | """Assign a corresponding gt bbox or background to each bbox. 12 | Each proposals will be assigned with `0` or a positive integer 13 | indicating the ground truth index. 14 | - 0: negative sample, no assigned gt 15 | - positive integer: positive sample, index (1-based) of assigned gt 16 | Args: 17 | topk (float): number of bbox selected in each level 18 | """ 19 | 20 | def __init__(self, topk, iou_calculator=dict(type="BboxOverlaps2D")): 21 | self.topk = topk 22 | self.iou_calculator = build_iou_calculator(iou_calculator) 23 | 24 | # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py 25 | 26 | def assign(self, bboxes, num_level_bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): 27 | """Assign gt to bboxes. 28 | The assignment is done in following steps 29 | 1. compute iou between all bbox (bbox of all pyramid levels) and gt 30 | 2. compute center distance between all bbox and gt 31 | 3. on each pyramid level, for each gt, select k bbox whose center 32 | are closest to the gt center, so we total select k*l bbox as 33 | candidates for each gt 34 | 4. get corresponding iou for the these candidates, and compute the 35 | mean and std, set mean + std as the iou threshold 36 | 5. select these candidates whose iou are greater than or equal to 37 | the threshold as postive 38 | 6. limit the positive sample's center in gt 39 | Args: 40 | bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4). 41 | num_level_bboxes (List): num of bboxes in each level 42 | gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). 43 | gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are 44 | labelled as `ignored`, e.g., crowd boxes in COCO. 45 | gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). 46 | Returns: 47 | :obj:`AssignResult`: The assign result. 48 | """ 49 | INF = 100000000 50 | bboxes = bboxes[:, :4] 51 | num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0) 52 | 53 | # compute iou between all bbox and gt 54 | overlaps = self.iou_calculator(bboxes, gt_bboxes) 55 | 56 | # assign 0 by default 57 | assigned_gt_inds = overlaps.new_full((num_bboxes,), 0, dtype=torch.long) 58 | 59 | if num_gt == 0 or num_bboxes == 0: 60 | # No ground truth or boxes, return empty assignment 61 | max_overlaps = overlaps.new_zeros((num_bboxes,)) 62 | if num_gt == 0: 63 | # No truth, assign everything to background 64 | assigned_gt_inds[:] = 0 65 | if gt_labels is None: 66 | assigned_labels = None 67 | else: 68 | assigned_labels = overlaps.new_full((num_bboxes,), -1, dtype=torch.long) 69 | return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) 70 | 71 | # compute center distance between all bbox and gt 72 | gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 73 | gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 74 | gt_points = torch.stack((gt_cx, gt_cy), dim=1) 75 | 76 | bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0 77 | bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0 78 | bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1) 79 | 80 | distances = (bboxes_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt() 81 | 82 | # Selecting candidates based on the center distance 83 | candidate_idxs = [] 84 | start_idx = 0 85 | for level, bboxes_per_level in enumerate(num_level_bboxes): 86 | # on each pyramid level, for each gt, 87 | # select k bbox whose center are closest to the gt center 88 | end_idx = start_idx + bboxes_per_level 89 | distances_per_level = distances[start_idx:end_idx, :] 90 | selectable_k = min(self.topk, bboxes_per_level) 91 | _, topk_idxs_per_level = distances_per_level.topk(selectable_k, dim=0, largest=False) 92 | candidate_idxs.append(topk_idxs_per_level + start_idx) 93 | start_idx = end_idx 94 | candidate_idxs = torch.cat(candidate_idxs, dim=0) 95 | 96 | # get corresponding iou for the these candidates, and compute the 97 | # mean and std, set mean + std as the iou threshold 98 | candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)] 99 | overlaps_mean_per_gt = candidate_overlaps.mean(0) 100 | overlaps_std_per_gt = candidate_overlaps.std(0) 101 | overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt 102 | 103 | is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :] 104 | 105 | # limit the positive sample's center in gt 106 | for gt_idx in range(num_gt): 107 | candidate_idxs[:, gt_idx] += gt_idx * num_bboxes 108 | ep_bboxes_cx = bboxes_cx.view(1, -1).expand(num_gt, num_bboxes).contiguous().view(-1) 109 | ep_bboxes_cy = bboxes_cy.view(1, -1).expand(num_gt, num_bboxes).contiguous().view(-1) 110 | candidate_idxs = candidate_idxs.view(-1) 111 | 112 | # calculate the left, top, right, bottom distance between positive 113 | # bbox center and gt side 114 | l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0] 115 | t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1] 116 | r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt) 117 | b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt) 118 | is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01 119 | is_pos = is_pos & is_in_gts 120 | 121 | # if an anchor box is assigned to multiple gts, 122 | # the one with the highest IoU will be selected. 123 | overlaps_inf = torch.full_like(overlaps, -INF).t().contiguous().view(-1) 124 | index = candidate_idxs.view(-1)[is_pos.view(-1)] 125 | overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index] 126 | overlaps_inf = overlaps_inf.view(num_gt, -1).t() 127 | 128 | max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1) 129 | assigned_gt_inds[max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1 130 | 131 | if gt_labels is not None: 132 | assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1) 133 | pos_inds = torch.nonzero(assigned_gt_inds > 0, as_tuple=False).squeeze() 134 | if pos_inds.numel() > 0: 135 | assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - 1] 136 | else: 137 | assigned_labels = None 138 | return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) 139 | -------------------------------------------------------------------------------- /gwd/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/backbones/__init__.py -------------------------------------------------------------------------------- /gwd/backbones/res2net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from mmcv.cnn import constant_init, kaiming_init 3 | from mmcv.runner import load_checkpoint 4 | from torch.nn.modules.batchnorm import _BatchNorm 5 | 6 | from mmdet.models.backbones.res2net import Res2Net 7 | from mmdet.models.backbones.resnet import BasicBlock, Bottleneck 8 | from mmdet.models.builder import BACKBONES 9 | from mmdet.utils import get_root_logger 10 | 11 | 12 | @BACKBONES.register_module() 13 | class FixedRes2Net(Res2Net): 14 | def init_weights(self, pretrained=None): 15 | if isinstance(pretrained, str): 16 | logger = get_root_logger() 17 | load_checkpoint(self, pretrained, strict=False, logger=logger) 18 | elif pretrained is None: 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | kaiming_init(m) 22 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 23 | constant_init(m, 1) 24 | 25 | if self.dcn is not None: 26 | for m in self.modules(): 27 | if isinstance(m, Bottleneck): 28 | for conv in m.convs: 29 | if hasattr(conv, "conv_offset"): 30 | constant_init(conv.conv_offset, 0) 31 | 32 | if self.zero_init_residual: 33 | for m in self.modules(): 34 | if isinstance(m, Bottleneck): 35 | constant_init(m.norm3, 0) 36 | elif isinstance(m, BasicBlock): 37 | constant_init(m.norm2, 0) 38 | else: 39 | raise TypeError("pretrained must be a str or None") 40 | -------------------------------------------------------------------------------- /gwd/backbones/sac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/backbones/sac/__init__.py -------------------------------------------------------------------------------- /gwd/backbones/sac/conv_aws.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import CONV_LAYERS 4 | 5 | 6 | @CONV_LAYERS.register_module("ConvAWSV1") 7 | class ConvAWS2d(nn.Conv2d): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 9 | super().__init__( 10 | in_channels, 11 | out_channels, 12 | kernel_size, 13 | stride=stride, 14 | padding=padding, 15 | dilation=dilation, 16 | groups=groups, 17 | bias=bias, 18 | ) 19 | self.register_buffer("weight_gamma", torch.ones(self.out_channels, 1, 1, 1)) 20 | self.register_buffer("weight_beta", torch.zeros(self.out_channels, 1, 1, 1)) 21 | 22 | def _get_weight(self, weight): 23 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) 24 | weight = weight - weight_mean 25 | std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1) 26 | weight = weight / std 27 | weight = self.weight_gamma * weight + self.weight_beta 28 | return weight 29 | 30 | def forward(self, x): 31 | weight = self._get_weight(self.weight) 32 | return super()._conv_forward(x, weight) 33 | 34 | def _load_from_state_dict( 35 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 36 | ): 37 | self.weight_gamma.data.fill_(-1) 38 | super()._load_from_state_dict( 39 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 40 | ) 41 | if self.weight_gamma.data.mean() > 0: 42 | return 43 | weight = self.weight.data 44 | weight_mean = weight.data.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) 45 | self.weight_beta.data.copy_(weight_mean) 46 | std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1) 47 | self.weight_gamma.data.copy_(std) 48 | -------------------------------------------------------------------------------- /gwd/backbones/sac/saconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmcv.cnn import CONV_LAYERS 3 | 4 | from mmdet.ops.dcn import deform_conv 5 | 6 | from .conv_aws import ConvAWS2d 7 | 8 | 9 | @CONV_LAYERS.register_module("SACV1") 10 | class SAConv2d(ConvAWS2d): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | padding=0, 18 | dilation=1, 19 | groups=1, 20 | bias=True, 21 | use_deform=False, 22 | ): 23 | super().__init__( 24 | in_channels, 25 | out_channels, 26 | kernel_size, 27 | stride=stride, 28 | padding=padding, 29 | dilation=dilation, 30 | groups=groups, 31 | bias=bias, 32 | ) 33 | self.use_deform = use_deform 34 | self.switch = torch.nn.Conv2d(self.in_channels, 1, kernel_size=1, stride=stride, bias=True) 35 | self.switch.weight.data.fill_(0) 36 | self.switch.bias.data.fill_(1) 37 | self.weight_diff = torch.nn.Parameter(torch.Tensor(self.weight.size())) 38 | self.weight_diff.data.zero_() 39 | self.pre_context = torch.nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1, bias=True) 40 | self.pre_context.weight.data.fill_(0) 41 | self.pre_context.bias.data.fill_(0) 42 | self.post_context = torch.nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1, bias=True) 43 | self.post_context.weight.data.fill_(0) 44 | self.post_context.bias.data.fill_(0) 45 | if self.use_deform: 46 | self.offset_s = torch.nn.Conv2d(self.in_channels, 18, kernel_size=3, padding=1, stride=stride, bias=True) 47 | self.offset_l = torch.nn.Conv2d(self.in_channels, 18, kernel_size=3, padding=1, stride=stride, bias=True) 48 | self.offset_s.weight.data.fill_(0) 49 | self.offset_s.bias.data.fill_(0) 50 | self.offset_l.weight.data.fill_(0) 51 | self.offset_l.bias.data.fill_(0) 52 | 53 | def forward(self, x): 54 | # pre-context 55 | avg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1) 56 | avg_x = self.pre_context(avg_x) 57 | avg_x = avg_x.expand_as(x) 58 | x = x + avg_x 59 | # switch 60 | avg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect") 61 | avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0) 62 | switch = self.switch(avg_x) 63 | # sac 64 | weight = self._get_weight(self.weight) 65 | if self.use_deform: 66 | offset = self.offset_s(avg_x) 67 | out_s = deform_conv(x, offset, weight, self.stride, self.padding, self.dilation, self.groups, 1) 68 | else: 69 | out_s = super()._conv_forward(x, weight) 70 | ori_p = self.padding 71 | ori_d = self.dilation 72 | self.padding = tuple(3 * p for p in self.padding) 73 | self.dilation = tuple(3 * d for d in self.dilation) 74 | weight = weight + self.weight_diff 75 | if self.use_deform: 76 | offset = self.offset_l(avg_x) 77 | out_l = deform_conv(x, offset, weight, self.stride, self.padding, self.dilation, self.groups, 1) 78 | else: 79 | out_l = super()._conv_forward(x, weight) 80 | out = switch * out_s + (1 - switch) * out_l 81 | self.padding = ori_p 82 | self.dilation = ori_d 83 | # post-context 84 | avg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1) 85 | avg_x = self.post_context(avg_x) 86 | avg_x = avg_x.expand_as(out) 87 | out = out + avg_x 88 | return out 89 | -------------------------------------------------------------------------------- /gwd/colorization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/colorization/__init__.py -------------------------------------------------------------------------------- /gwd/colorization/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | from glob import glob 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torchvision.transforms as transforms 10 | from torchvision.datasets.folder import pil_loader 11 | from tqdm import tqdm 12 | 13 | from gwd.colorization.models import GeneratorUNet 14 | 15 | IMG_SIZE = 1024 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--img_pattern", default="/data/SPIKE_Dataset/images/*jpg") 21 | parser.add_argument("--weights_path", default="/dumps/pix2pix_gen.pth") 22 | parser.add_argument("--output_root", default="/data/SPIKE_Dataset/colored_images") 23 | return parser.parse_args() 24 | 25 | 26 | def generate(model, img_path, transform): 27 | image = pil_loader(img_path) 28 | image = transform(image) 29 | with torch.no_grad(): 30 | fake_image = model(image.unsqueeze(0).cuda()) 31 | fake_image = fake_image.cpu().numpy()[0].transpose(1, 2, 0) 32 | fake_image -= fake_image.min() 33 | fake_image /= fake_image.max() 34 | fake_image *= 255 35 | fake_image = fake_image.astype(np.uint8) 36 | fake_image = cv2.UMat(fake_image).get() 37 | return fake_image 38 | 39 | 40 | def main(): 41 | args = parse_args() 42 | os.makedirs(args.output_root, exist_ok=True) 43 | img_paths = glob(args.img_pattern) 44 | model = GeneratorUNet().cuda() 45 | model.load_state_dict(torch.load(args.weights_path, map_location="cpu")) 46 | model.eval() 47 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 48 | for img_path in tqdm(img_paths): 49 | fake_image = generate(model, img_path, transform) 50 | output_path = osp.join(args.output_root, osp.basename(img_path)) 51 | cv2.imwrite(output_path, fake_image[..., ::-1]) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /gwd/colorization/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def weights_init_normal(m): 6 | classname = m.__class__.__name__ 7 | if classname.find("Conv") != -1: 8 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 9 | elif classname.find("BatchNorm2d") != -1: 10 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 11 | torch.nn.init.constant_(m.bias.data, 0.0) 12 | 13 | 14 | class UNetDown(nn.Module): 15 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0): 16 | super(UNetDown, self).__init__() 17 | layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)] 18 | if normalize: 19 | layers.append(nn.InstanceNorm2d(out_size)) 20 | layers.append(nn.LeakyReLU(0.2)) 21 | if dropout: 22 | layers.append(nn.Dropout(dropout)) 23 | self.model = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.model(x) 27 | 28 | 29 | class UNetUp(nn.Module): 30 | def __init__(self, in_size, out_size, dropout=0.0): 31 | super(UNetUp, self).__init__() 32 | layers = [ 33 | nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), 34 | nn.InstanceNorm2d(out_size), 35 | nn.ReLU(inplace=True), 36 | ] 37 | if dropout: 38 | layers.append(nn.Dropout(dropout)) 39 | 40 | self.model = nn.Sequential(*layers) 41 | 42 | def forward(self, x, skip_input): 43 | x = self.model(x) 44 | x = torch.cat((x, skip_input), 1) 45 | 46 | return x 47 | 48 | 49 | class GeneratorUNet(nn.Module): 50 | def __init__(self, in_channels=3, out_channels=3): 51 | super(GeneratorUNet, self).__init__() 52 | 53 | self.down1 = UNetDown(in_channels, 64, normalize=False) 54 | self.down2 = UNetDown(64, 128) 55 | self.down3 = UNetDown(128, 256) 56 | self.down4 = UNetDown(256, 512, dropout=0.5) 57 | self.down5 = UNetDown(512, 512, dropout=0.5) 58 | self.down6 = UNetDown(512, 512, dropout=0.5) 59 | self.down7 = UNetDown(512, 512, dropout=0.5) 60 | self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) 61 | 62 | self.up1 = UNetUp(512, 512, dropout=0.5) 63 | self.up2 = UNetUp(1024, 512, dropout=0.5) 64 | self.up3 = UNetUp(1024, 512, dropout=0.5) 65 | self.up4 = UNetUp(1024, 512, dropout=0.5) 66 | self.up5 = UNetUp(1024, 256) 67 | self.up6 = UNetUp(512, 128) 68 | self.up7 = UNetUp(256, 64) 69 | 70 | self.final = nn.Sequential( 71 | nn.Upsample(scale_factor=2), 72 | nn.ZeroPad2d((1, 0, 1, 0)), 73 | nn.Conv2d(128, out_channels, 4, padding=1), 74 | nn.Tanh(), 75 | ) 76 | 77 | def forward(self, x): 78 | # U-Net generator with skip connections from encoder to decoder 79 | d1 = self.down1(x) 80 | d2 = self.down2(d1) 81 | d3 = self.down3(d2) 82 | d4 = self.down4(d3) 83 | d5 = self.down5(d4) 84 | d6 = self.down6(d5) 85 | d7 = self.down7(d6) 86 | d8 = self.down8(d7) 87 | u1 = self.up1(d8, d7) 88 | u2 = self.up2(u1, d6) 89 | u3 = self.up3(u2, d5) 90 | u4 = self.up4(u3, d4) 91 | u5 = self.up5(u4, d3) 92 | u6 = self.up6(u5, d2) 93 | u7 = self.up7(u6, d1) 94 | 95 | return self.final(u7) 96 | 97 | 98 | ############################## 99 | # Discriminator 100 | ############################## 101 | 102 | 103 | class Discriminator(nn.Module): 104 | def __init__(self, in_channels=3): 105 | super(Discriminator, self).__init__() 106 | 107 | def discriminator_block(in_filters, out_filters, normalization=True): 108 | """Returns downsampling layers of each discriminator block""" 109 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 110 | if normalization: 111 | layers.append(nn.InstanceNorm2d(out_filters)) 112 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 113 | return layers 114 | 115 | self.model = nn.Sequential( 116 | *discriminator_block(in_channels * 2, 64, normalization=False), 117 | *discriminator_block(64, 128), 118 | *discriminator_block(128, 256), 119 | *discriminator_block(256, 512), 120 | nn.ZeroPad2d((1, 0, 1, 0)), 121 | nn.Conv2d(512, 1, 4, padding=1, bias=False), 122 | ) 123 | 124 | def forward(self, img_A, img_B): 125 | # Concatenate image and condition image by channels to produce input 126 | img_input = torch.cat((img_A, img_B), 1) 127 | return self.model(img_input) 128 | -------------------------------------------------------------------------------- /gwd/converters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/converters/__init__.py -------------------------------------------------------------------------------- /gwd/converters/coco2crop.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import mmcv 6 | from pycocotools.coco import COCO 7 | from tqdm import tqdm 8 | 9 | SCORE_THRESHOLD = 0.8 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--img_root", default="/data/test") 15 | parser.add_argument("--annotation_path", default="/data/rfp_r50_ga_mstrain_pseudo_stage1_test_predictions.json") 16 | parser.add_argument("--output_root", default="/data/test_wheat_crops") 17 | parser.add_argument("--from_predict", action="store_true") 18 | return parser.parse_args() 19 | 20 | 21 | def main(): 22 | args = parse_args() 23 | os.makedirs(args.output_root, exist_ok=True) 24 | dataset = COCO(args.annotation_path) 25 | for img_id, img_info in tqdm(dataset.imgs.items()): 26 | filename = img_info["file_name"] 27 | img = mmcv.imread(osp.join(args.img_root, filename)) 28 | ann_ids = dataset.getAnnIds(imgIds=[img_id]) 29 | ann_info = dataset.loadAnns(ann_ids) 30 | for i, ann in enumerate(ann_info): 31 | x1, y1, w, h = map(int, ann["bbox"]) 32 | category_id = ann["category_id"] 33 | crop = img[y1 : y1 + h, x1 : x1 + w] 34 | output_path = osp.join(args.output_root, f"{x1}_{y1}_{w}_{h}_{filename}") 35 | if w > 6 and h > 6: 36 | if args.from_predict: 37 | if category_id == 2 and ann["score"] > SCORE_THRESHOLD: # prediction 38 | mmcv.imwrite(crop, output_path) 39 | elif category_id == 1 and img_info["source"] in ["arvalis_1", "arvalis_3", "rres_1", "inrae_1"]: 40 | mmcv.imwrite(crop, output_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /gwd/converters/images2coco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | from glob import glob 4 | from multiprocessing import Pool 5 | 6 | import cv2 7 | import mmcv 8 | from tqdm import tqdm 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--img_pattern", default="/data/test/*jpg") 14 | parser.add_argument("--output_path", default="/data/coco_test.json") 15 | parser.add_argument("--n_jobs", type=int, default=16) 16 | return parser.parse_args() 17 | 18 | 19 | def convert(id_path): 20 | img_id, img_path = id_path 21 | image = cv2.imread(img_path) 22 | h, w = image.shape[:2] 23 | return {"id": img_id, "height": h, "width": w, "file_name": osp.basename(img_path)} 24 | 25 | 26 | def main(img_pattern, output_path, n_jobs=16): 27 | img_paths = glob(img_pattern) 28 | with Pool(n_jobs) as p: 29 | coco_images = list( 30 | tqdm( 31 | iterable=p.imap_unordered(convert, enumerate(img_paths)), total=len(img_paths), desc="Images to COCO..." 32 | ) 33 | ) 34 | 35 | mmcv.dump( 36 | { 37 | "annotations": [], 38 | "images": coco_images, 39 | "categories": [{"supercategory": "wheat", "name": "wheat", "id": 1}], 40 | }, 41 | output_path, 42 | ) 43 | 44 | 45 | if __name__ == "__main__": 46 | main(**vars(parse_args())) 47 | -------------------------------------------------------------------------------- /gwd/converters/kaggle2coco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | from copy import deepcopy 4 | from typing import Any, Dict, List, Tuple 5 | 6 | import mmcv 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--annotation_path", default="/data/folds_v2/0/clean_tile_train.csv") 14 | parser.add_argument("--output_path", default="/data/folds_v2/0/coco_clean_tile_train.json") 15 | return parser.parse_args() 16 | 17 | 18 | def group2coco(image_name: str, group: pd.DataFrame) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: 19 | coco_annotations = [] 20 | for _, row in group.iterrows(): 21 | x_min, y_min, width, height = row["bbox"] 22 | category_id = 1 23 | is_ignore = row.get("ignore", False) 24 | coco_annotations.append( 25 | { 26 | "segmentation": "", 27 | "area": float(width * height), 28 | "category_id": category_id, 29 | "bbox": [float(x_min), float(y_min), float(width), float(height)], 30 | "iscrowd": int(is_ignore), 31 | } 32 | ) 33 | 34 | return ( 35 | { 36 | "width": int(group["width"].iloc[0]), 37 | "height": int(group["height"].iloc[0]), 38 | "file_name": image_name, 39 | "source": group["source"].iloc[0], 40 | }, 41 | coco_annotations, 42 | ) 43 | 44 | 45 | def main(annotation_path, output_path, exclude_sources=None): 46 | annotations = pd.read_csv(annotation_path, converters={"bbox": ast.literal_eval}) 47 | if exclude_sources is not None: 48 | annotations = annotations[~annotations["source"].isin(exclude_sources)] 49 | annotations["image_name"] = annotations["image_id"].apply(lambda x: f"{x}.jpg") 50 | 51 | coco_annotations = [] 52 | coco_images = [] 53 | image_groups = annotations.groupby("image_id") 54 | 55 | for i, (image_name, group) in tqdm(enumerate(annotations.groupby("image_name")), total=len(image_groups)): 56 | image_info, image_annotations = group2coco(image_name, group) 57 | 58 | image_info["id"] = i 59 | for ann in image_annotations: 60 | ann["image_id"] = i 61 | coco_images.append(deepcopy(image_info)) 62 | coco_annotations.extend(deepcopy(image_annotations)) 63 | 64 | for i, ann in enumerate(coco_annotations): 65 | ann["id"] = i 66 | 67 | print(f"Length of images: {len(coco_images)}") 68 | print(f"Length of annotations: {len(coco_annotations)}") 69 | print(f"Length set image id: {len(set([x['id'] for x in coco_images]))}") 70 | print(f"Max image id: {max([x['id'] for x in coco_images])}") 71 | print(coco_images[0]) 72 | print(coco_annotations[0]) 73 | 74 | mmcv.dump( 75 | { 76 | "annotations": coco_annotations, 77 | "images": coco_images, 78 | "categories": [{"supercategory": "wheat", "name": "wheat", "id": 1}], 79 | }, 80 | output_path, 81 | ) 82 | 83 | 84 | if __name__ == "__main__": 85 | main(**vars(parse_args())) 86 | -------------------------------------------------------------------------------- /gwd/converters/spike2kaggle.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | from glob import glob 4 | 5 | import cv2 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | from gwd.converters import kaggle2coco 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--image-pattern", default="/data/SPIKE_images/*jpg") 15 | parser.add_argument("--annotation-root", default="/data/SPIKE_annotations") 16 | parser.add_argument("--kaggle_output_path", default="/data/spike.csv") 17 | parser.add_argument("--coco_output_path", default="/data/coco_spike.json") 18 | return parser.parse_args() 19 | 20 | 21 | def main(): 22 | args = parse_args() 23 | img_paths = glob(args.image_pattern) 24 | annotations = [] 25 | for img_path in tqdm(img_paths): 26 | ann_path = osp.join(args.annotation_root, (osp.basename(img_path.replace("jpg", "bboxes.tsv")))) 27 | ann = pd.read_csv(ann_path, sep="\t", names=["x_min", "y_min", "x_max", "y_max"]) 28 | h, w = cv2.imread(img_path).shape[:2] 29 | ann[["x_min", "x_max"]] = ann[["x_min", "x_max"]].clip(0, w) 30 | ann[["y_min", "y_max"]] = ann[["y_min", "y_max"]].clip(0, h) 31 | ann["height"] = h 32 | ann["width"] = w 33 | ann["bbox_width"] = ann["x_max"] - ann["x_min"] 34 | ann["bbox_height"] = ann["y_max"] - ann["y_min"] 35 | ann = ann[(ann["bbox_width"] > 0) & (ann["bbox_height"] > 0)].copy() 36 | ann["bbox"] = ann[["x_min", "y_min", "bbox_width", "bbox_height"]].values.tolist() 37 | ann["image_id"] = osp.basename(img_path).split(".")[0] 38 | annotations.append(ann) 39 | annotations = pd.concat(annotations) 40 | annotations["source"] = "spike" 41 | print(annotations.head()) 42 | annotations[["image_id", "source", "width", "height", "bbox"]].to_csv(args.kaggle_output_path, index=False) 43 | kaggle2coco.main(args.kaggle_output_path, args.coco_output_path) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /gwd/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/datasets/__init__.py -------------------------------------------------------------------------------- /gwd/datasets/evaluation.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | 3 | import numpy as np 4 | from mmcv.utils import print_log 5 | 6 | from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps 7 | from mmdet.core.evaluation.mean_ap import get_cls_results 8 | 9 | 10 | def calc_tpfpfn(det_bboxes, gt_bboxes, iou_thr=0.5): 11 | """Check if detected bboxes are true positive or false positive and if gt bboxes are false negative. 12 | 13 | Args: 14 | det_bboxes (ndarray): Detected bboxes of this image, of shape (m, 5). 15 | gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). 16 | iou_thr (float): IoU threshold to be considered as matched. 17 | Default: 0.5. 18 | 19 | Returns: 20 | float: (tp, fp, fn). 21 | """ 22 | num_dets = det_bboxes.shape[0] 23 | num_gts = gt_bboxes.shape[0] 24 | tp = 0 25 | fp = 0 26 | 27 | # if there is no gt bboxes in this image, then all det bboxes 28 | # within area range are false positives 29 | if num_gts == 0: 30 | fp = num_dets 31 | return tp, fp, 0 32 | 33 | ious: np.ndarray = bbox_overlaps(det_bboxes, gt_bboxes) 34 | # sort all dets in descending order by scores 35 | sort_inds = np.argsort(-det_bboxes[:, -1]) 36 | gt_covered = np.zeros(num_gts, dtype=bool) 37 | for i in sort_inds: 38 | uncovered_ious = ious[i, gt_covered == 0] 39 | if len(uncovered_ious): 40 | iou_argmax = uncovered_ious.argmax() 41 | iou_max = uncovered_ious[iou_argmax] 42 | if iou_max >= iou_thr: 43 | gt_covered[[x[iou_argmax] for x in np.where(gt_covered == 0)]] = True 44 | tp += 1 45 | else: 46 | fp += 1 47 | else: 48 | fp += 1 49 | fn = (gt_covered == 0).sum() 50 | return tp, fp, fn 51 | 52 | 53 | def kaggle_map( 54 | det_results, annotations, iou_thrs=(0.5, 0.55, 0.6, 0.65, 0.7, 0.75), logger=None, n_jobs=4, by_sample=False 55 | ): 56 | """Evaluate kaggle mAP of a dataset. 57 | 58 | Args: 59 | det_results (list[list]): [[cls1_det, cls2_det, ...], ...]. 60 | The outer list indicates images, and the inner list indicates 61 | per-class detected bboxes. 62 | annotations (list[dict]): Ground truth annotations where each item of 63 | the list indicates an image. Keys of annotations are: 64 | 65 | - `bboxes`: numpy array of shape (n, 4) 66 | - `labels`: numpy array of shape (n, ) 67 | - `bboxes_ignore` (optional): numpy array of shape (k, 4) 68 | - `labels_ignore` (optional): numpy array of shape (k, ) 69 | iou_thrs (list): IoU thresholds to be considered as matched. 70 | Default: (0.5, 0.55, 0.6, 0.65, 0.7, 0.75). 71 | logger (logging.Logger | str | None): The way to print the mAP 72 | summary. See `mmdet.utils.print_log()` for details. Default: None. 73 | n_jobs (int): Processes used for computing TP, FP and FN. 74 | Default: 4. 75 | by_sample (bool): Return AP by sample. 76 | 77 | Returns: 78 | tuple: (mAP, [dict, dict, ...]) 79 | """ 80 | assert len(det_results) == len(annotations) 81 | 82 | num_imgs = len(det_results) 83 | num_classes = len(det_results[0]) # positive class num 84 | 85 | pool = Pool(n_jobs) 86 | eval_results = [] 87 | for i in range(num_classes): 88 | # get gt and det bboxes of this class 89 | cls_dets, cls_gts, _ = get_cls_results(det_results, annotations, i) 90 | # compute tp and fp for each image with multiple processes 91 | aps_by_thrs = [] 92 | aps_by_sample = np.zeros(num_imgs) 93 | for iou_thr in iou_thrs: 94 | tpfpfn = pool.starmap(calc_tpfpfn, zip(cls_dets, cls_gts, [iou_thr for _ in range(num_imgs)])) 95 | iou_thr_aps = np.array([tp / (tp + fp + fn) for tp, fp, fn in tpfpfn]) 96 | if by_sample: 97 | aps_by_sample += iou_thr_aps 98 | aps_by_thrs.append(np.mean(iou_thr_aps)) 99 | eval_results.append( 100 | { 101 | "num_gts": len(cls_gts), 102 | "num_dets": len(cls_dets), 103 | "ap": np.mean(aps_by_thrs), 104 | "ap_by_sample": None if not by_sample else aps_by_sample / len(iou_thrs), 105 | } 106 | ) 107 | pool.close() 108 | 109 | aps = [] 110 | for cls_result in eval_results: 111 | if cls_result["num_gts"] > 0: 112 | aps.append(cls_result["ap"]) 113 | mean_ap = np.array(aps).mean().item() if aps else 0.0 114 | 115 | print_log(f"\nKaggle mAP: {mean_ap}", logger=logger) 116 | return mean_ap, eval_results 117 | -------------------------------------------------------------------------------- /gwd/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/datasets/pipelines/__init__.py -------------------------------------------------------------------------------- /gwd/datasets/pipelines/albumentations.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import random 3 | import sys 4 | 5 | import albumentations as A 6 | import cv2 7 | import mmcv 8 | import numpy as np 9 | from albumentations.augmentations.bbox_utils import union_of_bboxes 10 | from albumentations.augmentations.transforms import F 11 | 12 | from mmdet.datasets.builder import PIPELINES 13 | from mmdet.datasets.pipelines import Albu 14 | 15 | from .utils import calculate_area, calculate_aspect_ratios 16 | 17 | 18 | @F.preserve_channel_dim 19 | def stretch(img, w_scale, h_scale, interpolation=cv2.INTER_LINEAR): 20 | height, width = img.shape[:2] 21 | new_height, new_width = int(height * h_scale), int(width * w_scale) 22 | return F.resize(img, new_height, new_width, interpolation) 23 | 24 | 25 | class RandomStretch(A.RandomScale): 26 | def get_params(self): 27 | return { 28 | "w_scale": random.uniform(self.scale_limit[0], self.scale_limit[1]), 29 | "h_scale": random.uniform(self.scale_limit[0], self.scale_limit[1]), 30 | } 31 | 32 | def apply(self, img, w_scale=1.0, h_scale=1.0, interpolation=cv2.INTER_LINEAR, **params): 33 | return stretch(img, w_scale, h_scale, interpolation) 34 | 35 | 36 | class ModifiedShiftScaleRotate(A.ShiftScaleRotate): 37 | def get_params(self): 38 | self.params = super().get_params() 39 | return self.params 40 | 41 | 42 | class RandomBBoxesSafeCrop(A.DualTransform): 43 | def __init__(self, num_rate=(0.1, 1.0), erosion_rate=0.0, min_edge_ratio=0.5, always_apply=False, p=1.0): 44 | super(RandomBBoxesSafeCrop, self).__init__(always_apply, p) 45 | self.erosion_rate = erosion_rate 46 | self.num_rate = num_rate 47 | self.min_edge_ratio = min_edge_ratio 48 | 49 | def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, **params): 50 | return F.random_crop(img, crop_height, crop_width, h_start, w_start) 51 | 52 | def get_params_dependent_on_targets(self, params): 53 | img_h, img_w = params["image"].shape[:2] 54 | if len(params["bboxes"]) == 0: # less likely, this class is for use with bboxes. 55 | erosive_h = int(img_h * (1.0 - self.erosion_rate)) 56 | crop_height = img_h if erosive_h >= img_h else random.randint(erosive_h, img_h) 57 | return { 58 | "h_start": random.random(), 59 | "w_start": random.random(), 60 | "crop_height": crop_height, 61 | "crop_width": int(crop_height * img_w / img_h), 62 | } 63 | x, y, x2, y2 = union_of_bboxes(width=1.0, height=1.0, bboxes=params["bboxes"]) 64 | if x2 - x >= self.min_edge_ratio and y2 - y >= self.min_edge_ratio: 65 | for i in range(50): 66 | # get union of all bboxes 67 | x, y, x2, y2 = union_of_bboxes( 68 | width=1.0, 69 | height=1.0, 70 | bboxes=random.choices( 71 | params["bboxes"], k=max(int(random.uniform(*self.num_rate) * len(params["bboxes"])), 1) 72 | ), 73 | ) 74 | # find bigger region 75 | bx, by = ( 76 | x * random.uniform(1 - self.erosion_rate, 1.0), 77 | y * random.uniform(1 - self.erosion_rate, 1.0), 78 | ) 79 | bx2, by2 = ( 80 | x2 + (1 - x2) * random.uniform(1 - self.erosion_rate, 1.0), 81 | y2 + (1 - y2) * random.uniform(1 - self.erosion_rate, 1.0), 82 | ) 83 | bw, bh = bx2 - bx, by2 - by 84 | crop_height = img_h if bh >= 1.0 else int(img_h * bh) 85 | crop_width = img_w if bw >= 1.0 else int(img_w * bw) 86 | 87 | if crop_height / crop_width < 0.5 or crop_height / crop_width > 2: 88 | continue 89 | 90 | h_start = np.clip(0.0 if bh >= 1.0 else by / (1.0 - bh), 0.0, 1.0) 91 | w_start = np.clip(0.0 if bw >= 1.0 else bx / (1.0 - bw), 0.0, 1.0) 92 | return {"h_start": h_start, "w_start": w_start, "crop_height": crop_height, "crop_width": crop_width} 93 | return {"h_start": 0, "w_start": 0, "crop_height": img_h, "crop_width": img_w} 94 | 95 | def apply_to_bbox(self, bbox, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params): 96 | return F.bbox_random_crop(bbox, crop_height, crop_width, h_start, w_start, rows, cols) 97 | 98 | @property 99 | def targets_as_params(self): 100 | return ["image", "bboxes"] 101 | 102 | def get_transform_init_args_names(self): 103 | return "erosion_rate", "num_rate", "min_edge_ratio" 104 | 105 | 106 | def albu_builder(cfg): 107 | """Import a module from albumentations. 108 | Inherits some of `build_from_cfg` logic. 109 | 110 | Args: 111 | cfg (dict): Config dict. It should at least contain the key "type". 112 | Returns: 113 | obj: The constructed object. 114 | """ 115 | assert isinstance(cfg, dict) and "type" in cfg 116 | args = cfg.copy() 117 | 118 | obj_type = args.pop("type") 119 | if mmcv.is_str(obj_type): 120 | if A is None: 121 | raise RuntimeError("albumentations is not installed") 122 | if hasattr(A, obj_type): 123 | obj_cls = getattr(A, obj_type) 124 | else: 125 | obj_cls = getattr(sys.modules[__name__], obj_type) 126 | elif inspect.isclass(obj_type): 127 | obj_cls = obj_type 128 | else: 129 | raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}") 130 | 131 | if "transforms" in args: 132 | args["transforms"] = [albu_builder(transform) for transform in args["transforms"]] 133 | 134 | return obj_cls(**args) 135 | 136 | 137 | @PIPELINES.register_module() 138 | class Albumentations(Albu): 139 | def __init__(self, min_visibility=0.3, min_size=4, max_aspect_ratio=10, **kwargs): 140 | super(Albumentations, self).__init__(**kwargs) 141 | self.min_visibility = min_visibility 142 | self.min_size = min_size 143 | self.max_aspect_ratio = max_aspect_ratio 144 | 145 | # Be careful: it is a dirty hack 146 | for i, t in enumerate(self.aug.transforms): 147 | if isinstance(t, ModifiedShiftScaleRotate): 148 | self.scale_index = i 149 | assert hasattr(self, "scale_index") 150 | 151 | def albu_builder(self, cfg): 152 | return albu_builder(cfg) 153 | 154 | def reset_scale_zero(self): 155 | self.aug.transforms[self.scale_index].params = {"scale": 1.0} 156 | 157 | def __call__(self, results): 158 | original_areas = calculate_area(results["gt_bboxes"]) 159 | results = self.mapper(results, self.keymap_to_albu) 160 | if isinstance(results["bboxes"], np.ndarray): 161 | results["bboxes"] = [x for x in results["bboxes"]] 162 | results["labels"] = np.arange(len(results["bboxes"])) 163 | 164 | self.reset_scale_zero() 165 | results = self.aug(**results) 166 | scale = self.aug.transforms[self.scale_index].params["scale"] 167 | if isinstance(results["bboxes"], list): 168 | results["bboxes"] = np.array(results["bboxes"], dtype=np.float32) 169 | 170 | if not len(results["bboxes"]) and self.skip_img_without_anno: 171 | return None 172 | 173 | original_areas = original_areas[results["labels"]] 174 | augmented_areas = calculate_area(results["bboxes"]) 175 | aspect_ratios = calculate_aspect_ratios(results["bboxes"]) 176 | widths = results["bboxes"][:, 2] - results["bboxes"][:, 0] 177 | heights = results["bboxes"][:, 3] - results["bboxes"][:, 1] 178 | size_mask = (widths > self.min_size) & (heights > self.min_size) 179 | area_mask = augmented_areas / (scale ** 2 * original_areas) > self.min_visibility 180 | aspect_ratio_mask = aspect_ratios < self.max_aspect_ratio 181 | mask = size_mask & area_mask & aspect_ratio_mask 182 | results["bboxes"] = results["bboxes"][mask] 183 | if not len(results["bboxes"]) and self.skip_img_without_anno: 184 | return None 185 | results["gt_labels"] = np.zeros(len(results["bboxes"])).astype(np.int64) 186 | 187 | # back to the original format 188 | results = self.mapper(results, self.keymap_back) 189 | # update final shape 190 | if self.update_pad_shape: 191 | results["pad_shape"] = results["img"].shape 192 | 193 | if results is not None: 194 | results["img_shape"] = results["img"].shape 195 | return results 196 | -------------------------------------------------------------------------------- /gwd/datasets/pipelines/loading.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import mmcv 4 | import numpy as np 5 | 6 | from mmdet.datasets.builder import PIPELINES 7 | from mmdet.datasets.pipelines.loading import LoadImageFromFile 8 | 9 | 10 | @PIPELINES.register_module() 11 | class MultipleLoadImageFromFile(LoadImageFromFile): 12 | def __call__(self, results): 13 | if self.file_client is None: 14 | self.file_client = mmcv.FileClient(**self.file_client_args) 15 | 16 | assert isinstance(results["img_prefix"], dict) 17 | img_prefix = np.random.choice(results["img_prefix"]["roots"], p=results["img_prefix"]["probabilities"]) 18 | filename = osp.join(img_prefix, results["img_info"]["filename"]) 19 | 20 | img_bytes = self.file_client.get(filename) 21 | img = mmcv.imfrombytes(img_bytes, flag=self.color_type) 22 | if self.to_float32: 23 | img = img.astype(np.float32) 24 | 25 | results["filename"] = filename 26 | results["ori_filename"] = results["img_info"]["filename"] 27 | results["img"] = img 28 | results["img_shape"] = img.shape 29 | results["ori_shape"] = img.shape 30 | results["img_fields"] = ["img"] 31 | return results 32 | -------------------------------------------------------------------------------- /gwd/datasets/pipelines/test_aug.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import mmcv 4 | 5 | from mmdet.datasets.builder import PIPELINES 6 | from mmdet.datasets.pipelines import MultiScaleFlipAug 7 | 8 | 9 | @PIPELINES.register_module() 10 | class ModifiedMultiScaleFlipAug(MultiScaleFlipAug): 11 | def __init__(self, rotate=False, rotate_factor=0, **kwargs): 12 | super(ModifiedMultiScaleFlipAug, self).__init__(**kwargs) 13 | self.rotate = rotate 14 | self.rotate_factor = rotate_factor if isinstance(rotate_factor, list) else [rotate_factor] 15 | assert mmcv.is_list_of(self.rotate_factor, int) 16 | if not self.rotate and self.rotate_factor != [0]: 17 | warnings.warn("rotate_factor has no effect when rotate is set to False") 18 | 19 | def __call__(self, results): 20 | aug_data = [] 21 | flip_args = [[False, None]] 22 | rotate_args = [[False, None]] 23 | if self.flip: 24 | flip_args += [[True, direction] for direction in self.flip_direction] 25 | if self.rotate: 26 | rotate_args += [[True, factor] for factor in self.rotate_factor] 27 | for scale in self.img_scale: 28 | for flip, direction in flip_args: 29 | for rotate, factor in rotate_args: 30 | _results = results.copy() 31 | _results[self.scale_key] = scale 32 | _results["flip"] = flip 33 | _results["flip_direction"] = direction 34 | _results["rotate"] = rotate 35 | _results["rotate_factor"] = factor 36 | data = self.transforms(_results) 37 | aug_data.append(data) 38 | # list of dict to dict of list 39 | aug_data_dict = {key: [] for key in aug_data[0]} 40 | for data in aug_data: 41 | for key, val in data.items(): 42 | aug_data_dict[key].append(val) 43 | return aug_data_dict 44 | -------------------------------------------------------------------------------- /gwd/datasets/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def calculate_area(bboxes): 5 | return (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1]) 6 | 7 | 8 | def calculate_aspect_ratios(bboxes, eps=1e-6): 9 | return np.maximum( 10 | (bboxes[:, 2] - bboxes[:, 0]) / (bboxes[:, 3] - bboxes[:, 1] + eps), 11 | (bboxes[:, 3] - bboxes[:, 1]) / (bboxes[:, 2] - bboxes[:, 0] + eps), 12 | ) 13 | -------------------------------------------------------------------------------- /gwd/datasets/source_balanced_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict 3 | 4 | from mmdet.datasets.builder import DATASETS 5 | from mmdet.datasets.dataset_wrappers import ClassBalancedDataset 6 | 7 | from .wheat_detection import WheatDataset 8 | 9 | 10 | @DATASETS.register_module() 11 | class SourceBalancedDataset(ClassBalancedDataset): 12 | def _get_repeat_factors(self, dataset: WheatDataset, repeat_thr: float): 13 | # 1. For each source s, compute the fraction # of images 14 | # that contain it: f(s) 15 | source_freq = defaultdict(float) 16 | num_images = len(dataset) 17 | for data_info in dataset.data_infos: 18 | source = data_info["source"] 19 | source_freq[source] += 1.0 20 | for k, v in source_freq.items(): 21 | source_freq[k] = v / num_images 22 | 23 | # 2. For each category c, compute the category-level repeat factor: 24 | # r(c) = max(1, sqrt(t/f(c))) 25 | source_repeat = { 26 | source: max(1.0, math.sqrt(repeat_thr / source_freq)) for source, source_freq in source_freq.items() 27 | } 28 | 29 | # 3. For each image I, compute the image-level repeat factor: 30 | # r(I) = max_{c in I} r(c) 31 | repeat_factors = [source_repeat[x["source"]] for x in dataset.data_infos] 32 | return repeat_factors 33 | -------------------------------------------------------------------------------- /gwd/datasets/wheat_detection.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import mmcv 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from gwd.datasets.evaluation import kaggle_map 8 | from mmdet.datasets.builder import DATASETS 9 | from mmdet.datasets.coco import CocoDataset 10 | 11 | 12 | def calc_pseudo_confidence(sample_scores, pseudo_score_threshold): 13 | if len(sample_scores): 14 | return np.sum(sample_scores > pseudo_score_threshold) / len(sample_scores) 15 | else: 16 | return 0.0 17 | 18 | 19 | @DATASETS.register_module() 20 | class WheatDataset(CocoDataset): 21 | CLASSES = ("wheat",) 22 | 23 | def evaluate(self, results, logger=None, iou_thrs=(0.5, 0.55, 0.6, 0.65, 0.7, 0.75), **kwargs): 24 | annotations = [self.get_ann_info(i) for i in range(len(self))] 25 | mean_ap, _ = kaggle_map(results, annotations, iou_thrs=iou_thrs, logger=logger) 26 | return dict(mAP=mean_ap) 27 | 28 | def format_results(self, results, output_path=None, **kwargs): 29 | assert isinstance(results, list), "results must be a list" 30 | assert len(results) == len(self), "The length of results is not equal to the dataset len: {} != {}".format( 31 | len(results), len(self) 32 | ) 33 | prediction_results = [] 34 | for idx in range(len(self)): 35 | wheat_bboxes = results[idx][0] 36 | 37 | prediction_strs = [] 38 | for bbox in wheat_bboxes: 39 | x, y, w, h = self.xyxy2xywh(bbox) 40 | prediction_strs.append(f"{bbox[4]:.4f} {x} {y} {w} {h}") 41 | filename = self.data_infos[idx]["filename"] 42 | image_id = osp.splitext(osp.basename(filename))[0] 43 | prediction_results.append({"image_id": image_id, "PredictionString": " ".join(prediction_strs)}) 44 | predictions = pd.DataFrame(prediction_results) 45 | if output_path is not None: 46 | predictions.to_csv(output_path, index=False) 47 | return predictions 48 | 49 | def evaluate_by_sample(self, results, output_path, logger=None, iou_thrs=(0.5, 0.55, 0.6, 0.65, 0.7, 0.75)): 50 | annotations = [self.get_ann_info(i) for i in range(len(self))] 51 | _, eval_results = kaggle_map(results, annotations, iou_thrs=iou_thrs, logger=logger, by_sample=True) 52 | output_annotations = self.coco.dataset["annotations"] 53 | output_images = [] 54 | 55 | for idx in range(len(self)): 56 | wheat_bboxes = results[idx][0] 57 | data_info = self.data_infos[idx] 58 | data_info["ap"] = eval_results[0]["ap_by_sample"][idx] 59 | output_images.append(data_info) 60 | for bbox in wheat_bboxes: 61 | x, y, w, h = map(float, self.xyxy2xywh(bbox)) 62 | output_annotations.append( 63 | { 64 | "segmentation": "", 65 | "area": w * h, 66 | "image_id": data_info["id"], 67 | "category_id": 2, 68 | "bbox": [x, y, w, h], 69 | "iscrowd": 0, 70 | "score": float(bbox[-1]), 71 | } 72 | ) 73 | for i, ann in enumerate(output_annotations): 74 | ann["id"] = i 75 | outputs = { 76 | "annotations": output_annotations, 77 | "images": output_images, 78 | "categories": [ 79 | {"supercategory": "wheat", "name": "gt", "id": 1}, 80 | {"supercategory": "wheat", "name": "predict", "id": 2}, 81 | ], 82 | } 83 | mmcv.dump(outputs, output_path) 84 | return outputs 85 | 86 | def pseudo_results(self, results, output_path=None, pseudo_score_threshold=0.8, pseudo_confidence_threshold=0.65): 87 | assert isinstance(results, list), "results must be a list" 88 | assert len(results) == len(self), "The length of results is not equal to the dataset len: {} != {}".format( 89 | len(results), len(self) 90 | ) 91 | pseudo_annotations = [] 92 | pseudo_images = [] 93 | for idx in range(len(self)): 94 | wheat_bboxes = results[idx][0] 95 | scores = np.array([bbox[-1] for bbox in wheat_bboxes]) 96 | confidence = calc_pseudo_confidence(scores, pseudo_score_threshold=pseudo_score_threshold) 97 | if confidence >= pseudo_confidence_threshold: 98 | data_info = self.data_infos[idx] 99 | data_info["confidence"] = confidence 100 | pseudo_images.append(data_info) 101 | for bbox in wheat_bboxes: 102 | x, y, w, h = self.xyxy2xywh(bbox) 103 | pseudo_annotations.append( 104 | { 105 | "segmentation": "", 106 | "area": w * h, 107 | "image_id": data_info["id"], 108 | "category_id": 1, 109 | "bbox": [x, y, w, h], 110 | "iscrowd": 0, 111 | } 112 | ) 113 | for i, ann in enumerate(pseudo_annotations): 114 | ann["id"] = i 115 | print(len(pseudo_images)) 116 | mmcv.dump( 117 | { 118 | "annotations": pseudo_annotations, 119 | "images": pseudo_images, 120 | "categories": [{"supercategory": "wheat", "name": "wheat", "id": 1}], 121 | }, 122 | output_path, 123 | ) 124 | -------------------------------------------------------------------------------- /gwd/dense_heads/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/dense_heads/__init__.py -------------------------------------------------------------------------------- /gwd/detectors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/detectors/__init__.py -------------------------------------------------------------------------------- /gwd/detectors/atss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mmdet.core import bbox2result # , bbox_mapping_back 4 | from mmdet.core.bbox import bbox_flip 5 | from mmdet.models.builder import DETECTORS 6 | from mmdet.models.detectors.single_stage import SingleStageDetector 7 | from mmdet.ops.nms import batched_nms 8 | 9 | 10 | def bbox_rot90_back(bboxes, img_shape, factor=0): 11 | assert bboxes.shape[-1] % 4 == 0 12 | assert factor in {0, 1, 2, 3} 13 | rotated = bboxes.clone() 14 | h, w = img_shape[:2] 15 | if factor == 3: 16 | rotated[..., 0] = bboxes[..., 1] 17 | rotated[..., 1] = w - bboxes[..., 2] 18 | rotated[..., 2] = bboxes[..., 3] 19 | rotated[..., 3] = w - bboxes[..., 0] 20 | elif factor == 2: 21 | rotated[..., 0] = w - bboxes[..., 2] 22 | rotated[..., 1] = h - bboxes[..., 3] 23 | rotated[..., 2] = w - bboxes[..., 0] 24 | rotated[..., 3] = h - bboxes[..., 1] 25 | elif factor == 1: 26 | rotated[..., 0] = h - bboxes[..., 3] 27 | rotated[..., 1] = bboxes[..., 0] 28 | rotated[..., 2] = h - bboxes[..., 1] 29 | rotated[..., 3] = bboxes[..., 2] 30 | return rotated 31 | 32 | 33 | def bbox_mapping_back( 34 | bboxes, 35 | img_shape, 36 | scale_factor, 37 | flip, 38 | flip_direction="horizontal", 39 | # rotate=False, 40 | # rotate_factor=0, 41 | ): 42 | new_bboxes = bbox_flip(bboxes, img_shape, flip_direction) if flip else bboxes 43 | # new_bboxes = ( 44 | # bbox_rot90_back(new_bboxes, img_shape, rotate_factor) if rotate else new_bboxes 45 | # ) 46 | new_bboxes = new_bboxes.view(-1, 4) / new_bboxes.new_tensor(scale_factor) 47 | return new_bboxes.view(bboxes.shape) 48 | 49 | 50 | def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_cfg, max_num=-1, score_factors=None): 51 | num_classes = multi_scores.size(1) - 1 52 | # exclude background category 53 | if multi_bboxes.shape[1] > 4: 54 | bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) 55 | else: 56 | bboxes = multi_bboxes[:, None].expand(-1, num_classes, 4) 57 | scores = multi_scores[:, :-1] 58 | # filter out boxes with low scores 59 | scaled_scores = scores * score_factors[:, None] 60 | valid_mask = scores > score_thr 61 | bboxes = bboxes[valid_mask] 62 | scores = scores[valid_mask] 63 | scaled_scores = scaled_scores[valid_mask] 64 | labels = valid_mask.nonzero()[:, 1] 65 | 66 | if bboxes.numel() == 0: 67 | bboxes = multi_bboxes.new_zeros((0, 5)) 68 | labels = multi_bboxes.new_zeros((0,), dtype=torch.long) 69 | return bboxes, labels 70 | 71 | dets, keep = batched_nms(bboxes, scaled_scores, labels, nms_cfg) 72 | if max_num > 0: 73 | dets = dets[:max_num] 74 | keep = keep[:max_num] 75 | scores = scores[keep] 76 | dets[:, -1] = scores 77 | return dets, labels[keep] 78 | 79 | 80 | @DETECTORS.register_module() 81 | class ModifiedATSS(SingleStageDetector): 82 | def __init__(self, backbone, neck, bbox_head, train_cfg=None, test_cfg=None, pretrained=None): 83 | super(ModifiedATSS, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained) 84 | 85 | def merge_aug_results(self, aug_bboxes, aug_scores, aug_centerness, img_metas): 86 | """Merge augmented detection bboxes and scores. 87 | Args: 88 | aug_bboxes (list[Tensor]): shape (n, 4*#class) 89 | aug_scores (list[Tensor] or None): shape (n, #class) 90 | img_shapes (list[Tensor]): shape (3, ). 91 | Returns: 92 | tuple: (bboxes, scores) 93 | """ 94 | recovered_bboxes = [] 95 | for bboxes, img_info in zip(aug_bboxes, img_metas): 96 | img_shape = img_info[0]["img_shape"] 97 | scale_factor = img_info[0]["scale_factor"] 98 | flip = img_info[0]["flip"] 99 | flip_direction = img_info[0]["flip_direction"] 100 | # rotate = img_info[0]["rotate"] 101 | # rotate_factor = img_info[0]["rotate_factor"] 102 | bboxes = bbox_mapping_back( 103 | bboxes=bboxes, 104 | img_shape=img_shape, 105 | scale_factor=scale_factor, 106 | flip=flip, 107 | flip_direction=flip_direction, 108 | # rotate=rotate, 109 | # rotate_factor=rotate_factor 110 | ) 111 | recovered_bboxes.append(bboxes) 112 | bboxes = torch.cat(recovered_bboxes, dim=0) 113 | centerness = torch.cat(aug_centerness, dim=0) 114 | if aug_scores is None: 115 | return bboxes, centerness 116 | else: 117 | scores = torch.cat(aug_scores, dim=0) 118 | return bboxes, scores, centerness 119 | 120 | def aug_test(self, imgs, img_metas, rescale=False): 121 | # recompute feats to save memory 122 | feats = self.extract_feats(imgs) 123 | 124 | aug_bboxes = [] 125 | aug_scores = [] 126 | aug_centerness = [] 127 | for x, img_meta in zip(feats, img_metas): 128 | # only one image in the batch 129 | outs = self.bbox_head(x) 130 | bbox_inputs = outs + (img_meta, self.test_cfg, False, False) 131 | det_bboxes, det_scores, det_centerness = self.bbox_head.get_bboxes(*bbox_inputs)[0] 132 | aug_bboxes.append(det_bboxes) 133 | aug_scores.append(det_scores) 134 | aug_centerness.append(det_centerness) 135 | 136 | # after merging, bboxes will be rescaled to the original image size 137 | merged_bboxes, merged_scores, merged_centerness = self.merge_aug_results( 138 | aug_bboxes, aug_scores, aug_centerness, img_metas 139 | ) 140 | det_bboxes, det_labels = multiclass_nms( 141 | merged_bboxes, 142 | merged_scores, 143 | self.test_cfg.score_thr, 144 | self.test_cfg.nms, 145 | self.test_cfg.max_per_img, 146 | score_factors=merged_centerness, 147 | ) 148 | 149 | if rescale: 150 | _det_bboxes = det_bboxes 151 | else: 152 | _det_bboxes = det_bboxes.clone() 153 | _det_bboxes[:, :4] *= det_bboxes.new_tensor(img_metas[0][0]["scale_factor"]) 154 | bbox_results = bbox2result(_det_bboxes, det_labels, self.bbox_head.num_classes) 155 | return bbox_results 156 | -------------------------------------------------------------------------------- /gwd/detectors/rfp.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from mmdet.models import DETECTORS, builder 7 | from mmdet.models.detectors.htc import HybridTaskCascade 8 | 9 | 10 | class ASPP(torch.nn.Module): 11 | def __init__(self, in_channels, out_channels): 12 | super().__init__() 13 | kernel_sizes = [1, 3, 3, 1] 14 | dilations = [1, 3, 6, 1] 15 | paddings = [0, 3, 6, 0] 16 | self.aspp = torch.nn.ModuleList() 17 | for aspp_idx in range(len(kernel_sizes)): 18 | conv = torch.nn.Conv2d( 19 | in_channels, 20 | out_channels, 21 | kernel_size=kernel_sizes[aspp_idx], 22 | stride=1, 23 | dilation=dilations[aspp_idx], 24 | padding=paddings[aspp_idx], 25 | bias=True, 26 | ) 27 | self.aspp.append(conv) 28 | self.gap = torch.nn.AdaptiveAvgPool2d(1) 29 | self.aspp_num = len(kernel_sizes) 30 | for m in self.modules(): 31 | if isinstance(m, torch.nn.Conv2d): 32 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 33 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 34 | m.bias.data.fill_(0) 35 | 36 | def forward(self, x): 37 | avg_x = self.gap(x) 38 | out = [] 39 | for aspp_idx in range(self.aspp_num): 40 | inp = avg_x if (aspp_idx == self.aspp_num - 1) else x 41 | out.append(F.relu_(self.aspp[aspp_idx](inp))) 42 | out[-1] = out[-1].expand_as(out[-2]) 43 | out = torch.cat(out, dim=1) 44 | return out 45 | 46 | 47 | @DETECTORS.register_module 48 | class RecursiveFeaturePyramid(HybridTaskCascade): 49 | def __init__(self, backbone, rfp_steps=2, rfp_sharing=False, stage_with_rfp=(False, True, True, True), **kwargs): 50 | self.rfp_steps = rfp_steps 51 | self.rfp_sharing = rfp_sharing 52 | self.stage_with_rfp = stage_with_rfp 53 | backbone["rfp"] = None 54 | backbone["stage_with_rfp"] = stage_with_rfp 55 | neck_out_channels = kwargs["neck"]["out_channels"] 56 | if rfp_sharing: 57 | backbone["rfp"] = neck_out_channels 58 | super().__init__(backbone=backbone, **kwargs) 59 | if not self.rfp_sharing: 60 | backbone["rfp"] = neck_out_channels 61 | self.rfp_modules = torch.nn.ModuleList() 62 | for rfp_idx in range(1, rfp_steps): 63 | rfp_module = builder.build_backbone(backbone) 64 | rfp_module.init_weights(kwargs["pretrained"]) 65 | self.rfp_modules.append(rfp_module) 66 | self.rfp_aspp = ASPP(neck_out_channels, neck_out_channels // 4) 67 | self.rfp_weight = torch.nn.Conv2d(neck_out_channels, 1, kernel_size=1, stride=1, padding=0, bias=True) 68 | self.rfp_weight.weight.data.fill_(0) 69 | self.rfp_weight.bias.data.fill_(0) 70 | 71 | def extract_feat(self, img): 72 | x = self.backbone(img) 73 | x = self.neck(x) 74 | for rfp_idx in range(self.rfp_steps - 1): 75 | rfp_feats = tuple( 76 | self.rfp_aspp(x[i]) if self.stage_with_rfp[i] else x[i] for i in range(len(self.stage_with_rfp)) 77 | ) 78 | if self.rfp_sharing: 79 | x_idx = self.backbone.rfp_forward(img, rfp_feats) 80 | else: 81 | x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats) 82 | x_idx = self.neck(x_idx) 83 | x_new = [] 84 | for ft_idx in range(len(x_idx)): 85 | add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx])) 86 | x_new.append(add_weight * x_idx[ft_idx] + (1 - add_weight) * x[ft_idx]) 87 | x = x_new 88 | return x 89 | -------------------------------------------------------------------------------- /gwd/eda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/eda/__init__.py -------------------------------------------------------------------------------- /gwd/eda/coco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import tempfile 5 | 6 | import cv2 7 | import mmcv 8 | from pycocotools.coco import COCO 9 | from tqdm import tqdm 10 | 11 | from gwd.eda.visualization import draw_bounding_boxes_on_image 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--img-prefix", default="/data/train") 17 | parser.add_argument("--ann-file", default="/data/rfp_r50_ga_mstrain_pseudo_stage1_epoch_10_train_predictions.json") 18 | parser.add_argument("--prediction_path") 19 | parser.add_argument("--output_root", type=str, default="/data/eda") 20 | return parser.parse_args() 21 | 22 | 23 | def prepare_predictions(prediction_path, annotation_path, output_path): 24 | predictions = mmcv.load(prediction_path) 25 | annotations = mmcv.load(annotation_path) 26 | for i, prediction in tqdm(enumerate(predictions), total=len(predictions)): 27 | prediction["segmentation"] = "" 28 | x1, y1, w, h = prediction["bbox"] 29 | prediction["area"] = w * h 30 | prediction["id"] = i 31 | prediction["iscrowd"] = 0 32 | annotations["annotations"] = predictions 33 | mmcv.dump(annotations, output_path) 34 | 35 | 36 | def prepare_bboxes(ann_info): 37 | bboxes = [] 38 | labels = [] 39 | display_str_list = [] 40 | for i, ann in enumerate(ann_info): 41 | x1, y1, w, h = ann["bbox"] 42 | if ann["area"] <= 0 or w < 1 or h < 1: 43 | continue 44 | bbox = [x1, y1, x1 + w, y1 + h] 45 | bboxes.append(bbox) 46 | labels.append(ann["category_id"]) 47 | if "score" in ann: 48 | display_str_list.append(f'{ann["score"]:.2f}') 49 | else: 50 | display_str_list.append("") 51 | return bboxes, labels, display_str_list 52 | 53 | 54 | def main(): 55 | args = parse_args() 56 | 57 | os.makedirs(args.output_root, exist_ok=True) 58 | if args.prediction_path is not None: 59 | with tempfile.TemporaryDirectory() as root: 60 | tmp_ann_file = osp.join(root, "ann.json") 61 | prepare_predictions(args.prediction_path, args.ann_file, tmp_ann_file) 62 | dataset = COCO(tmp_ann_file) 63 | else: 64 | dataset = COCO(args.ann_file) 65 | 66 | for img_id, img_info in tqdm(dataset.imgs.items()): 67 | img = cv2.imread(osp.join(args.img_prefix, img_info["file_name"]))[..., ::-1] 68 | ann_ids = dataset.getAnnIds(imgIds=[img_id]) 69 | ann_info = dataset.loadAnns(ann_ids) 70 | bboxes, labels, display_str_list = prepare_bboxes(ann_info) 71 | label2colors = { 72 | -1: {"bbox": (255, 0, 0)}, 73 | 1: {"bbox": (0, 128, 255)}, 74 | 2: {"bbox": (128, 0, 0), "text": (255, 255, 255)}, 75 | } 76 | draw_bounding_boxes_on_image( 77 | img, 78 | bboxes, 79 | labels=labels, 80 | label2colors=label2colors, 81 | display_str_list=display_str_list, 82 | use_normalized_coordinates=False, 83 | thickness=4, 84 | fontsize=15, 85 | ) 86 | filename = f"{img_info['file_name']}" 87 | if "ap" in img_info: 88 | filename = f"{img_info['ap']:.2f}_{filename}" 89 | cv2.imwrite(osp.join(args.output_root, filename), img[..., ::-1]) 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /gwd/eda/kmeans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def iou(box, clusters): 5 | x = np.minimum(clusters[:, 0], box[0]) 6 | y = np.minimum(clusters[:, 1], box[1]) 7 | 8 | intersection = x * y 9 | box_area = box[0] * box[1] 10 | cluster_area = clusters[:, 0] * clusters[:, 1] 11 | iou_ = intersection / (box_area + cluster_area - intersection) 12 | if np.any(np.isnan(iou_)): 13 | breakpoint() 14 | 15 | return iou_ 16 | 17 | 18 | def kmeans(boxes, k, aggregator=np.median): 19 | rows = boxes.shape[0] 20 | 21 | distances = np.empty((rows, k)) 22 | last_clusters = np.zeros((rows,)) 23 | 24 | np.random.seed(77) 25 | 26 | clusters = boxes[np.random.choice(rows, k, replace=False)] 27 | 28 | while True: 29 | for row in range(rows): 30 | distances[row] = 1 - iou(boxes[row], clusters) 31 | 32 | nearest_clusters = np.argmin(distances, axis=1) 33 | 34 | if (last_clusters == nearest_clusters).all(): 35 | break 36 | 37 | for cluster in range(k): 38 | if np.any(nearest_clusters == cluster): 39 | clusters[cluster] = aggregator(boxes[nearest_clusters == cluster], axis=0) 40 | last_clusters = nearest_clusters 41 | 42 | return clusters 43 | -------------------------------------------------------------------------------- /gwd/eda/pipeline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from mmcv import Config 9 | from tqdm import tqdm 10 | 11 | from gwd.datasets.wheat_detection import WheatDataset 12 | from gwd.eda.kmeans import kmeans 13 | from gwd.eda.visualization import draw_bounding_boxes_on_image 14 | from mmdet.datasets.builder import build_dataset 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--config_path", default="configs/rfp_spike/rfp_r50_ga_mstrain_spike_stage1.py") 20 | parser.add_argument("--output_root", default="/data/eda") 21 | parser.add_argument("--fold", default=0) 22 | parser.add_argument("--skip_type", nargs="+", default=("DefaultFormatBundle", "Normalize", "Collect")) 23 | return parser.parse_args() 24 | 25 | 26 | def retrieve_data_cfg(config_path, fold, skip_type): 27 | cfg = Config.fromfile(config_path) 28 | train_data_cfg = cfg.data.train 29 | if "ann_file" not in cfg.data.train: 30 | train_data_cfg = train_data_cfg.dataset 31 | if isinstance(cfg.data.train.ann_file, list): 32 | cfg.data.train.ann_file = [x.format(fold=fold) for x in cfg.data.train.ann_file] 33 | elif isinstance(cfg.data.train.ann_file, str): 34 | cfg.data.train.ann_file = cfg.data.train.ann_file.format(fold=fold) 35 | train_data_cfg["pipeline"] = [x for x in train_data_cfg.pipeline if x["type"] not in skip_type] 36 | return cfg 37 | 38 | 39 | def plot_statistics(widths, heights, output_root): 40 | statistics = { 41 | "width": {"array": widths / 1024, "axis": (0, 0.5)}, 42 | "height": {"array": heights / 1024, "axis": (0, 0.5)}, 43 | "area": {"array": widths * heights / 1024 / 1024, "axis": (0, 0.1)}, 44 | "ratio": {"array": widths / heights, "axis": (0, 4)}, 45 | } 46 | axes = {} 47 | fig, ((axes["width"], axes["height"]), (axes["area"], axes["ratio"])) = plt.subplots(2, 2, figsize=(12, 12)) 48 | for name, data in statistics.items(): 49 | axes[name].hist(data["array"], 250, density=True) 50 | axes[name].set_title(f"{name}") 51 | axes[name].grid(True) 52 | axes[name].set_xlim(data["axis"]) 53 | fig.savefig(osp.join(output_root, "statistics.png")) 54 | 55 | 56 | def main(): 57 | args = parse_args() 58 | os.makedirs(args.output_root, exist_ok=True) 59 | cfg = retrieve_data_cfg(args.config_path, args.fold, args.skip_type) 60 | 61 | dataset: WheatDataset = build_dataset(cfg.data.train) 62 | from IPython import embed 63 | 64 | embed() 65 | heights = [] 66 | widths = [] 67 | for i, data in tqdm(enumerate(dataset), total=len(dataset)): 68 | image_id = osp.basename(dataset.data_infos[i]["file_name"]) 69 | image = data["img"] 70 | bboxes = data["gt_bboxes"] 71 | ignore_bboxes = data["gt_bboxes_ignore"] 72 | if len(bboxes) == 0: 73 | print(image.shape) 74 | widths.append(bboxes[:, 2] - bboxes[:, 0]) 75 | heights.append(bboxes[:, 3] - bboxes[:, 1]) 76 | draw_bounding_boxes_on_image(image, bboxes, use_normalized_coordinates=False, thickness=5) 77 | if len(ignore_bboxes): 78 | draw_bounding_boxes_on_image( 79 | image, 80 | ignore_bboxes, 81 | label2colors={None: {"bbox": (0, 255, 0)}}, 82 | use_normalized_coordinates=False, 83 | thickness=5, 84 | ) 85 | cv2.imwrite(osp.join(args.output_root, f"{i}_{image_id}"), image) 86 | widths = np.concatenate(widths) 87 | heights = np.concatenate(heights) 88 | clusters = kmeans(np.stack([heights, widths], axis=1), k=10) 89 | print(f"aspect rations: {clusters[:, 0] / clusters[:, 1]}") 90 | print(f"sizes: {np.sqrt(clusters[:, 0] * clusters[:, 1])}") 91 | plot_statistics(widths, heights, args.output_root) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /gwd/eda/sources.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | from shutil import copyfile 5 | 6 | import mmcv 7 | from tqdm import tqdm 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--annotation_path", default="/data/coco_train.json") 13 | parser.add_argument("--image_root", default="/data/train") 14 | parser.add_argument("--output_root", default="/data/sources") 15 | return parser.parse_args() 16 | 17 | 18 | def main(): 19 | args = parse_args() 20 | annotations = mmcv.load(args.annotation_path) 21 | for sample in tqdm(annotations["images"]): 22 | source_root = osp.join(args.output_root, sample["source"]) 23 | os.makedirs(source_root, exist_ok=True) 24 | copyfile(osp.join(args.image_root, sample["file_name"]), osp.join(source_root, sample["file_name"])) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /gwd/eda/submission.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import cv2 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | from gwd.eda.visualization import draw_bounding_boxes_on_image 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--submission_path", default="/data/pseudo_universe_submission.csv") 16 | parser.add_argument("--img_root", default="/data/test") 17 | parser.add_argument("--output_root", default="/data/eda") 18 | return parser.parse_args() 19 | 20 | 21 | def convert_bboxes(bboxes): 22 | bboxes = np.concatenate([bboxes[:, 1:], bboxes[:, :1]], axis=1) 23 | bboxes[:, 2] += bboxes[:, 0] 24 | bboxes[:, 3] += bboxes[:, 1] 25 | return bboxes 26 | 27 | 28 | def main(): 29 | args = parse_args() 30 | os.makedirs(args.output_root, exist_ok=True) 31 | submission = pd.read_csv(args.submission_path) 32 | for _, row in tqdm(submission.iterrows(), total=len(submission)): 33 | image = cv2.imread(osp.join(args.img_root, f"{row.image_id}.jpg")) 34 | bboxes = convert_bboxes(np.array(list(map(float, row.PredictionString.split()))).reshape(-1, 5)) 35 | draw_bounding_boxes_on_image(image, bboxes, use_normalized_coordinates=False, thickness=5) 36 | cv2.imwrite(osp.join(args.output_root, f"{row.image_id}.jpg"), image) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /gwd/eda/visualization.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | from typing import NoReturn 3 | 4 | import numpy as np 5 | import PIL.Image as Image 6 | import PIL.ImageDraw as ImageDraw 7 | import PIL.ImageFont as ImageFont 8 | 9 | DEFAULT_COLORS = {"bbox": (128, 0, 0), "text": (255, 255, 255)} 10 | 11 | 12 | def draw_bounding_box_on_image( 13 | image: Image, 14 | x_min, 15 | y_min, 16 | x_max, 17 | y_max, 18 | color, 19 | thickness=4, 20 | display_str=(), 21 | use_normalized_coordinates=True, 22 | fontsize=20, 23 | ) -> NoReturn: 24 | draw = ImageDraw.Draw(image) 25 | im_width, im_height = image.size 26 | if use_normalized_coordinates: 27 | (left, right, top, bottom) = (x_min * im_width, x_max * im_width, y_min * im_height, y_max * im_height) 28 | else: 29 | (left, right, top, bottom) = (x_min, x_max, y_min, y_max) 30 | draw.line( 31 | [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], width=thickness, fill=color["bbox"] 32 | ) 33 | 34 | try: 35 | font = ImageFont.truetype("/data/DejaVuSansMono.ttf", fontsize) 36 | except IOError: 37 | font = ImageFont.load_default() 38 | 39 | # If the total height of the display strings added to the top of the bounding 40 | # box exceeds the top of the image, stack the strings below the bounding box 41 | # instead of above. 42 | text_left = left 43 | # Reverse list and print from bottom to top. 44 | for display_str in display_str: 45 | text_width, text_height = font.getsize(display_str) 46 | margin = np.ceil(0.05 * text_width) 47 | draw.rectangle([(text_left, top - text_height), (text_left + text_width + 2 * margin, top)], fill=color["bbox"]) 48 | draw.text((text_left, top - text_height), display_str, fill=color["text"], font=font) 49 | text_left += text_width - 2 * margin 50 | 51 | 52 | def draw_bounding_boxes_on_image( 53 | image, 54 | bboxes, 55 | labels=(), 56 | label2colors=None, 57 | thickness=4, 58 | display_str_list=(), 59 | use_normalized_coordinates=True, 60 | fontsize=20, 61 | ): 62 | if label2colors is None: 63 | label2colors = {} 64 | image_pil = Image.fromarray(image) 65 | for bbox, label, display_str in zip_longest(bboxes, labels, display_str_list): 66 | draw_bounding_box_on_image( 67 | image=image_pil, 68 | x_min=bbox[0], 69 | y_min=bbox[1], 70 | x_max=bbox[2], 71 | y_max=bbox[3], 72 | color=label2colors.get(label, DEFAULT_COLORS), 73 | thickness=thickness, 74 | display_str=[] if display_str is None else display_str, 75 | use_normalized_coordinates=use_normalized_coordinates, 76 | fontsize=fontsize, 77 | ) 78 | np.copyto(image, np.array(image_pil)) 79 | -------------------------------------------------------------------------------- /gwd/jigsaw/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/jigsaw/__init__.py -------------------------------------------------------------------------------- /gwd/jigsaw/calculate_distance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on https://github.com/lRomul/argus-tgs-salt/blob/master/mosaic/create_mosaic.py 3 | """ 4 | import argparse 5 | import os 6 | import os.path as osp 7 | from multiprocessing import Pool 8 | 9 | import cv2 10 | import numpy as np 11 | import pandas as pd 12 | from tqdm import tqdm 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--annotation-path", default="/data/train.csv") 18 | parser.add_argument("--img-root", default="/data/train") 19 | parser.add_argument("--output-root", default="/data/distances") 20 | return parser.parse_args() 21 | 22 | 23 | def get_descriptor(image_path): 24 | img = cv2.imread(image_path).astype(float) 25 | return [(img[:, 0], img[:, 1]), (img[:, -2], img[:, -1]), (img[0, :], img[1, :]), (img[-2, :], img[-1, :])] 26 | 27 | 28 | def get_descriptors(image_paths): 29 | with Pool(32) as p: 30 | descriptors = list( 31 | tqdm(iterable=p.imap(get_descriptor, image_paths), total=len(image_paths), desc="Get descriptors...") 32 | ) 33 | return descriptors 34 | 35 | 36 | def calc_pair_metric(args): 37 | i, desc_i, j, desc_j = args 38 | if i != j: 39 | left_metric = calc_metric(desc_i[0], desc_j[1]) 40 | top_metric = calc_metric(desc_i[2], desc_j[3]) 41 | else: 42 | left_metric = top_metric = 1e6 43 | return left_metric, top_metric 44 | 45 | 46 | def norm(a): 47 | return np.mean(a ** 2) 48 | 49 | 50 | def calc_metric(d1, d2): 51 | """ 52 | d2 | d1 53 | [-2 -1] | [0 1] 54 | """ 55 | return (norm((d1[1] + d2[1] - 2 * d1[0]) / 2) + norm((d2[0] + d1[0] - 2 * d2[1]) / 2)) / 2 56 | 57 | 58 | def get_metrics(descriptors): 59 | n_samples = len(descriptors) 60 | left_matrix = np.zeros((n_samples, n_samples)) 61 | top_matrix = np.zeros((n_samples, n_samples)) 62 | for i, desc_i in tqdm(enumerate(descriptors), total=n_samples, desc="Get metrics..."): 63 | with Pool(32) as p: 64 | metrics = list((p.imap(calc_pair_metric, [(i, desc_i, j, desc_j) for j, desc_j in enumerate(descriptors)]))) 65 | for j, metric in enumerate(metrics): 66 | left_matrix[i, j], top_matrix[i, j] = metric 67 | return left_matrix, top_matrix 68 | 69 | 70 | def get_distances(image_paths): 71 | descriptors = get_descriptors(image_paths) 72 | left_matrix, top_matrix = get_metrics(descriptors) 73 | return left_matrix, top_matrix 74 | 75 | 76 | def main(): 77 | args = parse_args() 78 | os.makedirs(args.output_root, exist_ok=True) 79 | annotations = pd.read_csv(args.annotation_path).drop_duplicates("image_id") 80 | for source, source_annotations in annotations.groupby("source"): 81 | print(f"Source: {source}") 82 | image_paths = source_annotations["image_id"].apply(lambda x: osp.join(args.img_root, f"{x}.jpg")).tolist() 83 | left_matrix, top_matrix = get_distances(image_paths) 84 | np.savez( 85 | osp.join(args.output_root, f"{source}.npz"), 86 | left_matrix=left_matrix, 87 | top_matrix=top_matrix, 88 | paths=image_paths, 89 | ) 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /gwd/jigsaw/collect_bboxes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import os.path as osp 4 | 5 | import mmcv 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | IMG_SIZE = 1024 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--mosaics-path", default="/data/mosaics.json") 15 | parser.add_argument("--annotation-path", default="/data/train.csv") 16 | parser.add_argument("--output-path", default="/data/train_mosaic.csv") 17 | return parser.parse_args() 18 | 19 | 20 | def main(): 21 | args = parse_args() 22 | annotations = pd.read_csv(args.annotation_path, converters={"bbox": ast.literal_eval}) 23 | annotations["tile_id"] = 0 24 | mosaics = mmcv.load(args.mosaics_path) 25 | for mosaic in tqdm(mosaics): 26 | mosaic_image_id = "_".join([osp.basename(tile["path"]).split(".")[0] for tile in mosaic]) 27 | for i, tile in enumerate(mosaic): 28 | image_id = osp.basename(tile["path"]).split(".")[0] 29 | mask = annotations["image_id"] == image_id 30 | annotations.loc[mask, "bbox"] = annotations.loc[mask, "bbox"].apply( 31 | lambda bbox: [bbox[0] + IMG_SIZE * tile["i"], bbox[1] + IMG_SIZE * tile["j"], bbox[2], bbox[3]] 32 | ) 33 | annotations.loc[mask, "tile_id"] = i 34 | annotations.loc[mask, "image_id"] = mosaic_image_id 35 | mosaic_width = (max([x["i"] for x in mosaic]) + 1) * IMG_SIZE 36 | mosaic_height = (max([x["j"] for x in mosaic]) + 1) * IMG_SIZE 37 | annotations.loc[annotations["image_id"] == mosaic_image_id, "width"] = mosaic_width 38 | annotations.loc[annotations["image_id"] == mosaic_image_id, "height"] = mosaic_height 39 | print(annotations.head()) 40 | annotations.to_csv(args.output_path, index=False) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /gwd/jigsaw/crop.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | from functools import partial 5 | from multiprocessing import Pool 6 | 7 | import cv2 8 | import pandas as pd 9 | from tqdm import tqdm 10 | 11 | from gwd.converters import images2coco 12 | 13 | CROP_SIZE = 1024 14 | OFFSET_SIZE = 512 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--annotation-path", default="/data/folds_v2/0/mosaic_train.csv") 20 | parser.add_argument("--img-root", default="/data/mosaics") 21 | parser.add_argument("--output-root", default="/data/crops_fold0") 22 | parser.add_argument("--output-path", default="/data/coco_crops_fold0.json") 23 | return parser.parse_args() 24 | 25 | 26 | def crop_and_save(img_path, output_root): 27 | img = cv2.imread(img_path) 28 | h, w = img.shape[:2] 29 | if h <= 1024 and w <= 1024: 30 | return 31 | 32 | for i in range(0, h // OFFSET_SIZE - 1): 33 | for j in range(0, w // OFFSET_SIZE - 1): 34 | if i % 2 or j % 2: 35 | crop = img[i * OFFSET_SIZE : i * OFFSET_SIZE + CROP_SIZE, j * OFFSET_SIZE : j * OFFSET_SIZE + CROP_SIZE] 36 | img_name = osp.basename(img_path) 37 | crop_path = osp.join(output_root, f"{i}_{j}_{img_name}") 38 | cv2.imwrite(crop_path, crop) 39 | 40 | 41 | def main(): 42 | args = parse_args() 43 | os.makedirs(args.output_root, exist_ok=True) 44 | annotations = pd.read_csv(args.annotation_path) 45 | annotations["img_path"] = annotations["image_id"].apply(lambda x: f"{args.img_root}/{x}.jpg") 46 | img_paths = annotations["img_path"].drop_duplicates().tolist() 47 | with Pool(32) as p: 48 | list( 49 | tqdm(iterable=p.imap(partial(crop_and_save, output_root=args.output_root), img_paths), total=len(img_paths)) 50 | ) 51 | images2coco.main(img_pattern=osp.join(args.output_root, "*.jpg"), output_path=args.output_path) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /gwd/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/losses/__init__.py -------------------------------------------------------------------------------- /gwd/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch.nn.functional as F 4 | 5 | from mmdet.models.builder import LOSSES 6 | from mmdet.models.losses import CrossEntropyLoss, binary_cross_entropy, mask_cross_entropy 7 | from mmdet.models.losses.utils import weight_reduce_loss 8 | 9 | 10 | def label_smooth_cross_entropy( 11 | pred, label, weight=None, reduction="mean", avg_factor=None, class_weight=None, label_smooth=None 12 | ): 13 | # element-wise losses 14 | if label_smooth is None: 15 | loss = F.cross_entropy(pred, label, reduction="none") 16 | else: 17 | num_classes = pred.size(1) 18 | target = F.one_hot(label, num_classes).type_as(pred) 19 | target = target.sub_(label_smooth).clamp_(0).add_(label_smooth / num_classes) 20 | loss = F.kl_div(pred.log_softmax(1), target, reduction="none").sum(1) 21 | 22 | # apply weights and do the reduction 23 | if weight is not None: 24 | weight = weight.float() 25 | loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor) 26 | 27 | return loss 28 | 29 | 30 | @LOSSES.register_module() 31 | class LabelSmoothCrossEntropyLoss(CrossEntropyLoss): 32 | def __init__( 33 | self, use_sigmoid=False, use_mask=False, reduction="mean", loss_weight=1.0, class_weight=None, label_smooth=0.1 34 | ): 35 | super(LabelSmoothCrossEntropyLoss, self).__init__() 36 | assert (use_sigmoid is False) or (use_mask is False) 37 | self.use_sigmoid = use_sigmoid 38 | self.use_mask = use_mask 39 | self.reduction = reduction 40 | self.loss_weight = loss_weight 41 | self.class_weight = class_weight 42 | self.label_smooth = label_smooth 43 | 44 | if self.use_sigmoid: 45 | self.cls_criterion = binary_cross_entropy 46 | elif self.use_mask: 47 | self.cls_criterion = mask_cross_entropy 48 | elif self.label_smooth is None: 49 | self.cls_criterion = label_smooth_cross_entropy 50 | else: 51 | self.cls_criterion = partial(label_smooth_cross_entropy, label_smooth=self.label_smooth) 52 | -------------------------------------------------------------------------------- /gwd/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/misc/__init__.py -------------------------------------------------------------------------------- /gwd/misc/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.distributed as dist 4 | 5 | logger_initialized = {} 6 | 7 | 8 | def get_logger(name, log_file=None, log_mode="w", log_level=logging.INFO): 9 | logger = logging.getLogger(name) 10 | if name in logger_initialized: 11 | return logger 12 | # handle hierarchical names 13 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 14 | # initialization since it is a child of "a". 15 | for logger_name in logger_initialized: 16 | if name.startswith(logger_name): 17 | return logger 18 | 19 | stream_handler = logging.StreamHandler() 20 | handlers = [stream_handler] 21 | 22 | if dist.is_available() and dist.is_initialized(): 23 | rank = dist.get_rank() 24 | else: 25 | rank = 0 26 | 27 | # only rank 0 will add a FileHandler 28 | if rank == 0 and log_file is not None: 29 | file_handler = logging.FileHandler(log_file, log_mode) 30 | handlers.append(file_handler) 31 | 32 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 33 | for handler in handlers: 34 | handler.setFormatter(formatter) 35 | handler.setLevel(log_level) 36 | logger.addHandler(handler) 37 | 38 | if rank == 0: 39 | logger.setLevel(log_level) 40 | else: 41 | logger.setLevel(logging.ERROR) 42 | 43 | logger_initialized[name] = True 44 | 45 | return logger 46 | -------------------------------------------------------------------------------- /gwd/necks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/necks/__init__.py -------------------------------------------------------------------------------- /gwd/necks/sepc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from mmdet.core import auto_fp16 6 | from mmdet.models.builder import NECKS 7 | 8 | from .sepc_dconv import SEPCConv 9 | 10 | 11 | @NECKS.register_module() 12 | class SEPC(nn.Module): 13 | def __init__( 14 | self, 15 | in_channels=[256] * 5, 16 | out_channels=256, 17 | num_outs=5, 18 | stacked_convs=4, 19 | pconv_deform=False, 20 | lcconv_deform=False, 21 | ibn=False, 22 | lcconv_padding=0, 23 | ): 24 | super(SEPC, self).__init__() 25 | assert isinstance(in_channels, list) 26 | self.in_channels = in_channels 27 | self.out_channels = out_channels 28 | self.num_ins = len(in_channels) 29 | self.num_outs = num_outs 30 | assert num_outs == 5 31 | self.fp16_enabled = False 32 | self.ibn = ibn 33 | self.pconvs = nn.ModuleList() 34 | 35 | for i in range(stacked_convs): 36 | self.pconvs.append(PConvModule(in_channels[i], out_channels, ibn=self.ibn, part_deform=pconv_deform)) 37 | 38 | self.lconv = SEPCConv(256, 256, kernel_size=3, padding=lcconv_padding, dilation=1, part_deform=lcconv_deform) 39 | self.cconv = SEPCConv(256, 256, kernel_size=3, padding=lcconv_padding, dilation=1, part_deform=lcconv_deform) 40 | self.relu = nn.ReLU() 41 | if self.ibn: 42 | self.lbn = nn.BatchNorm2d(256) 43 | self.cbn = nn.BatchNorm2d(256) 44 | self.init_weights() 45 | 46 | def init_weights(self): 47 | for str in ["l", "c"]: 48 | m = getattr(self, str + "conv") 49 | nn.init.normal_(m.weight.data, 0, 0.01) 50 | if m.bias is not None: 51 | m.bias.data.zero_() 52 | 53 | @auto_fp16() 54 | def forward(self, inputs): 55 | assert len(inputs) == len(self.in_channels) 56 | x = inputs 57 | for pconv in self.pconvs: 58 | x = pconv(x) 59 | cls_feats = [self.cconv(level, item) for level, item in enumerate(x)] 60 | loc_feats = [self.lconv(level, item) for level, item in enumerate(x)] 61 | if self.ibn: 62 | cls_feats = integrated_bn(cls_feats, self.cbn) 63 | loc_feats = integrated_bn(loc_feats, self.lbn) 64 | outs = [[self.relu(cls_feat), self.relu(loc_feat)] for cls_feat, loc_feat in zip(cls_feats, loc_feats)] 65 | return tuple(outs) 66 | 67 | 68 | class PConvModule(nn.Module): 69 | def __init__( 70 | self, 71 | in_channels=256, 72 | out_channels=256, 73 | kernel_size=[3, 3, 3], 74 | dilation=[1, 1, 1], 75 | groups=[1, 1, 1], 76 | ibn=False, 77 | part_deform=False, 78 | ): 79 | super(PConvModule, self).__init__() 80 | 81 | self.ibn = ibn 82 | self.pconv = nn.ModuleList() 83 | self.pconv.append( 84 | SEPCConv( 85 | in_channels, 86 | out_channels, 87 | kernel_size=kernel_size[0], 88 | dilation=dilation[0], 89 | groups=groups[0], 90 | padding=(kernel_size[0] + (dilation[0] - 1) * 2) // 2, 91 | part_deform=part_deform, 92 | ) 93 | ) 94 | self.pconv.append( 95 | SEPCConv( 96 | in_channels, 97 | out_channels, 98 | kernel_size=kernel_size[1], 99 | dilation=dilation[1], 100 | groups=groups[1], 101 | padding=(kernel_size[1] + (dilation[1] - 1) * 2) // 2, 102 | part_deform=part_deform, 103 | ) 104 | ) 105 | self.pconv.append( 106 | SEPCConv( 107 | in_channels, 108 | out_channels, 109 | kernel_size=kernel_size[2], 110 | dilation=dilation[2], 111 | groups=groups[2], 112 | padding=(kernel_size[2] + (dilation[2] - 1) * 2) // 2, 113 | stride=2, 114 | part_deform=part_deform, 115 | ) 116 | ) 117 | 118 | if self.ibn: 119 | self.bn = nn.BatchNorm2d(256) 120 | 121 | self.relu = nn.ReLU() 122 | self.init_weights() 123 | 124 | def init_weights(self): 125 | for m in self.pconv: 126 | nn.init.normal_(m.weight.data, 0, 0.01) 127 | if m.bias is not None: 128 | m.bias.data.zero_() 129 | 130 | def forward(self, x): 131 | next_x = [] 132 | for level, feature in enumerate(x): 133 | temp_fea = self.pconv[1](level, feature) 134 | if level > 0: 135 | temp_fea += self.pconv[2](level, x[level - 1]) 136 | if level < len(x) - 1: 137 | temp_fea += F.interpolate( 138 | self.pconv[0](level, x[level + 1]), 139 | size=[temp_fea.size(2), temp_fea.size(3)], 140 | mode="bilinear", 141 | align_corners=True, 142 | ) 143 | next_x.append(temp_fea) 144 | if self.ibn: 145 | next_x = integrated_bn(next_x, self.bn) 146 | next_x = [self.relu(item) for item in next_x] 147 | return next_x 148 | 149 | 150 | def integrated_bn(fms, bn): 151 | sizes = [p.shape[2:] for p in fms] 152 | n, c = fms[0].shape[0], fms[0].shape[1] 153 | fm = torch.cat([p.view(n, c, 1, -1) for p in fms], dim=-1) 154 | fm = bn(fm) 155 | fm = torch.split(fm, [s[0] * s[1] for s in sizes], dim=-1) 156 | return [p.view(n, c, s[0], s[1]) for p, s in zip(fm, sizes)] 157 | -------------------------------------------------------------------------------- /gwd/necks/sepc_dconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.utils import _pair 4 | 5 | from mmdet.ops.dcn import DeformConv, ModulatedDeformConv, deform_conv, modulated_deform_conv 6 | 7 | 8 | class SEPCConv(DeformConv): 9 | def __init__(self, *args, part_deform=False, **kwargs): 10 | super(SEPCConv, self).__init__(*args, **kwargs) 11 | self.part_deform = part_deform 12 | if self.part_deform: 13 | self.conv_offset = nn.Conv2d( 14 | self.in_channels, 15 | self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], 16 | kernel_size=self.kernel_size, 17 | stride=_pair(self.stride), 18 | padding=_pair(self.padding), 19 | dilation=_pair(self.dilation), 20 | bias=True, 21 | ) 22 | self.init_offset() 23 | 24 | self.bias = nn.Parameter(torch.zeros(self.out_channels)) 25 | self.start_level = 1 26 | 27 | def init_offset(self): 28 | self.conv_offset.weight.data.zero_() 29 | self.conv_offset.bias.data.zero_() 30 | 31 | def forward(self, i, x): 32 | if i < self.start_level or not self.part_deform: 33 | return torch.nn.functional.conv2d( 34 | x, 35 | self.weight, 36 | bias=self.bias, 37 | stride=self.stride, 38 | padding=self.padding, 39 | dilation=self.dilation, 40 | groups=self.groups, 41 | ) 42 | 43 | offset = self.conv_offset(x) 44 | return deform_conv( 45 | x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups 46 | ) + self.bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 47 | 48 | 49 | class ModulatedSEPCConv(ModulatedDeformConv): 50 | 51 | _version = 2 52 | 53 | def __init__(self, *args, part_deform=False, **kwargs): 54 | super(ModulatedSEPCConv, self).__init__(*args, **kwargs) 55 | self.part_deform = part_deform 56 | if self.part_deform: 57 | self.conv_offset = nn.Conv2d( 58 | self.in_channels, 59 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], 60 | kernel_size=self.kernel_size, 61 | stride=_pair(self.stride), 62 | padding=_pair(self.padding), 63 | dilation=_pair(self.dilation), 64 | bias=True, 65 | ) 66 | self.init_offset() 67 | 68 | self.bias = nn.Parameter(torch.zeros(self.out_channels)) 69 | self.start_level = 1 70 | 71 | def init_offset(self): 72 | self.conv_offset.weight.data.zero_() 73 | self.conv_offset.bias.data.zero_() 74 | 75 | def forward(self, i, x): 76 | if i < self.start_level or not self.part_deform: 77 | return torch.nn.functional.conv2d( 78 | x, 79 | self.weight, 80 | bias=self.bias, 81 | stride=self.stride, 82 | padding=self.padding, 83 | dilation=self.dilation, 84 | groups=self.groups, 85 | ) 86 | 87 | out = self.conv_offset(x) 88 | o1, o2, mask = torch.chunk(out, 3, dim=1) 89 | offset = torch.cat((o1, o2), dim=1) 90 | mask = torch.sigmoid(mask) 91 | 92 | return modulated_deform_conv( 93 | x, 94 | offset, 95 | mask, 96 | self.weight, 97 | None, 98 | self.stride, 99 | self.padding, 100 | self.dilation, 101 | self.groups, 102 | self.deformable_groups, 103 | ) + self.bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 104 | -------------------------------------------------------------------------------- /gwd/patches.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import build_from_cfg 2 | 3 | from mmdet.datasets.builder import DATASETS, _concat_dataset 4 | from mmdet.datasets.dataset_wrappers import ClassBalancedDataset, ConcatDataset, RepeatDataset 5 | 6 | from .datasets.source_balanced_dataset import SourceBalancedDataset 7 | 8 | 9 | def build_dataset(cfg, default_args=None): 10 | if isinstance(cfg, (list, tuple)): 11 | dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) 12 | elif cfg["type"] == "RepeatDataset": 13 | dataset = RepeatDataset(build_dataset(cfg["dataset"], default_args), cfg["times"]) 14 | elif cfg["type"] == "ClassBalancedDataset": 15 | dataset = ClassBalancedDataset(build_dataset(cfg["dataset"], default_args), cfg["oversample_thr"]) 16 | elif cfg["type"] == "SourceBalancedDataset": 17 | dataset = SourceBalancedDataset(build_dataset(cfg["dataset"], default_args), cfg["oversample_thr"]) 18 | elif isinstance(cfg.get("ann_file"), (list, tuple)): 19 | dataset = _concat_dataset(cfg, default_args) 20 | else: 21 | dataset = build_from_cfg(cfg, DATASETS, default_args) 22 | 23 | return dataset 24 | -------------------------------------------------------------------------------- /gwd/prepare_pseudo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mmcv 4 | from mmcv import Config 5 | 6 | from gwd.datasets.wheat_detection import WheatDataset 7 | from mmdet.datasets import build_dataset 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--config-path", default="configs/detectors/detectors_r50_ga_mstrain_stage2.py") 13 | parser.add_argument("--ann-file", default="/data/coco_crops_fold0.json") 14 | parser.add_argument("--output-path", default="/data/folds_v2/0/coco_pseudo_train.json") 15 | parser.add_argument("--predictions-path", default="/data/crops_fold0_predictions.pkl") 16 | parser.add_argument("--fold", default=0, type=int) 17 | return parser.parse_args() 18 | 19 | 20 | def main(): 21 | args = parse_args() 22 | cfg = Config.fromfile(args.config_path) 23 | cfg.data.test.ann_file = cfg.data.test.ann_file.format(fold=args.fold) 24 | cfg.data.test.test_mode = True 25 | dataset: WheatDataset = build_dataset(cfg.data.test) 26 | predictions = mmcv.load(args.predictions_path) 27 | print(len(predictions), len(dataset)) 28 | dataset.pseudo_results(predictions, output_path=args.output_path, pseudo_confidence_threshold=0.8) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /gwd/select_anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from gwd.eda.kmeans import kmeans 5 | from mmdet.core.anchor import AnchorGenerator, build_anchor_generator 6 | 7 | 8 | def main(): 9 | anchor_generator_cfg = dict(type="AnchorGenerator", scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]) 10 | anchor_generator: AnchorGenerator = build_anchor_generator(anchor_generator_cfg) 11 | multi_level_anchors = anchor_generator.grid_anchors( 12 | featmap_sizes=[ 13 | torch.Size([256, 256]), 14 | torch.Size([128, 128]), 15 | torch.Size([64, 64]), 16 | torch.Size([32, 32]), 17 | torch.Size([16, 16]), 18 | ], 19 | device="cpu", 20 | ) 21 | anchors = torch.cat(multi_level_anchors).numpy() 22 | widths = anchors[:, 2] - anchors[:, 0] 23 | heights = anchors[:, 3] - anchors[:, 1] 24 | data = np.stack([heights, widths], axis=1) 25 | clusters = kmeans(data, k=50) 26 | print(f"aspect rations: {clusters[: 0] / clusters[: 1]}") 27 | print(f"sizes: {np.sqrt(clusters[: 0] * clusters[: 1])}") 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /gwd/split_folds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import os 4 | import os.path as osp 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from iterstrat.ml_stratifiers import MultilabelStratifiedKFold 10 | 11 | SOURCES = ["ethz_1", "arvalis_1", "arvalis_3", "usask_1", "rres_1", "inrae_1", "arvalis_2"] 12 | VAL_SOURCES = ["usask_1", "ethz_1"] 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--mosaic_path", default="/data/train_mosaic.csv") 18 | parser.add_argument("--annotation_path", default="/data/train.csv") 19 | parser.add_argument("--n_splits", type=int, default=5) 20 | parser.add_argument("--output_root", default="/data/folds_v2") 21 | return parser.parse_args() 22 | 23 | 24 | def save_split(annotations, val_ids, output_root, prefix): 25 | os.makedirs(output_root, exist_ok=True) 26 | train = annotations[~annotations["image_id"].isin(val_ids)] 27 | val = annotations[annotations["image_id"].isin(val_ids)] 28 | print(f"{prefix} train length: {len(set(train['image_id']))}") 29 | print(f"{prefix} val length: {len(set(val['image_id']))}\n") 30 | train.to_csv(osp.join(output_root, f"{prefix}_train.csv"), index=False) 31 | val.to_csv(osp.join(output_root, f"{prefix}_val.csv"), index=False) 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | os.makedirs(args.output_root, exist_ok=True) 37 | tile_annotations = pd.read_csv(args.annotation_path) 38 | mosaic_annotations = pd.read_csv(args.mosaic_path, converters={"bbox": ast.literal_eval}) 39 | mosaic_annotations["num_of_bboxes"] = mosaic_annotations["image_id"].map( 40 | mosaic_annotations["image_id"].value_counts() 41 | ) 42 | mosaic_annotations["median_area"] = mosaic_annotations["bbox"].apply(lambda x: np.sqrt(x[-1] * x[-2])) 43 | mosaic_annotations["source_index"] = mosaic_annotations["source"].apply(lambda x: SOURCES.index(x)) 44 | images = ( 45 | mosaic_annotations[["image_id", "source_index", "median_area", "num_of_bboxes", "source"]] 46 | .copy() 47 | .drop_duplicates("image_id") 48 | ) 49 | images = images[~images["source"].isin(VAL_SOURCES)] 50 | splitter = MultilabelStratifiedKFold(n_splits=args.n_splits, shuffle=True, random_state=3) 51 | for i, (train_index, test_index) in enumerate( 52 | splitter.split(images, images[["source_index", "median_area", "num_of_bboxes"]]) 53 | ): 54 | mosaic_val_ids = images.iloc[test_index, images.columns.get_loc("image_id")] 55 | tile_val_ids = sum([x.split("_") for x in mosaic_val_ids], []) 56 | 57 | fold_root = osp.join(args.output_root, str(i)) 58 | 59 | save_split(mosaic_annotations, mosaic_val_ids, fold_root, prefix="mosaic") 60 | save_split(tile_annotations, tile_val_ids, fold_root, prefix="tile") 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /gwd/stylize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/stylize/__init__.py -------------------------------------------------------------------------------- /gwd/stylize/function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_mean_std(feat, eps=1e-5): 5 | # eps is a small value added to the variance to avoid divide-by-zero. 6 | size = feat.data.size() 7 | assert len(size) == 4 8 | N, C = size[:2] 9 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 10 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 11 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 12 | return feat_mean, feat_std 13 | 14 | 15 | def adaptive_instance_normalization(content_feat, style_feat): 16 | assert content_feat.data.size()[:2] == style_feat.data.size()[:2] 17 | size = content_feat.data.size() 18 | style_mean, style_std = calc_mean_std(style_feat) 19 | content_mean, content_std = calc_mean_std(content_feat) 20 | 21 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 22 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 23 | 24 | 25 | def _calc_feat_flatten_mean_std(feat): 26 | # takes 3D feat (C, H, W), return mean and std of array within channels 27 | assert feat.size()[0] == 3 28 | assert isinstance(feat, torch.FloatTensor) 29 | feat_flatten = feat.view(3, -1) 30 | mean = feat_flatten.mean(dim=-1, keepdim=True) 31 | std = feat_flatten.std(dim=-1, keepdim=True) 32 | return feat_flatten, mean, std 33 | 34 | 35 | def _mat_sqrt(x): 36 | U, D, V = torch.svd(x) 37 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) 38 | 39 | 40 | def coral(source, target): 41 | # assume both source and target are 3D array (C, H, W) 42 | # Note: flatten -> f 43 | 44 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) 45 | source_f_norm = (source_f - source_f_mean.expand_as(source_f)) / source_f_std.expand_as(source_f) 46 | source_f_cov_eye = torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) 47 | 48 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) 49 | target_f_norm = (target_f - target_f_mean.expand_as(target_f)) / target_f_std.expand_as(target_f) 50 | target_f_cov_eye = torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) 51 | 52 | source_f_norm_transfer = torch.mm( 53 | _mat_sqrt(target_f_cov_eye), torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), source_f_norm) 54 | ) 55 | 56 | source_f_transfer = source_f_norm_transfer * target_f_std.expand_as(source_f_norm) + target_f_mean.expand_as( 57 | source_f_norm 58 | ) 59 | 60 | return source_f_transfer.view(source.size()) 61 | -------------------------------------------------------------------------------- /gwd/stylize/net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | 4 | from .function import adaptive_instance_normalization as adain 5 | from .function import calc_mean_std 6 | 7 | decoder = nn.Sequential( 8 | nn.ReflectionPad2d((1, 1, 1, 1)), 9 | nn.Conv2d(512, 256, (3, 3)), 10 | nn.ReLU(), 11 | nn.Upsample(scale_factor=2), 12 | nn.ReflectionPad2d((1, 1, 1, 1)), 13 | nn.Conv2d(256, 256, (3, 3)), 14 | nn.ReLU(), 15 | nn.ReflectionPad2d((1, 1, 1, 1)), 16 | nn.Conv2d(256, 256, (3, 3)), 17 | nn.ReLU(), 18 | nn.ReflectionPad2d((1, 1, 1, 1)), 19 | nn.Conv2d(256, 256, (3, 3)), 20 | nn.ReLU(), 21 | nn.ReflectionPad2d((1, 1, 1, 1)), 22 | nn.Conv2d(256, 128, (3, 3)), 23 | nn.ReLU(), 24 | nn.Upsample(scale_factor=2), 25 | nn.ReflectionPad2d((1, 1, 1, 1)), 26 | nn.Conv2d(128, 128, (3, 3)), 27 | nn.ReLU(), 28 | nn.ReflectionPad2d((1, 1, 1, 1)), 29 | nn.Conv2d(128, 64, (3, 3)), 30 | nn.ReLU(), 31 | nn.Upsample(scale_factor=2), 32 | nn.ReflectionPad2d((1, 1, 1, 1)), 33 | nn.Conv2d(64, 64, (3, 3)), 34 | nn.ReLU(), 35 | nn.ReflectionPad2d((1, 1, 1, 1)), 36 | nn.Conv2d(64, 3, (3, 3)), 37 | ) 38 | 39 | vgg = nn.Sequential( 40 | nn.Conv2d(3, 3, (1, 1)), 41 | nn.ReflectionPad2d((1, 1, 1, 1)), 42 | nn.Conv2d(3, 64, (3, 3)), 43 | nn.ReLU(), # relu1-1 44 | nn.ReflectionPad2d((1, 1, 1, 1)), 45 | nn.Conv2d(64, 64, (3, 3)), 46 | nn.ReLU(), # relu1-2 47 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 48 | nn.ReflectionPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(64, 128, (3, 3)), 50 | nn.ReLU(), # relu2-1 51 | nn.ReflectionPad2d((1, 1, 1, 1)), 52 | nn.Conv2d(128, 128, (3, 3)), 53 | nn.ReLU(), # relu2-2 54 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 55 | nn.ReflectionPad2d((1, 1, 1, 1)), 56 | nn.Conv2d(128, 256, (3, 3)), 57 | nn.ReLU(), # relu3-1 58 | nn.ReflectionPad2d((1, 1, 1, 1)), 59 | nn.Conv2d(256, 256, (3, 3)), 60 | nn.ReLU(), # relu3-2 61 | nn.ReflectionPad2d((1, 1, 1, 1)), 62 | nn.Conv2d(256, 256, (3, 3)), 63 | nn.ReLU(), # relu3-3 64 | nn.ReflectionPad2d((1, 1, 1, 1)), 65 | nn.Conv2d(256, 256, (3, 3)), 66 | nn.ReLU(), # relu3-4 67 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 68 | nn.ReflectionPad2d((1, 1, 1, 1)), 69 | nn.Conv2d(256, 512, (3, 3)), 70 | nn.ReLU(), # relu4-1, this is the last layer used 71 | nn.ReflectionPad2d((1, 1, 1, 1)), 72 | nn.Conv2d(512, 512, (3, 3)), 73 | nn.ReLU(), # relu4-2 74 | nn.ReflectionPad2d((1, 1, 1, 1)), 75 | nn.Conv2d(512, 512, (3, 3)), 76 | nn.ReLU(), # relu4-3 77 | nn.ReflectionPad2d((1, 1, 1, 1)), 78 | nn.Conv2d(512, 512, (3, 3)), 79 | nn.ReLU(), # relu4-4 80 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 81 | nn.ReflectionPad2d((1, 1, 1, 1)), 82 | nn.Conv2d(512, 512, (3, 3)), 83 | nn.ReLU(), # relu5-1 84 | nn.ReflectionPad2d((1, 1, 1, 1)), 85 | nn.Conv2d(512, 512, (3, 3)), 86 | nn.ReLU(), # relu5-2 87 | nn.ReflectionPad2d((1, 1, 1, 1)), 88 | nn.Conv2d(512, 512, (3, 3)), 89 | nn.ReLU(), # relu5-3 90 | nn.ReflectionPad2d((1, 1, 1, 1)), 91 | nn.Conv2d(512, 512, (3, 3)), 92 | nn.ReLU(), # relu5-4 93 | ) 94 | 95 | 96 | class Net(nn.Module): 97 | def __init__(self, encoder, decoder): 98 | super(Net, self).__init__() 99 | enc_layers = list(encoder.children()) 100 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 101 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 102 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 103 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 104 | self.decoder = decoder 105 | self.mse_loss = nn.MSELoss() 106 | 107 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 108 | def encode_with_intermediate(self, input): 109 | results = [input] 110 | for i in range(4): 111 | func = getattr(self, "enc_{:d}".format(i + 1)) 112 | results.append(func(results[-1])) 113 | return results[1:] 114 | 115 | # extract relu4_1 from input image 116 | def encode(self, input): 117 | for i in range(4): 118 | input = getattr(self, "enc_{:d}".format(i + 1))(input) 119 | return input 120 | 121 | def calc_content_loss(self, input, target): 122 | assert input.data.size() == target.data.size() 123 | assert target.requires_grad is False 124 | return self.mse_loss(input, target) 125 | 126 | def calc_style_loss(self, input, target): 127 | assert input.data.size() == target.data.size() 128 | assert target.requires_grad is False 129 | input_mean, input_std = calc_mean_std(input) 130 | target_mean, target_std = calc_mean_std(target) 131 | return self.mse_loss(input_mean, target_mean) + self.mse_loss(input_std, target_std) 132 | 133 | def forward(self, content, style): 134 | style_feats = self.encode_with_intermediate(style) 135 | t = adain(self.encode(content), style_feats[-1]) 136 | 137 | g_t = self.decoder(Variable(t.data, requires_grad=True)) 138 | g_t_feats = self.encode_with_intermediate(g_t) 139 | 140 | loss_c = self.calc_content_loss(g_t_feats[-1], t) 141 | loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0]) 142 | for i in range(1, 4): 143 | loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i]) 144 | return loss_c, loss_s 145 | -------------------------------------------------------------------------------- /gwd/stylize/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms 8 | from PIL import Image 9 | from torchvision.utils import save_image 10 | from tqdm import tqdm 11 | 12 | from gwd.stylize import net 13 | from gwd.stylize.function import adaptive_instance_normalization 14 | 15 | parser = argparse.ArgumentParser( 16 | description="This script applies the AdaIN style transfer method to arbitrary datasets." 17 | ) 18 | parser.add_argument( 19 | "--content-dir", default="/data/SPIKE_Dataset/images", help="Directory path to a batch of content images" 20 | ) 21 | parser.add_argument("--style-dir", default="/data/test", help="Directory path to a batch of style images") 22 | parser.add_argument( 23 | "--output-dir", default="/data/SPIKE_Dataset/stylized_images_v4", help="Directory to save the output images" 24 | ) 25 | parser.add_argument( 26 | "--num-styles", 27 | type=int, 28 | default=1, 29 | help="Number of styles to \ 30 | create for each image (default: 1)", 31 | ) 32 | parser.add_argument( 33 | "--alpha", 34 | type=float, 35 | default=0.7, 36 | help="The weight that controls the degree of \ 37 | stylization. Should be between 0 and 1", 38 | ) 39 | parser.add_argument( 40 | "--extensions", 41 | nargs="+", 42 | type=str, 43 | default=["png", "jpeg", "jpg"], 44 | help="List of image extensions to scan style and content directory for (case sensitive), default: png, jpeg, jpg", 45 | ) 46 | 47 | # Advanced options 48 | parser.add_argument( 49 | "--content-size", 50 | type=int, 51 | default=0, 52 | help="New (minimum) size for the content image, \ 53 | keeping the original size if set to 0", 54 | ) 55 | parser.add_argument( 56 | "--style-size", 57 | type=int, 58 | default=1024, 59 | help="New (minimum) size for the style image, \ 60 | keeping the original size if set to 0", 61 | ) 62 | parser.add_argument( 63 | "--crop", 64 | type=int, 65 | default=0, 66 | help="If set to anything else than 0, center crop of this size will be applied to the content image \ 67 | after resizing in order to create a squared image (default: 0)", 68 | ) 69 | 70 | 71 | # random.seed(131213) 72 | 73 | 74 | def input_transform(size, crop): 75 | transform_list = [] 76 | if size != 0: 77 | transform_list.append(torchvision.transforms.Resize(size)) 78 | if crop != 0: 79 | transform_list.append(torchvision.transforms.CenterCrop(crop)) 80 | transform_list.append(torchvision.transforms.ToTensor()) 81 | transform = torchvision.transforms.Compose(transform_list) 82 | return transform 83 | 84 | 85 | def style_transfer(vgg, decoder, content, style, alpha=1.0): 86 | assert 0.0 <= alpha <= 1.0 87 | content_f = vgg(content) 88 | style_f = vgg(style) 89 | feat = adaptive_instance_normalization(content_f, style_f) 90 | feat = feat * alpha + content_f * (1 - alpha) 91 | return decoder(feat) 92 | 93 | 94 | def main(): 95 | args = parser.parse_args() 96 | 97 | # set content and style directories 98 | content_dir = Path(args.content_dir) 99 | style_dir = Path(args.style_dir) 100 | style_dir = style_dir.resolve() 101 | output_dir = Path(args.output_dir) 102 | output_dir = output_dir.resolve() 103 | assert style_dir.is_dir(), "Style directory not found" 104 | 105 | # collect content files 106 | extensions = args.extensions 107 | assert len(extensions) > 0, "No file extensions specified" 108 | content_dir = Path(content_dir) 109 | content_dir = content_dir.resolve() 110 | assert content_dir.is_dir(), "Content directory not found" 111 | dataset = [] 112 | for ext in extensions: 113 | dataset += list(content_dir.rglob("*." + ext)) 114 | 115 | assert len(dataset) > 0, "No images with specified extensions found in content directory" + content_dir 116 | content_paths = sorted(dataset) 117 | print("Found %d content images in %s" % (len(content_paths), content_dir)) 118 | 119 | # collect style files 120 | styles = [] 121 | for ext in extensions: 122 | styles += list(style_dir.rglob("*." + ext)) 123 | 124 | assert len(styles) > 0, "No images with specified extensions found in style directory" + style_dir 125 | styles = sorted(styles) 126 | print("Found %d style images in %s" % (len(styles), style_dir)) 127 | 128 | decoder = net.decoder 129 | vgg = net.vgg 130 | 131 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 132 | 133 | decoder.eval() 134 | vgg.eval() 135 | 136 | decoder.load_state_dict(torch.load("/dumps/decoder.pth")) 137 | vgg.load_state_dict(torch.load("/dumps/vgg_normalised.pth")) 138 | vgg = nn.Sequential(*list(vgg.children())[:31]) 139 | 140 | vgg.to(device) 141 | decoder.to(device) 142 | 143 | content_tf = input_transform(args.content_size, args.crop) 144 | style_tf = input_transform(args.style_size, 0) 145 | 146 | # disable decompression bomb errors 147 | Image.MAX_IMAGE_PIXELS = None 148 | skipped_imgs = [] 149 | 150 | # actual style transfer as in AdaIN 151 | with tqdm(total=len(content_paths)) as pbar: 152 | for content_path in content_paths: 153 | try: 154 | content_img = Image.open(content_path).convert("RGB") 155 | for style_path in random.sample(styles, args.num_styles): 156 | style_img = Image.open(style_path).convert("RGB") 157 | 158 | content = content_tf(content_img) 159 | style = style_tf(style_img) 160 | style = style.to(device).unsqueeze(0) 161 | content = content.to(device).unsqueeze(0) 162 | with torch.no_grad(): 163 | output = style_transfer(vgg, decoder, content, style, args.alpha) 164 | output = output.cpu() 165 | 166 | rel_path = content_path.relative_to(content_dir) 167 | out_dir = output_dir.joinpath(rel_path.parent) 168 | 169 | # create directory structure if it does not exist 170 | if not out_dir.is_dir(): 171 | out_dir.mkdir(parents=True) 172 | 173 | content_name = content_path.stem 174 | # style_name = style_path.stem 175 | # out_filename = ( 176 | # content_name + "-stylized-" + style_name + content_path.suffix 177 | # ) 178 | out_filename = content_name + content_path.suffix 179 | output_name = out_dir.joinpath(out_filename) 180 | 181 | # save_image( 182 | # torch.cat([style.cpu(), content.cpu(), output]), output_name, padding=0 183 | # ) 184 | save_image(output, output_name, padding=0) # default image padding is 2. 185 | style_img.close() 186 | content_img.close() 187 | except OSError as e: 188 | print("Skipping stylization of %s due to an error" % (content_path), e) 189 | skipped_imgs.append(content_path) 190 | continue 191 | except RuntimeError as e: 192 | print("Skipping stylization of %s due to an error" % (content_path), e) 193 | skipped_imgs.append(content_path) 194 | continue 195 | finally: 196 | pbar.update(1) 197 | 198 | if len(skipped_imgs) > 0: 199 | with open(output_dir.joinpath("skipped_imgs.txt"), "w") as f: 200 | for item in skipped_imgs: 201 | f.write("%s\n" % item) 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /gwd/submit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | from gwd import test, wbf 5 | from gwd.converters import images2coco 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--img_prefix", default="/data/test") 11 | parser.add_argument("--flip", action="store_true") 12 | parser.add_argument("--ann_file", default="/data/test_coco.json") 13 | parser.add_argument( 14 | "--configs", 15 | default=[ 16 | "configs/rfp_pseudo/rfp_r50_ga_mstrain_pseudo_stage1.py", 17 | "configs/universe_spike/universe_r50_mstrain_spike_stage1.py", 18 | ], 19 | nargs="+", 20 | ) 21 | parser.add_argument( 22 | "--checkpoints", 23 | default=[ 24 | "/dumps/rfp_r50_ga_mstrain_pseudo_stage1_epoch_10.pth", 25 | "/dumps/universe_r50_mstrain_spike_stage1_epoch_12.pth", 26 | ], 27 | nargs="+", 28 | ) 29 | parser.add_argument("--weights", default=[0.75, 0.25], type=float, nargs="+") 30 | parser.add_argument("--submission_path", default="/data/test_submission.csv") 31 | parser.add_argument("--pseudo_path", default="/data/coco_pseudo_test.json") 32 | parser.add_argument("--iou_thr", default=0.55, type=float) 33 | parser.add_argument("--score_thr", default=0.45, type=float) 34 | return parser.parse_args() 35 | 36 | 37 | def main( 38 | img_prefix, 39 | configs, 40 | checkpoints, 41 | weights, 42 | submission_path, 43 | ann_file, 44 | pseudo_path=None, 45 | format_only=True, 46 | flip=False, 47 | iou_thr=0.55, 48 | score_thr=0.45, 49 | pseudo_score_threshold=0.8, 50 | pseudo_confidence_threshold=0.65, 51 | ): 52 | images2coco.main(osp.join(img_prefix, "*"), output_path=ann_file) 53 | dataset = None 54 | all_predictions = [] 55 | for config, checkpoint in zip(configs, checkpoints): 56 | print(f"\nconfig: {config}, checkpoint: {checkpoint}") 57 | predictions, dataset = test.main( 58 | config=config, 59 | checkpoint=checkpoint, 60 | img_prefix=img_prefix, 61 | ann_file=ann_file, 62 | format_only=format_only, 63 | options=dict(output_path=submission_path), 64 | flip=flip, 65 | ) 66 | all_predictions.append(predictions) 67 | if len(all_predictions) > 1: 68 | ensemble_predictions = wbf.main(all_predictions, weights=weights, iou_thr=iou_thr, score_thr=score_thr) 69 | else: 70 | ensemble_predictions = all_predictions[0] 71 | dataset.format_results(results=ensemble_predictions, output_path=submission_path) 72 | if pseudo_path is not None: 73 | dataset.pseudo_results( 74 | results=ensemble_predictions, 75 | output_path=pseudo_path, 76 | pseudo_score_threshold=pseudo_score_threshold, 77 | pseudo_confidence_threshold=pseudo_confidence_threshold, 78 | ) 79 | 80 | 81 | if __name__ == "__main__": 82 | main(**vars(parse_args())) 83 | -------------------------------------------------------------------------------- /gwd/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Tuple 4 | 5 | import mmcv 6 | import torch 7 | from mmcv import Config, DictAction 8 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 9 | from mmcv.runner import get_dist_info, init_dist, load_checkpoint 10 | from mmcv.utils import print_log 11 | 12 | from gwd.datasets.wheat_detection import WheatDataset 13 | from gwd.misc.logging import get_logger 14 | from mmdet.apis import multi_gpu_test, single_gpu_test 15 | from mmdet.core import wrap_fp16_model 16 | from mmdet.datasets import build_dataloader, build_dataset 17 | from mmdet.models import build_detector 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="MMDet test (and eval) a model") 22 | parser.add_argument("--config", help="test config file path") 23 | parser.add_argument("--checkpoint", help="checkpoint file") 24 | parser.add_argument("--img-prefix") 25 | parser.add_argument("--ann-file") 26 | parser.add_argument("--flip", action="store_true") 27 | parser.add_argument("--out", help="output result file in pickle format") 28 | parser.add_argument( 29 | "--format-only", 30 | action="store_true", 31 | help="Format the output results without perform evaluation. It is" 32 | "useful when you want to format the result to a specific format and " 33 | "submit it to the test server", 34 | ) 35 | parser.add_argument( 36 | "--eval", 37 | type=str, 38 | nargs="+", 39 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",' 40 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC', 41 | ) 42 | parser.add_argument("--show", action="store_true", help="show results") 43 | parser.add_argument("--show-dir", help="directory where painted images will be saved") 44 | parser.add_argument("--log-file", default=None) 45 | parser.add_argument("--fold", type=int) 46 | parser.add_argument("--show-score-thr", type=float, default=0.3, help="score threshold (default: 0.3)") 47 | parser.add_argument("--gpu-collect", action="store_true", help="whether to use gpu to collect results.") 48 | parser.add_argument( 49 | "--tmpdir", 50 | help="tmp directory used for collecting results from multiple " 51 | "workers, available when gpu-collect is not specified", 52 | ) 53 | parser.add_argument("--options", nargs="+", action=DictAction, help="arguments in dict") 54 | parser.add_argument( 55 | "--launcher", choices=["none", "pytorch", "slurm", "mpi"], default="pytorch", help="job launcher" 56 | ) 57 | parser.add_argument("--local_rank", type=int, default=0) 58 | parser.add_argument("--score_thr", type=float) 59 | parser.add_argument("--iou_thr", type=float) 60 | args = parser.parse_args() 61 | if "LOCAL_RANK" not in os.environ: 62 | os.environ["LOCAL_RANK"] = str(args.local_rank) 63 | return args 64 | 65 | 66 | def main( 67 | config=None, 68 | checkpoint=None, 69 | img_prefix=None, 70 | ann_file=None, 71 | flip=False, 72 | out=None, 73 | format_only=False, 74 | eval=None, 75 | show=False, 76 | show_dir=None, 77 | log_file=None, 78 | fold=None, 79 | show_score_thr=0.3, 80 | gpu_collect=False, 81 | tmpdir=None, 82 | options=None, 83 | launcher="none", 84 | local_rank=0, 85 | score_thr=None, 86 | iou_thr=None, 87 | ) -> Tuple[list, WheatDataset]: 88 | logger = get_logger("inference", log_file=log_file, log_mode="a") 89 | assert out or eval or format_only or show or show_dir, ( 90 | "Please specify at least one operation (save/eval/format/show the " 91 | 'results / save the results) with the argument "--out", "--eval"' 92 | ', "--format-only", "--show" or "--show-dir"' 93 | ) 94 | 95 | if eval and format_only: 96 | raise ValueError("--eval and --format_only cannot be both specified") 97 | 98 | if out is not None and not out.endswith((".pkl", ".pickle")): 99 | raise ValueError("The output file must be a pkl file.") 100 | 101 | cfg = Config.fromfile(config) 102 | if fold is not None: 103 | cfg.data.test.ann_file = cfg.data.test.ann_file.format(fold=fold) 104 | if img_prefix is not None: 105 | cfg.data.test.img_prefix = img_prefix 106 | if ann_file is not None: 107 | cfg.data.test.ann_file = ann_file 108 | if flip: 109 | print(cfg.data.test.pipeline[-1]) 110 | print(f"flip: {cfg.data.test.pipeline[-1].flip}") 111 | cfg.data.test.pipeline[-1].flip_direction = ["horizontal", "vertical"] 112 | cfg.data.test.pipeline[-1].flip = flip 113 | if score_thr is not None: 114 | if "rcnn" in cfg.test_cfg: 115 | cfg.test_cfg.rcnn.score_thr = score_thr 116 | else: 117 | cfg.test_cfg.score_thr = score_thr 118 | if iou_thr is not None: 119 | if "rcnn" in cfg.test_cfg: 120 | cfg.test_cfg.rcnn.nms.iou_thr = iou_thr 121 | else: 122 | cfg.test_cfg.nms.iou_thr = iou_thr 123 | # set cudnn_benchmark 124 | if cfg.get("cudnn_benchmark", False): 125 | torch.backends.cudnn.benchmark = True 126 | cfg.model.pretrained = None 127 | cfg.data.test.test_mode = True 128 | 129 | # init distributed env first, since logger depends on the dist info. 130 | if launcher == "none": 131 | distributed = False 132 | else: 133 | distributed = True 134 | init_dist(launcher, **cfg.dist_params) 135 | 136 | # build the dataloader 137 | # TODO: support multiple images per gpu (only minor changes are needed) 138 | dataset = build_dataset(cfg.data.test) 139 | data_loader = build_dataloader( 140 | dataset, samples_per_gpu=1, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False 141 | ) 142 | 143 | # build the model and load checkpoint 144 | model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) 145 | fp16_cfg = cfg.get("fp16", None) 146 | if fp16_cfg is not None: 147 | wrap_fp16_model(model) 148 | checkpoint = load_checkpoint(model, checkpoint, map_location="cpu") 149 | # old versions did not save class info in checkpoints, this walkaround is 150 | # for backward compatibility 151 | if "CLASSES" in checkpoint["meta"]: 152 | model.CLASSES = checkpoint["meta"]["CLASSES"] 153 | else: 154 | model.CLASSES = dataset.CLASSES 155 | 156 | if not distributed: 157 | model = MMDataParallel(model, device_ids=[0]) 158 | outputs = single_gpu_test(model, data_loader, show, show_dir, show_score_thr) 159 | else: 160 | model = MMDistributedDataParallel( 161 | model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False 162 | ) 163 | outputs = multi_gpu_test(model, data_loader, tmpdir, gpu_collect) 164 | 165 | rank, _ = get_dist_info() 166 | if rank == 0: 167 | if out: 168 | print(f"\nwriting results to {out}") 169 | mmcv.dump(outputs, out) 170 | kwargs = {} if options is None else options 171 | if format_only: 172 | dataset.format_results(outputs, **kwargs) 173 | if eval: 174 | print_log(f"config: {config}", logger) 175 | print_log(f"score_thr: {score_thr}", logger) 176 | print_log(f"iou_thr: {iou_thr}", logger) 177 | dataset.evaluate(outputs, logger=logger, **kwargs) 178 | return outputs, dataset 179 | 180 | 181 | if __name__ == "__main__": 182 | main(**vars(parse_args())) 183 | -------------------------------------------------------------------------------- /gwd/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import os.path as osp 5 | import time 6 | import warnings # noqa E402 7 | 8 | import mmcv 9 | import torch 10 | from mmcv import Config, DictAction 11 | from mmcv.runner import init_dist 12 | 13 | import gwd # noqa F401 14 | from gwd.patches import build_dataset 15 | from mmdet import __version__ 16 | from mmdet.apis import set_random_seed, train_detector 17 | from mmdet.models import build_detector 18 | from mmdet.utils import collect_env, get_root_logger 19 | 20 | warnings.simplefilter("ignore", UserWarning) # noqa E402 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="Train a detector") 25 | parser.add_argument("--config", help="train config file path") 26 | parser.add_argument("--fold", type=int, default=None) 27 | parser.add_argument("--work-dir", help="the dir to save logs and models") 28 | parser.add_argument("--resume-from", help="the checkpoint file to resume from") 29 | parser.add_argument("--load-from", help="the checkpoint file to load from") 30 | parser.add_argument( 31 | "--no-validate", action="store_true", help="whether not to evaluate the checkpoint during training" 32 | ) 33 | group_gpus = parser.add_mutually_exclusive_group() 34 | group_gpus.add_argument( 35 | "--gpus", type=int, help="number of gpus to use " "(only applicable to non-distributed training)" 36 | ) 37 | group_gpus.add_argument( 38 | "--gpu-ids", type=int, nargs="+", help="ids of gpus to use " "(only applicable to non-distributed training)" 39 | ) 40 | parser.add_argument("--seed", type=int, default=None, help="random seed") 41 | parser.add_argument( 42 | "--deterministic", action="store_true", help="whether to set deterministic options for CUDNN backend." 43 | ) 44 | parser.add_argument("--options", nargs="+", action=DictAction, help="arguments in dict") 45 | parser.add_argument("--launcher", choices=["none", "pytorch", "slurm", "mpi"], default="none", help="job launcher") 46 | parser.add_argument("--local_rank", type=int, default=0) 47 | args = parser.parse_args() 48 | if "LOCAL_RANK" not in os.environ: 49 | os.environ["LOCAL_RANK"] = str(args.local_rank) 50 | 51 | return args 52 | 53 | 54 | def main( 55 | config=None, 56 | fold=None, 57 | work_dir=None, 58 | resume_from=None, 59 | load_from=None, 60 | no_validate=False, 61 | gpus=None, 62 | gpu_ids=None, 63 | seed=None, 64 | deterministic=False, 65 | options=None, 66 | launcher="none", 67 | local_rank=0, 68 | ): 69 | cfg = Config.fromfile(config) 70 | 71 | if fold is not None: 72 | if "ann_file" in cfg.data.train: 73 | if isinstance(cfg.data.train.ann_file, list): 74 | cfg.data.train.ann_file = [x.format(fold=fold) for x in cfg.data.train.ann_file] 75 | elif isinstance(cfg.data.train.ann_file, str): 76 | cfg.data.train.ann_file = cfg.data.train.ann_file.format(fold=fold) 77 | else: 78 | cfg.data.train.dataset.ann_file = cfg.data.train.dataset.ann_file.format(fold=fold) 79 | cfg.data.val.ann_file = cfg.data.val.ann_file.format(fold=fold) 80 | cfg.data.test.ann_file = cfg.data.test.ann_file.format(fold=fold) 81 | if options is not None: 82 | cfg.merge_from_dict(options) 83 | # set cudnn_benchmark 84 | if cfg.get("cudnn_benchmark", False): 85 | torch.backends.cudnn.benchmark = True 86 | 87 | # work_dir is determined in this priority: CLI > segment in file > filename 88 | if work_dir is not None: 89 | # update configs according to CLI args if args.work_dir is not None 90 | cfg.work_dir = work_dir 91 | elif cfg.get("work_dir", None) is None: 92 | # use config filename as default work_dir if cfg.work_dir is None 93 | cfg.work_dir = osp.join("/dumps/work_dirs", osp.splitext(osp.basename(config))[0], str(fold)) 94 | if resume_from is not None: 95 | cfg.resume_from = resume_from 96 | if load_from is not None: 97 | cfg.load_from = load_from 98 | if gpu_ids is not None: 99 | cfg.gpu_ids = gpu_ids 100 | else: 101 | cfg.gpu_ids = range(1) if gpus is None else range(gpus) 102 | 103 | # init distributed env first, since logger depends on the dist info. 104 | if launcher == "none": 105 | distributed = False 106 | else: 107 | distributed = True 108 | init_dist(launcher, **cfg.dist_params) 109 | 110 | # create work_dir 111 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 112 | # init the logger before other steps 113 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 114 | log_file = osp.join(cfg.work_dir, f"{timestamp}.log") 115 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 116 | 117 | # init the meta dict to record some important information such as 118 | # environment info and seed, which will be logged 119 | meta = dict() 120 | # log env info 121 | env_info_dict = collect_env() 122 | env_info = "\n".join([f"{k}: {v}" for k, v in env_info_dict.items()]) 123 | dash_line = "-" * 60 + "\n" 124 | logger.info("Environment info:\n" + dash_line + env_info + "\n" + dash_line) 125 | meta["env_info"] = env_info 126 | 127 | # log some basic info 128 | logger.info(f"Distributed training: {distributed}") 129 | logger.info(f"Config:\n{cfg.pretty_text}") 130 | 131 | # set random seeds 132 | if seed is not None: 133 | logger.info(f"Set random seed to {seed}, " f"deterministic: {deterministic}") 134 | set_random_seed(seed, deterministic=deterministic) 135 | cfg.seed = seed 136 | meta["seed"] = seed 137 | 138 | model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) 139 | datasets = [build_dataset(cfg.data.train)] 140 | if len(cfg.workflow) == 2: 141 | val_dataset = copy.deepcopy(cfg.data.val) 142 | val_dataset.pipeline = cfg.data.train.pipeline 143 | datasets.append(build_dataset(val_dataset)) 144 | if cfg.checkpoint_config is not None: 145 | # save mmdet version, config file content and class names in 146 | # checkpoints as meta data 147 | cfg.checkpoint_config.meta = dict( 148 | mmdet_version=__version__, config=cfg.pretty_text, CLASSES=datasets[0].CLASSES 149 | ) 150 | # add an attribute for visualization convenience 151 | model.CLASSES = datasets[0].CLASSES 152 | train_detector( 153 | model, datasets, cfg, distributed=distributed, validate=(not no_validate), timestamp=timestamp, meta=meta 154 | ) 155 | 156 | 157 | if __name__ == "__main__": 158 | main(**vars(parse_args())) 159 | -------------------------------------------------------------------------------- /gwd/wbf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | from multiprocessing import Pool 4 | 5 | import mmcv 6 | import numpy as np 7 | from ensemble_boxes import weighted_boxes_fusion 8 | from scipy.stats import rankdata 9 | from tqdm import tqdm 10 | 11 | import gwd # noqa F401 12 | from mmdet.datasets import build_dataset 13 | 14 | WHEAT_CLASS_ID = 0 15 | IMAGE_SIZE = 1024 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--prediction_paths", 22 | default=[ 23 | "/data/rfp_r50_ga_mstrain_stage2_fold0_predictions.pkl", 24 | "/data/universe_r50_mstrain_spike_stage1_epoch_12_predictions.pkl", 25 | ], 26 | nargs="+", 27 | ) 28 | parser.add_argument("--iou_thr", default=0.55, type=float) 29 | parser.add_argument("--score_thr", default=0.45, type=float) 30 | parser.add_argument("--weights", default=[1.0, 1.0], type=float, nargs="+") 31 | return parser.parse_args() 32 | 33 | 34 | def mmdet2wbf(prediction): 35 | wheat_prediction = prediction[WHEAT_CLASS_ID] 36 | bboxes = wheat_prediction[:, :4] / IMAGE_SIZE 37 | scores = np.clip(wheat_prediction[:, 4], 0, 1.0) 38 | scores = 0.5 * (rankdata(scores) / len(scores)) + 0.5 39 | labels = np.zeros_like(scores) 40 | return bboxes, scores, labels 41 | 42 | 43 | def wbf_per_sample(sample_predictions, weights, iou_thr, score_thr): 44 | bboxes_list = [] 45 | scores_list = [] 46 | labels_list = [] 47 | for prediction in sample_predictions: 48 | bboxes, scores, labels = mmdet2wbf(prediction) 49 | bboxes_list.append(bboxes) 50 | scores_list.append(scores) 51 | labels_list.append(labels) 52 | bboxes, scores, labels = weighted_boxes_fusion( 53 | boxes_list=bboxes_list, 54 | scores_list=scores_list, 55 | labels_list=labels_list, 56 | weights=weights, 57 | iou_thr=iou_thr, 58 | skip_box_thr=score_thr, 59 | ) 60 | return [np.concatenate([bboxes * IMAGE_SIZE, scores.reshape(-1, 1)], axis=1)] 61 | 62 | 63 | def main(all_predictions, weights, iou_thr, score_thr): 64 | with Pool(32) as p: 65 | results = list( 66 | ( 67 | tqdm( 68 | p.imap( 69 | partial(wbf_per_sample, weights=weights, iou_thr=iou_thr, score_thr=score_thr), 70 | zip(*all_predictions), 71 | ), 72 | total=len(all_predictions[0]), 73 | ) 74 | ) 75 | ) 76 | return results 77 | 78 | 79 | if __name__ == "__main__": 80 | from mmcv import Config 81 | 82 | cfg = Config.fromfile("configs/_base_/datasets/wheat_detection_mstrain.py") 83 | cfg.data.test.ann_file = cfg.data.test.ann_file.format(fold=0) 84 | cfg.data.test.test_mode = True 85 | dataset = build_dataset(cfg.data.test) 86 | args = vars(parse_args()) 87 | prediction_paths = args.pop("prediction_paths") 88 | _all_predictions = [mmcv.load(path) if isinstance(path, str) else path for path in prediction_paths] 89 | for predictions in _all_predictions: 90 | dataset.evaluate(results=predictions) 91 | 92 | metrics = {} 93 | # for _iou_thr in np.arange(0.3, 0.7, 0.1): 94 | # for _score_thr in np.arange(0.45, 0.7, 0.1): 95 | # for w in np.arange(0.1, 1.0, 0.1): 96 | for _iou_thr in [0.55]: 97 | for _score_thr in [0.45]: 98 | for w in [0.75]: 99 | _weights = [w, 1 - w] 100 | print(f"iou_thr: {_iou_thr}, score_thr: {_score_thr}, weights: {_weights}") 101 | metrics[(_iou_thr, _score_thr, w)] = dataset.evaluate( 102 | main(iou_thr=_iou_thr, score_thr=_score_thr, weights=_weights, all_predictions=_all_predictions) 103 | ) 104 | print(metrics) 105 | best_parameters = max(metrics, key=metrics.get) 106 | print(f"best_parameters: {best_parameters}, best_score: {metrics[best_parameters]}") 107 | -------------------------------------------------------------------------------- /gwd/weights/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/gwd/weights/__init__.py -------------------------------------------------------------------------------- /gwd/weights/prepare_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import OrderedDict 3 | from typing import Dict, List 4 | 5 | import torch 6 | 7 | from mmdet.datasets.coco import CocoDataset 8 | 9 | BACKGROUND_INDEX = 0 10 | NUM_COCO_CLASSES = 80 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--weights_path", default="/dumps/DetectoRS_X101-ed983634_v2.pth") 16 | parser.add_argument("--output_path", default="/dumps/DetectoRS_X101-ed983634_v3.pth") 17 | parser.add_argument("--classes", default=["broccoli"], nargs="+") 18 | return parser.parse_args() 19 | 20 | 21 | def change_weights(name: str, w: torch.Tensor, class_indices: List[int]): 22 | old_shape = w.shape 23 | if name.startswith("roi_head.bbox_head"): 24 | if "fc_cls" in name: 25 | w = w[[BACKGROUND_INDEX] + class_indices] 26 | elif "fc_reg" in name and w.shape[0] == 4 * NUM_COCO_CLASSES: 27 | w = w[sum([[i, i + 1, i + 2, i + 3] for i in class_indices], [])] 28 | print(f"{name}: {old_shape} -> {w.shape}") 29 | return w 30 | 31 | 32 | def main(): 33 | args = parse_args() 34 | class_indices = [CocoDataset.CLASSES.index(cls) for cls in args.classes] 35 | print(class_indices) 36 | weights: Dict[str, OrderedDict] = torch.load(args.weights_path, map_location="cpu") 37 | for name, w in weights["state_dict"].items(): 38 | weights["state_dict"][name] = change_weights(name, w, class_indices) 39 | torch.save(weights, args.output_path) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /gwd/weights/rm_optimizer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--weights", default="/dumps/work_dirs/universe_r101_gfl_mstrain_stage1/0/epoch_11.pth") 9 | parser.add_argument( 10 | "--output", 11 | default="/dumps/work_dirs/universe_r101_gfl_mstrain_stage1/0/universe_r101_gfl_mstrain_stage1_epoch_11.pth", 12 | ) 13 | return parser.parse_args() 14 | 15 | 16 | def main(): 17 | args = parse_args() 18 | weights = torch.load(args.weights, map_location="cpu") 19 | del weights["optimizer"] 20 | torch.save(weights, args.output) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /gwd/weights/upgrade_model_version.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | 8 | def is_head(key): 9 | valid_head_list = ["bbox_head", "mask_head", "semantic_head", "grid_head", "mask_iou_head"] 10 | 11 | return any(key.startswith(h) for h in valid_head_list) 12 | 13 | 14 | def parse_config(config_strings): 15 | # temp_file = tempfile.NamedTemporaryFile() 16 | # config_path = f'{temp_file.name}.py' 17 | # with open(config_path, 'w') as f: 18 | # f.write(config_strings) 19 | # 20 | # config = Config.fromfile(config_path) 21 | # is_two_stage = True 22 | # is_ssd = False 23 | # is_retina = False 24 | # reg_cls_agnostic = False 25 | # if 'rpn_head' not in config.model: 26 | # is_two_stage = False 27 | # # check whether it is SSD 28 | # if config.model.bbox_head.type == 'SSDHead': 29 | # is_ssd = True 30 | # elif config.model.bbox_head.type == 'RetinaHead': 31 | # is_retina = True 32 | # elif isinstance(config.model['bbox_head'], list): 33 | # reg_cls_agnostic = True 34 | # elif 'reg_class_agnostic' in config.model.bbox_head: 35 | # reg_cls_agnostic = config.model.bbox_head \ 36 | # .reg_class_agnostic 37 | # temp_file.close() 38 | return False, False, False, False 39 | 40 | 41 | def reorder_cls_channel(val, num_classes=81): 42 | # bias 43 | if val.dim() == 1: 44 | new_val = torch.cat((val[1:], val[:1]), dim=0) 45 | # weight 46 | else: 47 | out_channels, in_channels = val.shape[:2] 48 | # conv_cls for softmax output 49 | if out_channels != num_classes and out_channels % num_classes == 0: 50 | new_val = val.reshape(-1, num_classes, in_channels, *val.shape[2:]) 51 | new_val = torch.cat((new_val[:, 1:], new_val[:, :1]), dim=1) 52 | new_val = new_val.reshape(val.size()) 53 | # fc_cls 54 | elif out_channels == num_classes: 55 | new_val = torch.cat((val[1:], val[:1]), dim=0) 56 | # agnostic | retina_cls | rpn_cls 57 | else: 58 | new_val = val 59 | 60 | return new_val 61 | 62 | 63 | def truncate_cls_channel(val, num_classes=81): 64 | 65 | # bias 66 | if val.dim() == 1: 67 | if val.size(0) % num_classes == 0: 68 | new_val = val[: num_classes - 1] 69 | else: 70 | new_val = val 71 | # weight 72 | else: 73 | out_channels, in_channels = val.shape[:2] 74 | # conv_logits 75 | if out_channels % num_classes == 0: 76 | new_val = val.reshape(num_classes, in_channels, *val.shape[2:])[1:] 77 | new_val = new_val.reshape(-1, *val.shape[1:]) 78 | # agnostic 79 | else: 80 | new_val = val 81 | 82 | return new_val 83 | 84 | 85 | def truncate_reg_channel(val, num_classes=81): 86 | # bias 87 | if val.dim() == 1: 88 | # fc_reg|rpn_reg 89 | if val.size(0) % num_classes == 0: 90 | new_val = val.reshape(num_classes, -1)[: num_classes - 1] 91 | new_val = new_val.reshape(-1) 92 | # agnostic 93 | else: 94 | new_val = val 95 | # weight 96 | else: 97 | out_channels, in_channels = val.shape[:2] 98 | # fc_reg|rpn_reg 99 | if out_channels % num_classes == 0: 100 | new_val = val.reshape(num_classes, -1, in_channels, *val.shape[2:])[1:] 101 | new_val = new_val.reshape(-1, *val.shape[1:]) 102 | # agnostic 103 | else: 104 | new_val = val 105 | 106 | return new_val 107 | 108 | 109 | def convert(in_file, out_file, num_classes): 110 | """Convert keys in checkpoints. 111 | 112 | There can be some breaking changes during the development of mmdetection, 113 | and this tool is used for upgrading checkpoints trained with old versions 114 | to the latest one. 115 | """ 116 | checkpoint = torch.load(in_file) 117 | in_state_dict = checkpoint.pop("state_dict") 118 | out_state_dict = OrderedDict() 119 | meta_info = checkpoint["meta"] 120 | is_two_stage, is_ssd, is_retina, reg_cls_agnostic = parse_config(meta_info["config"]) 121 | if meta_info["mmdet_version"] <= "0.5.3" and is_retina: 122 | upgrade_retina = True 123 | else: 124 | upgrade_retina = False 125 | 126 | for key, val in in_state_dict.items(): 127 | new_key = key 128 | new_val = val 129 | if is_two_stage and is_head(key): 130 | new_key = "roi_head.{}".format(key) 131 | 132 | # classification 133 | m = re.search(r"(conv_cls|retina_cls|rpn_cls|fc_cls|fcos_cls|" r"fovea_cls).(weight|bias)", new_key) 134 | if m is not None: 135 | print(f"reorder cls channels of {new_key}") 136 | new_val = reorder_cls_channel(val, num_classes) 137 | 138 | # regression 139 | m = re.search(r"(fc_reg|rpn_reg).(weight|bias)", new_key) 140 | if m is not None and not reg_cls_agnostic: 141 | print(f"truncate regression channels of {new_key}") 142 | new_val = truncate_reg_channel(val, num_classes) 143 | 144 | # mask head 145 | m = re.search(r"(conv_logits).(weight|bias)", new_key) 146 | if m is not None: 147 | print(f"truncate mask prediction channels of {new_key}") 148 | new_val = truncate_cls_channel(val, num_classes) 149 | 150 | m = re.search(r"(cls_convs|reg_convs).\d.(weight|bias)", key) 151 | # Legacy issues in RetinaNet since V1.x 152 | # Use ConvModule instead of nn.Conv2d in RetinaNet 153 | # cls_convs.0.weight -> cls_convs.0.conv.weight 154 | if m is not None and upgrade_retina: 155 | param = m.groups()[1] 156 | new_key = key.replace(param, f"conv.{param}") 157 | out_state_dict[new_key] = val 158 | print(f"rename the name of {key} to {new_key}") 159 | continue 160 | 161 | m = re.search(r"(cls_convs).\d.(weight|bias)", key) 162 | if m is not None and is_ssd: 163 | print(f"reorder cls channels of {new_key}") 164 | new_val = reorder_cls_channel(val, num_classes) 165 | 166 | out_state_dict[new_key] = new_val 167 | checkpoint["state_dict"] = out_state_dict 168 | torch.save(checkpoint, out_file) 169 | 170 | 171 | def main(): 172 | parser = argparse.ArgumentParser(description="Upgrade model version") 173 | parser.add_argument("in_file", help="input checkpoint file") 174 | parser.add_argument("out_file", help="output checkpoint file") 175 | parser.add_argument("--num-classes", type=int, default=81, help="number of classes of the original model") 176 | args = parser.parse_args() 177 | convert(args.in_file, args.out_file, args.num_classes) 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | torch==1.5.0 2 | numpy 3 | matplotlib 4 | mmcv>=0.5.9 5 | numpy 6 | Pillow<=6.2.2 7 | six 8 | terminaltables 9 | torchvision 10 | loguru 11 | pytest 12 | cython 13 | addict 14 | yapf 15 | pandas 16 | ipython 17 | albumentations 18 | pytest 19 | ipdb 20 | ensemble-boxes 21 | bbaug 22 | timm 23 | iterative-stratification 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile 3 | # To update, run: 4 | # 5 | # pip-compile 6 | # 7 | addict==2.2.1 # via -r requirements.in, mmcv 8 | albumentations==0.4.5 # via -r requirements.in 9 | attrs==19.3.0 # via pytest 10 | backcall==0.1.0 # via ipython 11 | bbaug==0.4.2 # via -r requirements.in 12 | cycler==0.10.0 # via matplotlib 13 | cython==0.29.17 # via -r requirements.in 14 | decorator==4.4.2 # via ipython, networkx, traitlets 15 | ensemble-boxes==1.0.2 # via -r requirements.in 16 | future==0.18.2 # via torch 17 | imagecodecs==2020.2.18 # via tifffile 18 | imageio==2.8.0 # via scikit-image 19 | imgaug==0.2.6 # via albumentations, bbaug 20 | importlib-metadata==1.6.0 # via pluggy, pytest 21 | ipdb==0.13.2 # via -r requirements.in 22 | ipython-genutils==0.2.0 # via traitlets 23 | ipython==7.14.0 # via -r requirements.in, ipdb 24 | iterative-stratification==0.1.6 # via -r requirements.in 25 | jedi==0.17.0 # via ipython 26 | joblib==0.16.0 # via scikit-learn 27 | kiwisolver==1.2.0 # via matplotlib 28 | loguru==0.5.0 # via -r requirements.in 29 | matplotlib==3.2.1 # via -r requirements.in, scikit-image 30 | mmcv==0.5.9 # via -r requirements.in 31 | more-itertools==8.3.0 # via pytest 32 | networkx==2.4 # via scikit-image 33 | numpy==1.18.4 # via -r requirements.in, albumentations, ensemble-boxes, imagecodecs, imageio, imgaug, iterative-stratification, matplotlib, mmcv, opencv-python, pandas, pywavelets, scikit-image, scikit-learn, scipy, tifffile, torch, torchvision 34 | opencv-python==4.2.0.34 # via albumentations, mmcv 35 | packaging==20.3 # via pytest 36 | pandas==1.0.3 # via -r requirements.in, ensemble-boxes 37 | parso==0.7.0 # via jedi 38 | pexpect==4.8.0 # via ipython 39 | pickleshare==0.7.5 # via ipython 40 | pillow>=8.1.1 # via -r requirements.in, imageio, scikit-image, torchvision 41 | pluggy==0.13.1 # via pytest 42 | prompt-toolkit==3.0.5 # via ipython 43 | ptyprocess==0.6.0 # via pexpect 44 | py>=1.10.0 # via pytest 45 | Pygments>=2.7.4 # via ipython 46 | pyparsing==2.4.7 # via matplotlib, packaging 47 | pytest==5.4.2 # via -r requirements.in 48 | python-dateutil==2.8.1 # via matplotlib, pandas 49 | pytz==2020.1 # via pandas 50 | pywavelets==1.1.1 # via scikit-image 51 | PyYAML>=5.4 # via albumentations, mmcv 52 | scikit-image==0.17.2 # via imgaug 53 | scikit-learn==0.23.2 # via iterative-stratification 54 | scipy==1.4.1 # via albumentations, imgaug, iterative-stratification, scikit-image, scikit-learn 55 | six==1.14.0 # via -r requirements.in, cycler, imgaug, packaging, python-dateutil, traitlets 56 | terminaltables==3.1.0 # via -r requirements.in 57 | threadpoolctl==2.1.0 # via scikit-learn 58 | tifffile==2020.5.11 # via scikit-image 59 | timm==0.1.30 # via -r requirements.in 60 | torch==1.5.0 # via -r requirements.in, timm, torchvision 61 | torchvision==0.6.0 # via -r requirements.in, timm 62 | traitlets==4.3.3 # via ipython 63 | wcwidth==0.1.9 # via prompt-toolkit, pytest 64 | yapf==0.30.0 # via -r requirements.in, mmcv 65 | zipp==3.1.0 # via importlib-metadata 66 | 67 | # The following packages are considered to be unsafe in a requirements file: 68 | # setuptools 69 | -------------------------------------------------------------------------------- /script_template.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import gzip 3 | import os 4 | import os.path as osp 5 | import sys 6 | from pathlib import Path 7 | from typing import Dict 8 | 9 | from gwd import submit, test, train, wbf # noqa E402 10 | from gwd.converters import images2coco, kaggle2coco # noqa E402 11 | 12 | MODELS_ROOT = "/kaggle/input/gwd-models" 13 | WHEELS_ROOT = "/kaggle/input/mmdetection-wheels" 14 | INPUT_ROOT = "/kaggle/input/global-wheat-detection" 15 | KAGGLE_WORKING = "/kaggle/working" 16 | 17 | # this is base64 encoded source code 18 | file_data: Dict = {file_data} 19 | 20 | for path, encoded in file_data.items(): 21 | path = Path(KAGGLE_WORKING) / path 22 | print(path) 23 | path.parent.mkdir(parents=True, exist_ok=True) 24 | path.write_bytes(gzip.decompress(base64.b64decode(encoded))) 25 | 26 | 27 | def run(command): 28 | os.system(f"export PYTHONPATH=${{PYTHONPATH}}:{KAGGLE_WORKING} && " + command) 29 | 30 | 31 | run(f"python {KAGGLE_WORKING}/setup.py develop") 32 | sys.path.append(KAGGLE_WORKING) 33 | 34 | for wheel_name in [ 35 | "addict-2.2.1-py3-none-any.whl", 36 | "mmcv-0.6.2-cp37-cp37m-linux_x86_64.whl", 37 | "terminal-0.4.0-py3-none-any.whl", 38 | "terminaltables-3.1.0-py3-none-any.whl", 39 | "pycocotools-12.0-cp37-cp37m-linux_x86_64.whl", 40 | "mmdet-2.2.0cdfc6a1-cp37-cp37m-linux_x86_64.whl", 41 | "ensemble_boxes-1.0.0-py3-none-any.whl", 42 | ]: 43 | wheel_path = osp.join(WHEELS_ROOT, wheel_name) 44 | run(f"pip install {wheel_path}") 45 | 46 | 47 | CONFIG_DETECTORS = osp.join(KAGGLE_WORKING, "configs/detectors/detectors_r50_ga_mstrain_stage2.py") 48 | CHECKPOINT_DETECTORS = osp.join(MODELS_ROOT, "detectors_r50_ga_mstrainv2_stage1_epoch_24.pth") 49 | CHECKPOINT_UNIVERSE = osp.join(MODELS_ROOT, "universe_r101_gfl_mstrain_stage1_epoch_11.pth") 50 | IMG_PREFIX = osp.join(INPUT_ROOT, "test") 51 | ANN_FILE = "coco_test_images.json" 52 | PSEUDO_PATH = f"{KAGGLE_WORKING}/coco_pseudo_test.json" 53 | PSEUDO_SCORE_THRESHOLD = 0.75 54 | PSEUDO_CONFIDENCE_THRESHOLD = 0.6 55 | SUBMISSION_PATH = "submission.csv" 56 | 57 | # make pseudo prediction 58 | submit.main( 59 | img_prefix=IMG_PREFIX, 60 | configs=[CONFIG_DETECTORS], 61 | checkpoints=[CHECKPOINT_DETECTORS], 62 | submission_path=f"detectors_{SUBMISSION_PATH}", 63 | ann_file=ANN_FILE, 64 | pseudo_path=PSEUDO_PATH, 65 | pseudo_score_threshold=PSEUDO_SCORE_THRESHOLD, 66 | pseudo_confidence_threshold=PSEUDO_CONFIDENCE_THRESHOLD, 67 | flip=True, 68 | format_only=True, 69 | weights=None, 70 | iou_thr=None, 71 | score_thr=None, 72 | ) 73 | 74 | # convert train samples to COCO 75 | kaggle2coco.main( 76 | annotation_path=f"{INPUT_ROOT}/train.csv", 77 | output_path=f"{KAGGLE_WORKING}/coco_train.json", 78 | exclude_sources=["usask_1", "ethz_1"], 79 | ) 80 | 81 | if len(os.listdir(IMG_PREFIX)) == 10: 82 | PSEUDO_CONFIG_DETECTORS = f"{KAGGLE_WORKING}/configs/detectors/detectors_r50_ga_mstrain_public_pseudo.py" 83 | PSEUDO_CONFIG_UNIVERSE = f"{KAGGLE_WORKING}/configs/universe_r101_gfl/universe_r101_gfl_mstrain_public_pseudo.py" 84 | else: 85 | PSEUDO_CONFIG_DETECTORS = f"{KAGGLE_WORKING}/configs/detectors/detectors_r50_ga_mstrain_private_pseudo.py" 86 | PSEUDO_CONFIG_UNIVERSE = f"{KAGGLE_WORKING}/configs/universe_r101_gfl/universe_r101_gfl_mstrain_private_pseudo.py" 87 | 88 | # retrain DetectoRS 89 | run( 90 | f"python {KAGGLE_WORKING}/gwd/train.py " 91 | f"--config {PSEUDO_CONFIG_DETECTORS} " 92 | "--no-validate " 93 | f"--load-from {CHECKPOINT_DETECTORS} " 94 | f"--work-dir {KAGGLE_WORKING}/pseudo_detectors" 95 | ) 96 | detectors_predictions, test_dataset = test.main( 97 | config=PSEUDO_CONFIG_DETECTORS, 98 | checkpoint=f"{KAGGLE_WORKING}/pseudo_detectors/latest.pth", 99 | img_prefix=IMG_PREFIX, 100 | ann_file=ANN_FILE, 101 | flip=True, 102 | format_only=True, 103 | options=dict(output_path=f"pseudo_detectors_{SUBMISSION_PATH}"), 104 | ) 105 | test_dataset.pseudo_results( 106 | results=detectors_predictions, 107 | output_path=PSEUDO_PATH, 108 | pseudo_score_threshold=PSEUDO_SCORE_THRESHOLD, 109 | pseudo_confidence_threshold=PSEUDO_CONFIDENCE_THRESHOLD, 110 | ) 111 | 112 | # retrain UniverseNet 113 | run( 114 | f"python {KAGGLE_WORKING}/gwd/train.py " 115 | f"--config {PSEUDO_CONFIG_UNIVERSE} " 116 | "--no-validate " 117 | f"--load-from {CHECKPOINT_UNIVERSE} " 118 | f"--work-dir {KAGGLE_WORKING}/pseudo_universe" 119 | ) 120 | universe_predictions, test_dataset = test.main( 121 | config=PSEUDO_CONFIG_UNIVERSE, 122 | checkpoint=f"{KAGGLE_WORKING}/pseudo_universe/latest.pth", 123 | img_prefix=IMG_PREFIX, 124 | ann_file=ANN_FILE, 125 | flip=True, 126 | format_only=True, 127 | options=dict(output_path=f"pseudo_universe_{SUBMISSION_PATH}"), 128 | ) 129 | ensemble_predictions = wbf.main( 130 | [detectors_predictions, universe_predictions], weights=[0.65, 0.35], iou_thr=0.55, score_thr=0.45 131 | ) 132 | test_dataset.format_results(results=ensemble_predictions, output_path=SUBMISSION_PATH) 133 | -------------------------------------------------------------------------------- /scripts/colorization.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | for folder in train crops_fold0 6 | do 7 | python "$PROJECT_ROOT"/gwd/colorization/generate.py \ 8 | --img_pattern=/data/${folder}/*jpg \ 9 | --weights_path=/dumps/pix2pix_gen.pth \ 10 | --output_root=/data/colored_${folder} 11 | done 12 | -------------------------------------------------------------------------------- /scripts/kaggle2coco.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | for fold in /data/folds_v2/[0-9]* 4 | do 5 | for mode in mosaic tile 6 | do 7 | for split in train val 8 | do 9 | echo "$fold"/"$mode"_"$split" 10 | python "$PROJECT_ROOT"/gwd/converters/kaggle2coco.py \ 11 | --annotation_path="$fold"/"$mode"_"$split".csv \ 12 | --output_path="$fold"/coco_"$mode"_"$split".json 13 | wait 14 | done 15 | done 16 | done 17 | -------------------------------------------------------------------------------- /scripts/seach_thresholds.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | CONFIG=configs/detectors/detectors_r50_ga_mstrain_stage2.py 6 | 7 | GPUS=$1 8 | CHECKPOINT=/dumps/work_dirs/detectors_r50_ga_mstrainv2_stage1/0/detectors_r50_ga_mstrainv2_stage1_epoch_24.pth 9 | 10 | for score_thr in 0.5 11 | do 12 | for iou_thr in 0.5 13 | do 14 | python -m torch.distributed.launch --nproc_per_node="$GPUS" --master_port="$RANDOM" \ 15 | "$PROJECT_ROOT"/gwd/test.py \ 16 | --config "$CONFIG" \ 17 | --checkpoint "$CHECKPOINT" \ 18 | --eval bbox \ 19 | --fold 0 \ 20 | --score_thr $score_thr \ 21 | --iou_thr $iou_thr \ 22 | --log-file /dumps/thr_logs.log 23 | done 24 | done 25 | -------------------------------------------------------------------------------- /scripts/stylize.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | for folder in train crops_fold0 SPIKE_images 6 | do 7 | for i in {1..4} 8 | do 9 | python "$PROJECT_ROOT"/gwd/stylize/run.py \ 10 | --content-dir=/data/${folder} \ 11 | --style-dir=/data/test \ 12 | --output-dir=/data/stylized_${folder}_v${i} 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | GPUS=8 6 | PORT=${PORT:-29500} 7 | 8 | CONFIG=configs/universe_r101_gfl_v2/universe_r101_gfl_mstrainv2_stage1.py 9 | CHECKPOINT=/dumps/work_dirs/universe_r101_gfl_mstrainv2_stage1/0/epoch_23.pth 10 | 11 | python -m torch.distributed.launch --nproc_per_node="$GPUS" --master_port="$PORT" \ 12 | "$PROJECT_ROOT"/gwd/test.py \ 13 | --config "$CONFIG" \ 14 | --checkpoint "$CHECKPOINT" \ 15 | --out /data/universe_r101_gfl_mstrainv2_stage1_1024_crops_predictions.pkl \ 16 | --eval bbox \ 17 | --fold 0 18 | -------------------------------------------------------------------------------- /scripts/test_crops.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | GPUS=8 6 | PORT=${PORT:-29500} 7 | 8 | CONFIG=configs/detectors/detectors_r50_ga_mstrain_stage2.py 9 | CHECKPOINT=/dumps/PL_detectors_r50.pth 10 | 11 | python -m torch.distributed.launch --nproc_per_node="$GPUS" --master_port="$PORT" \ 12 | "$PROJECT_ROOT"/gwd/test.py \ 13 | --config "$CONFIG" \ 14 | --checkpoint "$CHECKPOINT" \ 15 | --out /data/crops_fold0_predictions.pkl \ 16 | --ann-file /data/coco_crops_fold0.json \ 17 | --img-prefix /data/crops_fold0 18 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | CONFIG=$1 6 | GPUS=$2 7 | PORT=${PORT:-29500} 8 | 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$RANDOM \ 10 | "$PROJECT_ROOT"/gwd/train.py --config $CONFIG --launcher pytorch --fold 0 ${@:3} 11 | -------------------------------------------------------------------------------- /scripts/train_detectors.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | bash scripts/train.sh configs/detectors/detectors_r50_ga_mstrain_stage0.py 4 5 | wait 6 | bash scripts/train.sh configs/detectors/detectors_r50_ga_mstrain_stage1.py 4 7 | wait 8 | bash scripts/train.sh configs/detectors/detectors_r50_ga_mstrain_stage2.py 4 9 | -------------------------------------------------------------------------------- /scripts/train_universenet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | bash scripts/train.sh configs/universe_r101_gfl/universe_r101_gfl_mstrain_stage0.py 4 5 | wait 6 | bash scripts/train.sh configs/universe_r101_gfl/universe_r101_gfl_mstrain_stage1.py 4 7 | wait 8 | bash scripts/train.sh configs/universe_r101_gfl/universe_r101_gfl_mstrain_stage2.py 4 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203 3 | max-line-length = 120 4 | exclude = .git,__pycache__,gwd/__init__.py 5 | 6 | [isort] 7 | multi_line_output=3 8 | include_trailing_comma=True 9 | force_grid_wrap=0 10 | use_parentheses=True 11 | line_length=120 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="gwd", packages=["gwd"]) 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirassov/kaggle-global-wheat-detection/b26295ea257f73089f1a067b70b4a7ee638f6b83/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from mmcv import Config 4 | 5 | from gwd.datasets.evaluation import calc_tpfpfn, kaggle_map 6 | from mmdet.datasets import build_dataset 7 | 8 | 9 | @pytest.fixture(scope="session") 10 | def dataset(): 11 | cfg = Config.fromfile("configs/_base_/datasets/wheat_detection.py") 12 | return build_dataset(cfg) 13 | 14 | 15 | @pytest.fixture(scope="session") 16 | def gt_bboxes(): 17 | return np.array([[954, 391, 1024, 481], [660, 220, 755, 322]]) 18 | 19 | 20 | @pytest.fixture(scope="session") 21 | def one_bboxes(): 22 | return np.array([[954, 391, 1024, 481, 0]]) 23 | 24 | 25 | @pytest.fixture(scope="session") 26 | def annotations(gt_bboxes): 27 | return [{"bboxes": gt_bboxes, "labels": np.zeros(len(gt_bboxes))}] 28 | 29 | 30 | def test_non_overlapping(gt_bboxes): 31 | det_bboxes = np.array([[0, 0, 10, 10, 0]]) 32 | tp, fp, fn = calc_tpfpfn(det_bboxes, gt_bboxes) 33 | assert tp == 0 34 | assert fp == 1 35 | assert fn == len(gt_bboxes) 36 | 37 | 38 | def test_one(gt_bboxes, one_bboxes): 39 | tp, fp, fn = calc_tpfpfn(one_bboxes, gt_bboxes) 40 | assert tp == 1 41 | assert fp == 0 42 | assert fn == len(gt_bboxes) - 1 43 | 44 | 45 | def test_empty_map(annotations): 46 | mean_ap, _ = kaggle_map([[np.array([]).reshape(-1, 5)]], annotations, iou_thrs=[0.5]) 47 | assert mean_ap == 0 48 | 49 | 50 | def test_one_map(annotations, one_bboxes): 51 | mean_ap, _ = kaggle_map([[one_bboxes]], annotations, iou_thrs=[0.5]) 52 | assert mean_ap == 1 / 2 53 | 54 | 55 | def test_map(): 56 | gt_bboxes = np.array( 57 | [ 58 | [954, 391, 1024, 481], 59 | [660, 220, 755, 322], 60 | [64, 209, 140, 266], 61 | [896, 99, 998, 168], 62 | [747, 460, 819, 537], 63 | [885, 163, 988, 232], 64 | [514, 399, 604, 496], 65 | [702, 794, 799, 893], 66 | [721, 624, 819, 732], 67 | [826, 512, 908, 606], 68 | [883, 944, 962, 1018], 69 | [247, 594, 370, 686], 70 | [673, 514, 768, 627], 71 | [829, 847, 931, 957], 72 | [94, 737, 186, 844], 73 | [588, 568, 663, 675], 74 | [158, 890, 261, 954], 75 | [744, 906, 819, 985], 76 | [826, 33, 898, 107], 77 | [601, 69, 668, 156], 78 | ] 79 | ) 80 | annotations = [{"bboxes": gt_bboxes, "labels": np.zeros(len(gt_bboxes))}] 81 | det_bboxes = np.array( 82 | [ 83 | [956.0, 409.0, 1024.0, 494.0, 0.997], 84 | [883.0, 945.0, 968.0, 1022.0, 0.996], 85 | [745.0, 468.0, 826.0, 555.0, 0.995], 86 | [658.0, 239.0, 761.0, 344.0, 0.994], 87 | [518.0, 419.0, 609.0, 519.0, 0.993], 88 | [711.0, 805.0, 803.0, 911.0, 0.992], 89 | [62.0, 213.0, 134.0, 277.0, 0.991], 90 | [884.0, 175.0, 993.0, 243.0, 0.99], 91 | [721.0, 626.0, 817.0, 730.0, 0.98], 92 | [878.0, 619.0, 999.0, 700.0, 0.97], 93 | [887.0, 107.0, 998.0, 178.0, 0.95], 94 | [827.0, 525.0, 915.0, 608.0, 0.94], 95 | [816.0, 868.0, 918.0, 954.0, 0.93], 96 | [166.0, 882.0, 244.0, 957.0, 0.92], 97 | [603.0, 563.0, 681.0, 660.0, 0.91], 98 | [744.0, 916.0, 812.0, 968.0, 0.89], 99 | [582.0, 86.0, 668.0, 158.0, 0.88], 100 | [79.0, 715.0, 170.0, 816.0, 0.86], 101 | [246.0, 586.0, 341.0, 666.0, 0.85], 102 | [181.0, 512.0, 274.0, 601.0, 0.84], 103 | [655.0, 527.0, 754.0, 617.0, 0.80], 104 | [568.0, 363.0, 629.0, 439.0, 0.79], 105 | [9.0, 717.0, 161.0, 827.0, 0.74], 106 | [576.0, 698.0, 651.0, 776.0, 0.60], 107 | [805.0, 974.0, 880.0, 1024.0, 0.59], 108 | [10.0, 15.0, 88.0, 79.0, 0.55], 109 | [826.0, 40.0, 895.0, 114.0, 0.53], 110 | [32.0, 983.0, 138.0, 1023.0, 0.50], 111 | ] 112 | ) 113 | assert abs(kaggle_map([[det_bboxes]], annotations, iou_thrs=[0.5])[0] - 0.6552) < 1e-3 114 | assert abs(kaggle_map([[det_bboxes]], annotations, iou_thrs=[0.75])[0] - 0.0909) < 1e-3 115 | assert abs(kaggle_map([[det_bboxes]], annotations, iou_thrs=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75])[0] - 0.3663) < 1e-3 116 | --------------------------------------------------------------------------------