├── .gitignore ├── 00-installation.ipynb ├── 01a-coco.ipynb ├── 01b-open-images.ipynb ├── 02-format-annotations.ipynb ├── 03-train-small.ipynb ├── 04-train-big.ipynb ├── 05-kaggle-global-wheat-detection.ipynb ├── README.md ├── images ├── input.jpg ├── output_03.png ├── output_04.png ├── output_05.png └── sample.jpg ├── mscoco.py ├── open-images-v5 ├── .gitignore └── downloadOI.py ├── open-images-v6 ├── .gitignore ├── downloadOI.py ├── eval.sh ├── main.py └── train.sh └── wheat-detection ├── eval.sh ├── main.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .ipynb_checkpoints 3 | kaggle.json -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Object Detection with Detectron2 2 | 3 | 4 | 5 | A series of notebooks to dive deep into popular datasets for object detection and learn how to train Detectron2 on a custom dataset. 6 | 7 | - [Notebook 00](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/00-installation.ipynb): Install Detectron2 8 | - [Notebook 01a](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/01a-coco.ipynb): Load and read COCO dataset with `COCO PythonAPI` and `GluonCV` 9 | - [Notebook 01b](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/01b-open-images.ipynb): Load and read Open Images v5 10 | - [Notebook 02](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/02-format-annotations.ipynb): Format Open Images annotations for Detectron2 11 | - [Notebook 03](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/03-train-small.ipynb): Train Detectron2 on a small dataset of Camera and Tripod 12 | - [Notebook 04](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/04-train-big.ipynb): Put all steps together: download Open Images v6 and train Detectron2 on a large dataset of 11 musical instruments. 13 | - [Notebook 05](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/05-kaggle-global-wheat-detection.ipynb): Apply Detectron2 to solve a real world challenge ([Kaggle Global Wheat Detection Competition](https://www.kaggle.com/c/global-wheat-detection)) 14 | 15 | ## Sample Outputs 16 | 17 | *From Notebook 03: Camera and Tripod* 18 | 19 | 20 | *From Notebook 04: 11 Musical Instruments* 21 | 22 | 23 | *From Notebook 05: Global Wheat Detection Challenge* 24 | -------------------------------------------------------------------------------- /images/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chriskhanhtran/object-detection-detectron2/91d013a94dac70cda3b4619cd1025d572b718fdc/images/input.jpg -------------------------------------------------------------------------------- /images/output_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chriskhanhtran/object-detection-detectron2/91d013a94dac70cda3b4619cd1025d572b718fdc/images/output_03.png -------------------------------------------------------------------------------- /images/output_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chriskhanhtran/object-detection-detectron2/91d013a94dac70cda3b4619cd1025d572b718fdc/images/output_04.png -------------------------------------------------------------------------------- /images/output_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chriskhanhtran/object-detection-detectron2/91d013a94dac70cda3b4619cd1025d572b718fdc/images/output_05.png -------------------------------------------------------------------------------- /images/sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chriskhanhtran/object-detection-detectron2/91d013a94dac70cda3b4619cd1025d572b718fdc/images/sample.jpg -------------------------------------------------------------------------------- /mscoco.py: -------------------------------------------------------------------------------- 1 | """Prepare MS COCO datasets""" 2 | import os 3 | import shutil 4 | import argparse 5 | import zipfile 6 | from gluoncv.utils import download, makedirs 7 | from gluoncv.data.mscoco.utils import try_import_pycocotools 8 | 9 | _TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/coco') 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description='Initialize MS COCO dataset.', 14 | epilog='Example: python mscoco.py --download-dir ~/mscoco', 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument('--download-dir', type=str, default='~/mscoco/', help='dataset directory on disk') 17 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') 18 | parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted') 19 | args = parser.parse_args() 20 | return args 21 | 22 | def download_coco(path, overwrite=False): 23 | _DOWNLOAD_URLS = [ 24 | ('http://images.cocodataset.org/zips/train2017.zip', 25 | '10ad623668ab00c62c096f0ed636d6aff41faca5'), 26 | ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip', 27 | '8551ee4bb5860311e79dace7e79cb91e432e78b3'), 28 | ('http://images.cocodataset.org/zips/val2017.zip', 29 | '4950dc9d00dbe1c933ee0170f5797584351d2a41'), 30 | # ('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip', 31 | # '46cdcf715b6b4f67e980b529534e79c2edffe084'), 32 | # test2017.zip, for those who want to attend the competition. 33 | # ('http://images.cocodataset.org/zips/test2017.zip', 34 | # '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'), 35 | ] 36 | makedirs(path) 37 | for url, checksum in _DOWNLOAD_URLS: 38 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) 39 | # extract 40 | with zipfile.ZipFile(filename) as zf: 41 | zf.extractall(path=path) 42 | 43 | if __name__ == '__main__': 44 | args = parse_args() 45 | path = os.path.expanduser(args.download_dir) 46 | if not os.path.isdir(path) or not os.path.isdir(os.path.join(path, 'train2017')) \ 47 | or not os.path.isdir(os.path.join(path, 'val2017')) \ 48 | or not os.path.isdir(os.path.join(path, 'annotations')): 49 | if args.no_download: 50 | raise ValueError(('{} is not a valid directory, make sure it is present.' 51 | ' Or you should not disable "--no-download" to grab it'.format(path))) 52 | else: 53 | download_coco(path, overwrite=args.overwrite) 54 | 55 | # make symlink 56 | makedirs(os.path.expanduser('~/.mxnet/datasets')) 57 | if os.path.isdir(_TARGET_DIR): 58 | os.remove(_TARGET_DIR) 59 | os.symlink(path, _TARGET_DIR) 60 | try_import_pycocotools() 61 | -------------------------------------------------------------------------------- /open-images-v5/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !downloadOI.py -------------------------------------------------------------------------------- /open-images-v5/downloadOI.py: -------------------------------------------------------------------------------- 1 | #Author : Sunita Nayak, Big Vision LLC 2 | 3 | #### Usage example: python3 downloadOI.py --classes 'Ice_cream,Cookie' --mode train 4 | 5 | import argparse 6 | import csv 7 | import subprocess 8 | import os 9 | from tqdm import tqdm 10 | import multiprocessing 11 | from multiprocessing import Pool as thread_pool 12 | 13 | cpu_count = multiprocessing.cpu_count() 14 | 15 | parser = argparse.ArgumentParser(description='Download Class specific images from OpenImagesV4') 16 | parser.add_argument("--mode", help="Dataset category - train, validation or test", required=True) 17 | parser.add_argument("--classes", help="Names of object classes to be downloaded", required=True) 18 | parser.add_argument("--nthreads", help="Number of threads to use", required=False, type=int, default=cpu_count*2) 19 | parser.add_argument("--occluded", help="Include occluded images", required=False, type=int, default=1) 20 | parser.add_argument("--truncated", help="Include truncated images", required=False, type=int, default=1) 21 | parser.add_argument("--groupOf", help="Include groupOf images", required=False, type=int, default=1) 22 | parser.add_argument("--depiction", help="Include depiction images", required=False, type=int, default=1) 23 | parser.add_argument("--inside", help="Include inside images", required=False, type=int, default=1) 24 | 25 | args = parser.parse_args() 26 | 27 | run_mode = args.mode 28 | 29 | threads = args.nthreads 30 | 31 | classes = [] 32 | for class_name in args.classes.split(','): 33 | classes.append(class_name) 34 | 35 | # Create a dictionary {class:class_id} 36 | with open('./class-descriptions-boxable.csv', mode='r') as infile: 37 | reader = csv.reader(infile) 38 | dict_list = {rows[1]:rows[0] for rows in reader} 39 | 40 | subprocess.run(['rm', '-rf', run_mode]) 41 | subprocess.run([ 'mkdir', run_mode]) 42 | 43 | pool = thread_pool(threads) 44 | commands = [] 45 | cnt = 0 46 | 47 | for ind in range(0, len(classes)): 48 | 49 | class_name = classes[ind] 50 | print("Class "+str(ind) + " : " + class_name) 51 | 52 | # Create a folder for each class 53 | subprocess.run([ 'mkdir', run_mode+'/'+class_name]) 54 | 55 | # Grep rows of our classes in the annotions-bbox.csv 56 | command = "grep "+dict_list[class_name.replace('_', ' ')] + " ./" + run_mode + "-annotations-bbox.csv" 57 | class_annotations = subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8') 58 | class_annotations = class_annotations.splitlines() 59 | 60 | # For each line we grepped 61 | for line in class_annotations: 62 | 63 | line_parts = line.split(',') 64 | 65 | #IsOccluded,IsTruncated,IsGroupOf,IsDepiction,IsInside 66 | if (args.occluded==0 and int(line_parts[8])>0): 67 | print("Skipped %s",line_parts[0]) 68 | continue 69 | if (args.truncated==0 and int(line_parts[9])>0): 70 | print("Skipped %s",line_parts[0]) 71 | continue 72 | if (args.groupOf==0 and int(line_parts[10])>0): 73 | print("Skipped %s",line_parts[0]) 74 | continue 75 | if (args.depiction==0 and int(line_parts[11])>0): 76 | print("Skipped %s",line_parts[0]) 77 | continue 78 | if (args.inside==0 and int(line_parts[12])>0): 79 | print("Skipped %s",line_parts[0]) 80 | continue 81 | 82 | cnt = cnt + 1 83 | 84 | # Command to download image 85 | command = 'aws s3 --no-sign-request --only-show-errors cp s3://open-images-dataset/'+run_mode+'/'+line_parts[0]+'.jpg '+ run_mode+'/'+class_name+'/'+line_parts[0]+'.jpg' 86 | commands.append(command) 87 | 88 | with open('%s/%s/%s.txt'%(run_mode,class_name,line_parts[0]),'a') as f: 89 | f.write(','.join([class_name, line_parts[4], line_parts[5], line_parts[6], line_parts[7]])+'\n') 90 | 91 | print("Annotation Count : "+str(cnt)) 92 | commands = list(set(commands)) 93 | print("Number of images to be downloaded : "+str(len(commands))) 94 | 95 | list(tqdm(pool.imap(os.system, commands), total = len(commands) )) 96 | 97 | pool.close() 98 | pool.join() 99 | -------------------------------------------------------------------------------- /open-images-v6/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !*.py 3 | !*.sh -------------------------------------------------------------------------------- /open-images-v6/downloadOI.py: -------------------------------------------------------------------------------- 1 | # Author : Sunita Nayak, Big Vision LLC 2 | 3 | #### Usage example: python3 downloadOI.py --classes 'Ice_cream,Cookie' --mode train 4 | 5 | import argparse 6 | import csv 7 | import subprocess 8 | import os 9 | from tqdm import tqdm 10 | import multiprocessing 11 | from multiprocessing import Pool as thread_pool 12 | 13 | cpu_count = multiprocessing.cpu_count() 14 | 15 | parser = argparse.ArgumentParser(description="Download Class specific images from OpenImagesV4") 16 | parser.add_argument("--mode", help="Dataset category - train, validation or test", required=True) 17 | parser.add_argument("--classes", help="Names of object classes to be downloaded", required=True) 18 | parser.add_argument("--nthreads", help="Number of threads to use", required=False, type=int, default=cpu_count * 2) 19 | parser.add_argument("--occluded", help="Include occluded images", required=False, type=int, default=1) 20 | parser.add_argument("--truncated", help="Include truncated images", required=False, type=int, default=1) 21 | parser.add_argument("--groupOf", help="Include groupOf images", required=False, type=int, default=1) 22 | parser.add_argument("--depiction", help="Include depiction images", required=False, type=int, default=1) 23 | parser.add_argument("--inside", help="Include inside images", required=False, type=int, default=1) 24 | 25 | args = parser.parse_args() 26 | 27 | run_mode = args.mode 28 | 29 | threads = args.nthreads 30 | 31 | classes = [] 32 | for class_name in args.classes.split(","): 33 | classes.append(class_name) 34 | 35 | # Read `class-descriptions-boxable.csv` 36 | with open("./class-descriptions-boxable.csv", mode="r") as infile: 37 | reader = csv.reader(infile) 38 | dict_list = {rows[1]: rows[0] for rows in reader} # rows[1] is ClassName, rows[0] is ClassCode 39 | 40 | subprocess.run(["rm", "-rf", run_mode]) 41 | subprocess.run(["mkdir", run_mode]) 42 | 43 | pool = thread_pool(threads) 44 | commands = [] 45 | cnt = 0 46 | 47 | for ind in range(0, len(classes)): 48 | class_name = classes[ind] 49 | print("Class " + str(ind) + " : " + class_name) 50 | 51 | command = "grep " + dict_list[class_name.replace("_", " ")] + " ./" + run_mode + "-annotations-bbox.csv" 52 | class_annotations = subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode("utf-8") 53 | class_annotations = class_annotations.splitlines() 54 | 55 | for line in class_annotations: 56 | line_parts = line.split(",") 57 | img_id = line_parts[0] 58 | save_path = os.path.join(run_mode, img_id + ".jpg") 59 | 60 | # If image exists, skip 61 | if os.path.exists(save_path): 62 | continue 63 | 64 | # Download options: IsOccluded, IsTruncated, IsGroupOf, IsDepiction, IsInside 65 | if args.occluded == 0 and int(line_parts[8]) > 0: 66 | print("Skipped %s", img_id) 67 | continue 68 | if args.truncated == 0 and int(line_parts[9]) > 0: 69 | print("Skipped %s", img_id) 70 | continue 71 | if args.groupOf == 0 and int(line_parts[10]) > 0: 72 | print("Skipped %s", img_id) 73 | continue 74 | if args.depiction == 0 and int(line_parts[11]) > 0: 75 | print("Skipped %s", img_id) 76 | continue 77 | if args.inside == 0 and int(line_parts[12]) > 0: 78 | print("Skipped %s", img_id) 79 | continue 80 | 81 | # Command to download 82 | command = f"aws s3 --no-sign-request --only-show-errors cp s3://open-images-dataset/'{run_mode}'/'{img_id}'.jpg {save_path}" 83 | commands.append(command) 84 | cnt += 1 85 | 86 | print("Annotation Count : " + str(cnt)) 87 | commands = list(set(commands)) 88 | print("Number of images to be downloaded : " + str(len(commands))) 89 | 90 | list(tqdm(pool.imap(os.system, commands), total=len(commands))) 91 | 92 | pool.close() 93 | pool.join() 94 | -------------------------------------------------------------------------------- /open-images-v6/eval.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --eval \ 3 | --train_annot_fp "./small-train-annotations-bbox-target.csv" \ 4 | --val_annot_fp "./validation-annotations-bbox-target.csv" \ 5 | --model_dir "retinanet_R_101_FPN_3x__iter-100000__lr-0.0005__warmup-1000" -------------------------------------------------------------------------------- /open-images-v6/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import os, json, shutil, cv2 4 | import numpy as np 5 | import pandas as pd 6 | 7 | # detectron2 8 | import detectron2 9 | from detectron2.utils.logger import setup_logger 10 | from detectron2 import model_zoo 11 | from detectron2.config import get_cfg 12 | from detectron2.structures import BoxMode 13 | from detectron2.data import DatasetCatalog, MetadataCatalog 14 | from detectron2.engine import DefaultTrainer 15 | from detectron2.evaluation import COCOEvaluator 16 | 17 | models = [ 18 | "faster_rcnn_R_50_FPN_3x", 19 | "faster_rcnn_R_101_FPN_3x", 20 | "faster_rcnn_X_101_32x8d_FPN_3x", # best 21 | "retinanet_R_50_FPN_3x", 22 | "retinanet_R_101_FPN_3x", 23 | ] 24 | 25 | target_classes = [ 26 | 'Accordion', 27 | 'Cello', 28 | 'Drum', 29 | 'French horn', 30 | 'Guitar', 31 | 'Musical keyboard', 32 | 'Piano', 33 | 'Saxophone', 34 | 'Trombone', 35 | 'Trumpet', 36 | 'Violin' 37 | ] 38 | 39 | 40 | def get_args_parser(): 41 | parser = argparse.ArgumentParser('Set up Detectron2', add_help=False) 42 | parser.add_argument('--model', default=None, type=str) 43 | parser.add_argument('--model_dir', default=None, type=str) 44 | parser.add_argument('--train', action='store_true') 45 | parser.add_argument('--eval', action='store_true') 46 | parser.add_argument('--train_annot_fp', default=None, type=str) 47 | parser.add_argument('--val_annot_fp', default=None, type=str) 48 | parser.add_argument('--max_iter', default=10000, type=int) 49 | parser.add_argument('--lr', default=3e-4, type=float) 50 | parser.add_argument('--ims_per_batch', default=4, type=int) 51 | parser.add_argument('--warmup_iters', default=1000, type=int) 52 | parser.add_argument('--gamma', default=0.5, type=float) 53 | parser.add_argument('--lr_decay_steps', default=[100000,], type=int, nargs='*') 54 | return parser 55 | 56 | 57 | def denormalize_bboxes(bboxes, height, width): 58 | """Denormalize bounding boxes in format of (xmin, ymin, xmax, ymax).""" 59 | bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * width 60 | bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * height 61 | return np.round(bboxes) 62 | 63 | 64 | def get_detectron_dicts(annot_fp): 65 | """ 66 | Create Detectron2's standard dataset from an annotation file. 67 | 68 | Args: 69 | annot_df (pd.DataFrame): annotation dataframe. 70 | Return: 71 | dataset_dicts (list[dict]): List of annotation dictionaries for Detectron2. 72 | """ 73 | # Load annotatations 74 | annot_df = pd.read_csv(annot_fp) 75 | 76 | # Get image ids 77 | img_ids = annot_df["ImageID"].unique().tolist() 78 | 79 | dataset_dicts = [] 80 | for img_id in tqdm(img_ids): 81 | file_name = f'images/{img_id}.jpg' 82 | if not os.path.exists(file_name): continue 83 | height, width = cv2.imread(file_name).shape[:2] 84 | 85 | record = {} 86 | record['file_name'] = file_name 87 | record['image_id'] = img_id 88 | record['height'] = height 89 | record['width'] = width 90 | 91 | # Extract bboxes from annotation file 92 | bboxes = annot_df[['XMin', 'YMin', 'XMax', 'YMax']][annot_df['ImageID'] == img_id].values 93 | bboxes = denormalize_bboxes(bboxes, height, width) 94 | class_ids = annot_df[['ClassID']][annot_df['ImageID'] == img_id].values 95 | 96 | annots = [] 97 | for i, bbox in enumerate(bboxes.tolist()): 98 | annot = { 99 | "bbox": bbox, 100 | "bbox_mode": BoxMode.XYXY_ABS, 101 | "category_id": int(class_ids[i]), 102 | } 103 | annots.append(annot) 104 | 105 | record["annotations"] = annots 106 | dataset_dicts.append(record) 107 | return dataset_dicts 108 | 109 | 110 | def main(args): 111 | # Register datasets 112 | print("Registering music_train") 113 | DatasetCatalog.register("music_train", lambda path=args.train_annot_fp: get_detectron_dicts(path)) 114 | MetadataCatalog.get("music_train").set(thing_classes=target_classes) 115 | 116 | print("Registering music_val") 117 | DatasetCatalog.register("music_val", lambda path=args.val_annot_fp: get_detectron_dicts(path)) 118 | MetadataCatalog.get("music_val").set(thing_classes=target_classes) 119 | 120 | # Set up configurations 121 | cfg = get_cfg() 122 | if not args.model_dir: 123 | cfg.merge_from_file(model_zoo.get_config_file(f"COCO-Detection/{args.model}.yaml")) 124 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(f"COCO-Detection/{args.model}.yaml") 125 | cfg.DATASETS.TRAIN = ("music_train",) 126 | cfg.DATASETS.TEST = ("music_val",) 127 | 128 | cfg.SOLVER.IMS_PER_BATCH = args.ims_per_batch 129 | cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 130 | cfg.SOLVER.BASE_LR = args.lr 131 | cfg.SOLVER.MAX_ITER = args.max_iter 132 | cfg.SOLVER.WARMUP_ITERS = args.warmup_iters 133 | cfg.SOLVER.GAMMA = args.gamma 134 | cfg.SOLVER.STEPS = args.lr_decay_steps 135 | 136 | cfg.DATALOADER.NUM_WORKERS = 6 137 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(target_classes) 138 | cfg.MODEL.RETINANET.NUM_CLASSES = len(target_classes) 139 | cfg.OUTPUT_DIR = f"{args.model}__iter-{args.max_iter}__lr-{args.lr}" 140 | if os.path.exists(cfg.OUTPUT_DIR): shutil.rmtree(cfg.OUTPUT_DIR) 141 | os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) 142 | 143 | # Save config 144 | with open(os.path.join(cfg.OUTPUT_DIR, "config.yaml"), "w") as f: 145 | f.write(cfg.dump()) 146 | else: 147 | print("Loading model from ", args.model_dir) 148 | cfg.merge_from_file(os.path.join(args.model_dir, "config.yaml")) 149 | cfg.MODEL.WEIGHTS = os.path.join(args.model_dir, "model_final.pth") 150 | cfg.OUTPUT_DIR = args.model_dir 151 | cfg.DATASETS.TRAIN = ("music_train",) 152 | cfg.DATASETS.TEST = ("music_val",) 153 | 154 | # Set up trainer 155 | setup_logger(output=os.path.join(cfg.OUTPUT_DIR, "terminal_output.log")) 156 | trainer = DefaultTrainer(cfg) 157 | trainer.resume_or_load(resume=False) 158 | 159 | # Train 160 | if args.train: 161 | trainer.train() 162 | 163 | # Evaluate 164 | if args.eval: 165 | evaluator = COCOEvaluator("music_val", cfg, False, output_dir=cfg.OUTPUT_DIR) 166 | eval_results = trainer.test(cfg=cfg, model=trainer.model, evaluators=evaluator) 167 | with open(os.path.join(cfg.OUTPUT_DIR, "eval_results.json"), "w") as f: 168 | json.dump(eval_results, f) 169 | 170 | if __name__ == '__main__': 171 | parser = get_args_parser() 172 | args = parser.parse_args() 173 | main(args) 174 | -------------------------------------------------------------------------------- /open-images-v6/train.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --train \ 3 | --train_annot_fp "./small-train-annotations-bbox-target.csv" \ 4 | --eval \ 5 | --val_annot_fp "./validation-annotations-bbox-target.csv" \ 6 | --model "faster_rcnn_X_101_32x8d_FPN_3x" \ 7 | --max_iter 300 \ 8 | --lr 5e-4 \ 9 | --gamma 0.5 \ 10 | --lr_decay_steps 250 300 \ -------------------------------------------------------------------------------- /wheat-detection/eval.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --eval \ 3 | --val_annot_fp "./annot_val.csv" \ 4 | --model_dir retinanet_R_101_FPN_3x__iter-10000__lr-0.0005 \ -------------------------------------------------------------------------------- /wheat-detection/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | import ast 7 | import json 8 | import shutil 9 | 10 | # import some common detectron2 utilities 11 | import detectron2 12 | from detectron2.utils.logger import setup_logger 13 | from detectron2 import model_zoo 14 | from detectron2.config import get_cfg 15 | from detectron2.structures import BoxMode 16 | from detectron2.data import DatasetCatalog, MetadataCatalog 17 | from detectron2.engine import DefaultTrainer 18 | from detectron2.evaluation import COCOEvaluator 19 | 20 | 21 | models = [ 22 | "faster_rcnn_R_50_FPN_3x", 23 | "faster_rcnn_R_101_FPN_3x", 24 | "faster_rcnn_X_101_32x8d_FPN_3x", 25 | "retinanet_R_50_FPN_3x", 26 | "retinanet_R_101_FPN_3x", 27 | ] 28 | 29 | 30 | def get_args_parser(): 31 | parser = argparse.ArgumentParser('Set up Detectron2', add_help=False) 32 | parser.add_argument('--model', default=None, type=str) 33 | parser.add_argument('--model_dir', default=None, type=str) 34 | parser.add_argument('--train', action='store_true') 35 | parser.add_argument('--eval', action='store_true') 36 | parser.add_argument('--train_annot_fp', default="./annot_train.csv", type=str) 37 | parser.add_argument('--val_annot_fp', default="./annot_val.csv", type=str) 38 | parser.add_argument('--max_iter', default=10000, type=int) 39 | parser.add_argument('--lr', default=3e-4, type=float) 40 | parser.add_argument('--ims_per_batch', default=4, type=int) 41 | parser.add_argument('--warmup_iters', default=1000, type=int) 42 | parser.add_argument('--gamma', default=0.5, type=float) 43 | parser.add_argument('--lr_decay_steps', default=[100000,], type=int, nargs='*') 44 | return parser 45 | 46 | 47 | def get_detectron_dicts(annot_path): 48 | """ 49 | Create a Detectron2's standard dataset dicts from an annotation file. 50 | 51 | Args: 52 | annot_path (str): path to the annotation file. 53 | Return: 54 | dataset_dicts (list[dict]): List of annotation dictionaries for Detectron2. 55 | """ 56 | print("Loading annotation from ", annot_path) 57 | # Load annotation DataFrame 58 | annot_df = pd.read_csv(annot_path) 59 | 60 | # Get all images in `annot_df` 61 | img_ids = annot_df["image_id"].unique().tolist() 62 | 63 | # Convert `bbox` column from string to list 64 | annot_df['bbox'] = annot_df['bbox'].apply(ast.literal_eval) 65 | 66 | dataset_dicts = [] 67 | for img_id in tqdm(img_ids): 68 | file_name = f'train/{img_id}.jpg' 69 | 70 | record = {} 71 | record['file_name'] = file_name 72 | record['image_id'] = img_id 73 | record['height'] = 1024 74 | record['width'] = 1024 75 | 76 | # Extract bboxes from annotation file 77 | bboxes = annot_df[annot_df['image_id'] == img_id]['bbox'] 78 | 79 | annots = [] 80 | for bbox in bboxes: 81 | annot = { 82 | "bbox": bbox, 83 | "bbox_mode": BoxMode.XYWH_ABS, 84 | "category_id": 0, 85 | } 86 | annots.append(annot) 87 | 88 | record["annotations"] = annots 89 | dataset_dicts.append(record) 90 | return dataset_dicts 91 | 92 | 93 | def main(args): 94 | # Register datasets 95 | print("Registering wheat_detection_train") 96 | DatasetCatalog.register("wheat_detection_train", lambda path=args.train_annot_fp: get_detectron_dicts(path)) 97 | MetadataCatalog.get("wheat_detection_train").set(thing_classes=["Wheat"]) 98 | 99 | print("Registering wheat_detection_val") 100 | DatasetCatalog.register("wheat_detection_val", lambda path=args.val_annot_fp: get_detectron_dicts(path)) 101 | MetadataCatalog.get("wheat_detection_val").set(thing_classes=["Wheat"]) 102 | 103 | # Set up configurations 104 | cfg = get_cfg() 105 | if not args.model_dir: 106 | cfg.merge_from_file(model_zoo.get_config_file(f"COCO-Detection/{args.model}.yaml")) 107 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(f"COCO-Detection/{args.model}.yaml") 108 | cfg.DATASETS.TRAIN = ("wheat_detection_train",) 109 | cfg.DATASETS.TEST = ("wheat_detection_val",) 110 | 111 | cfg.SOLVER.IMS_PER_BATCH = args.ims_per_batch 112 | cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 113 | cfg.SOLVER.BASE_LR = args.lr 114 | cfg.SOLVER.MAX_ITER = args.max_iter 115 | cfg.SOLVER.WARMUP_ITERS = args.warmup_iters 116 | cfg.SOLVER.GAMMA = args.gamma 117 | cfg.SOLVER.STEPS = args.lr_decay_steps 118 | 119 | cfg.DATALOADER.NUM_WORKERS = 6 120 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 121 | cfg.MODEL.RETINANET.NUM_CLASSES = 1 122 | cfg.OUTPUT_DIR = f"{args.model}__iter-{args.max_iter}__lr-{args.lr}" 123 | if os.path.exists(cfg.OUTPUT_DIR): shutil.rmtree(cfg.OUTPUT_DIR) 124 | os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) 125 | 126 | # Save config 127 | with open(os.path.join(cfg.OUTPUT_DIR, "config.yaml"), "w") as f: 128 | f.write(cfg.dump()) 129 | else: 130 | print("Loading model from ", args.model_dir) 131 | cfg.merge_from_file(os.path.join(args.model_dir, "config.yaml")) 132 | cfg.MODEL.WEIGHTS = os.path.join(args.model_dir, "model_final.pth") 133 | cfg.OUTPUT_DIR = args.model_dir 134 | 135 | # Train 136 | setup_logger(output=os.path.join(cfg.OUTPUT_DIR, "terminal_output.log")) 137 | trainer = DefaultTrainer(cfg) 138 | trainer.resume_or_load(resume=False) 139 | 140 | if args.train: 141 | trainer.train() 142 | 143 | # Evaluate 144 | if args.eval: 145 | evaluator = COCOEvaluator("wheat_detection_val", cfg, False, output_dir=cfg.OUTPUT_DIR) 146 | eval_results = trainer.test(cfg=cfg, model=trainer.model, evaluators=evaluator) 147 | with open(os.path.join(cfg.OUTPUT_DIR, "eval_results.json"), "w") as f: 148 | json.dump(eval_results, f) 149 | 150 | if __name__ == '__main__': 151 | parser = get_args_parser() 152 | args = parser.parse_args() 153 | main(args) 154 | -------------------------------------------------------------------------------- /wheat-detection/train.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --train \ 3 | --train_annot_fp "./annot_train.csv" \ 4 | --eval \ 5 | --val_annot_fp "./annot_val.csv" \ 6 | --model "faster_rcnn_X_101_32x8d_FPN_3x" \ 7 | --max_iter 300 \ 8 | --lr 5e-4 \ 9 | --warmup_iters 100 \ 10 | --gamma 0.5 \ 11 | --lr_decay_steps 200 250 --------------------------------------------------------------------------------