├── .gitignore
├── 00-installation.ipynb
├── 01a-coco.ipynb
├── 01b-open-images.ipynb
├── 02-format-annotations.ipynb
├── 03-train-small.ipynb
├── 04-train-big.ipynb
├── 05-kaggle-global-wheat-detection.ipynb
├── README.md
├── images
├── input.jpg
├── output_03.png
├── output_04.png
├── output_05.png
└── sample.jpg
├── mscoco.py
├── open-images-v5
├── .gitignore
└── downloadOI.py
├── open-images-v6
├── .gitignore
├── downloadOI.py
├── eval.sh
├── main.py
└── train.sh
└── wheat-detection
├── eval.sh
├── main.py
└── train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | .ipynb_checkpoints
3 | kaggle.json
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Object Detection with Detectron2
2 |
3 |
4 |
5 | A series of notebooks to dive deep into popular datasets for object detection and learn how to train Detectron2 on a custom dataset.
6 |
7 | - [Notebook 00](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/00-installation.ipynb): Install Detectron2
8 | - [Notebook 01a](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/01a-coco.ipynb): Load and read COCO dataset with `COCO PythonAPI` and `GluonCV`
9 | - [Notebook 01b](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/01b-open-images.ipynb): Load and read Open Images v5
10 | - [Notebook 02](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/02-format-annotations.ipynb): Format Open Images annotations for Detectron2
11 | - [Notebook 03](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/03-train-small.ipynb): Train Detectron2 on a small dataset of Camera and Tripod
12 | - [Notebook 04](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/04-train-big.ipynb): Put all steps together: download Open Images v6 and train Detectron2 on a large dataset of 11 musical instruments.
13 | - [Notebook 05](https://github.com/chriskhanhtran/object-detection-detectron2/blob/master/05-kaggle-global-wheat-detection.ipynb): Apply Detectron2 to solve a real world challenge ([Kaggle Global Wheat Detection Competition](https://www.kaggle.com/c/global-wheat-detection))
14 |
15 | ## Sample Outputs
16 |
17 | *From Notebook 03: Camera and Tripod*
18 |
19 |
20 | *From Notebook 04: 11 Musical Instruments*
21 |
22 |
23 | *From Notebook 05: Global Wheat Detection Challenge*
24 |
--------------------------------------------------------------------------------
/images/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chriskhanhtran/object-detection-detectron2/91d013a94dac70cda3b4619cd1025d572b718fdc/images/input.jpg
--------------------------------------------------------------------------------
/images/output_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chriskhanhtran/object-detection-detectron2/91d013a94dac70cda3b4619cd1025d572b718fdc/images/output_03.png
--------------------------------------------------------------------------------
/images/output_04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chriskhanhtran/object-detection-detectron2/91d013a94dac70cda3b4619cd1025d572b718fdc/images/output_04.png
--------------------------------------------------------------------------------
/images/output_05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chriskhanhtran/object-detection-detectron2/91d013a94dac70cda3b4619cd1025d572b718fdc/images/output_05.png
--------------------------------------------------------------------------------
/images/sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chriskhanhtran/object-detection-detectron2/91d013a94dac70cda3b4619cd1025d572b718fdc/images/sample.jpg
--------------------------------------------------------------------------------
/mscoco.py:
--------------------------------------------------------------------------------
1 | """Prepare MS COCO datasets"""
2 | import os
3 | import shutil
4 | import argparse
5 | import zipfile
6 | from gluoncv.utils import download, makedirs
7 | from gluoncv.data.mscoco.utils import try_import_pycocotools
8 |
9 | _TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/coco')
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser(
13 | description='Initialize MS COCO dataset.',
14 | epilog='Example: python mscoco.py --download-dir ~/mscoco',
15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
16 | parser.add_argument('--download-dir', type=str, default='~/mscoco/', help='dataset directory on disk')
17 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set')
18 | parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted')
19 | args = parser.parse_args()
20 | return args
21 |
22 | def download_coco(path, overwrite=False):
23 | _DOWNLOAD_URLS = [
24 | ('http://images.cocodataset.org/zips/train2017.zip',
25 | '10ad623668ab00c62c096f0ed636d6aff41faca5'),
26 | ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
27 | '8551ee4bb5860311e79dace7e79cb91e432e78b3'),
28 | ('http://images.cocodataset.org/zips/val2017.zip',
29 | '4950dc9d00dbe1c933ee0170f5797584351d2a41'),
30 | # ('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip',
31 | # '46cdcf715b6b4f67e980b529534e79c2edffe084'),
32 | # test2017.zip, for those who want to attend the competition.
33 | # ('http://images.cocodataset.org/zips/test2017.zip',
34 | # '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'),
35 | ]
36 | makedirs(path)
37 | for url, checksum in _DOWNLOAD_URLS:
38 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
39 | # extract
40 | with zipfile.ZipFile(filename) as zf:
41 | zf.extractall(path=path)
42 |
43 | if __name__ == '__main__':
44 | args = parse_args()
45 | path = os.path.expanduser(args.download_dir)
46 | if not os.path.isdir(path) or not os.path.isdir(os.path.join(path, 'train2017')) \
47 | or not os.path.isdir(os.path.join(path, 'val2017')) \
48 | or not os.path.isdir(os.path.join(path, 'annotations')):
49 | if args.no_download:
50 | raise ValueError(('{} is not a valid directory, make sure it is present.'
51 | ' Or you should not disable "--no-download" to grab it'.format(path)))
52 | else:
53 | download_coco(path, overwrite=args.overwrite)
54 |
55 | # make symlink
56 | makedirs(os.path.expanduser('~/.mxnet/datasets'))
57 | if os.path.isdir(_TARGET_DIR):
58 | os.remove(_TARGET_DIR)
59 | os.symlink(path, _TARGET_DIR)
60 | try_import_pycocotools()
61 |
--------------------------------------------------------------------------------
/open-images-v5/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 | !downloadOI.py
--------------------------------------------------------------------------------
/open-images-v5/downloadOI.py:
--------------------------------------------------------------------------------
1 | #Author : Sunita Nayak, Big Vision LLC
2 |
3 | #### Usage example: python3 downloadOI.py --classes 'Ice_cream,Cookie' --mode train
4 |
5 | import argparse
6 | import csv
7 | import subprocess
8 | import os
9 | from tqdm import tqdm
10 | import multiprocessing
11 | from multiprocessing import Pool as thread_pool
12 |
13 | cpu_count = multiprocessing.cpu_count()
14 |
15 | parser = argparse.ArgumentParser(description='Download Class specific images from OpenImagesV4')
16 | parser.add_argument("--mode", help="Dataset category - train, validation or test", required=True)
17 | parser.add_argument("--classes", help="Names of object classes to be downloaded", required=True)
18 | parser.add_argument("--nthreads", help="Number of threads to use", required=False, type=int, default=cpu_count*2)
19 | parser.add_argument("--occluded", help="Include occluded images", required=False, type=int, default=1)
20 | parser.add_argument("--truncated", help="Include truncated images", required=False, type=int, default=1)
21 | parser.add_argument("--groupOf", help="Include groupOf images", required=False, type=int, default=1)
22 | parser.add_argument("--depiction", help="Include depiction images", required=False, type=int, default=1)
23 | parser.add_argument("--inside", help="Include inside images", required=False, type=int, default=1)
24 |
25 | args = parser.parse_args()
26 |
27 | run_mode = args.mode
28 |
29 | threads = args.nthreads
30 |
31 | classes = []
32 | for class_name in args.classes.split(','):
33 | classes.append(class_name)
34 |
35 | # Create a dictionary {class:class_id}
36 | with open('./class-descriptions-boxable.csv', mode='r') as infile:
37 | reader = csv.reader(infile)
38 | dict_list = {rows[1]:rows[0] for rows in reader}
39 |
40 | subprocess.run(['rm', '-rf', run_mode])
41 | subprocess.run([ 'mkdir', run_mode])
42 |
43 | pool = thread_pool(threads)
44 | commands = []
45 | cnt = 0
46 |
47 | for ind in range(0, len(classes)):
48 |
49 | class_name = classes[ind]
50 | print("Class "+str(ind) + " : " + class_name)
51 |
52 | # Create a folder for each class
53 | subprocess.run([ 'mkdir', run_mode+'/'+class_name])
54 |
55 | # Grep rows of our classes in the annotions-bbox.csv
56 | command = "grep "+dict_list[class_name.replace('_', ' ')] + " ./" + run_mode + "-annotations-bbox.csv"
57 | class_annotations = subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8')
58 | class_annotations = class_annotations.splitlines()
59 |
60 | # For each line we grepped
61 | for line in class_annotations:
62 |
63 | line_parts = line.split(',')
64 |
65 | #IsOccluded,IsTruncated,IsGroupOf,IsDepiction,IsInside
66 | if (args.occluded==0 and int(line_parts[8])>0):
67 | print("Skipped %s",line_parts[0])
68 | continue
69 | if (args.truncated==0 and int(line_parts[9])>0):
70 | print("Skipped %s",line_parts[0])
71 | continue
72 | if (args.groupOf==0 and int(line_parts[10])>0):
73 | print("Skipped %s",line_parts[0])
74 | continue
75 | if (args.depiction==0 and int(line_parts[11])>0):
76 | print("Skipped %s",line_parts[0])
77 | continue
78 | if (args.inside==0 and int(line_parts[12])>0):
79 | print("Skipped %s",line_parts[0])
80 | continue
81 |
82 | cnt = cnt + 1
83 |
84 | # Command to download image
85 | command = 'aws s3 --no-sign-request --only-show-errors cp s3://open-images-dataset/'+run_mode+'/'+line_parts[0]+'.jpg '+ run_mode+'/'+class_name+'/'+line_parts[0]+'.jpg'
86 | commands.append(command)
87 |
88 | with open('%s/%s/%s.txt'%(run_mode,class_name,line_parts[0]),'a') as f:
89 | f.write(','.join([class_name, line_parts[4], line_parts[5], line_parts[6], line_parts[7]])+'\n')
90 |
91 | print("Annotation Count : "+str(cnt))
92 | commands = list(set(commands))
93 | print("Number of images to be downloaded : "+str(len(commands)))
94 |
95 | list(tqdm(pool.imap(os.system, commands), total = len(commands) ))
96 |
97 | pool.close()
98 | pool.join()
99 |
--------------------------------------------------------------------------------
/open-images-v6/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !*.py
3 | !*.sh
--------------------------------------------------------------------------------
/open-images-v6/downloadOI.py:
--------------------------------------------------------------------------------
1 | # Author : Sunita Nayak, Big Vision LLC
2 |
3 | #### Usage example: python3 downloadOI.py --classes 'Ice_cream,Cookie' --mode train
4 |
5 | import argparse
6 | import csv
7 | import subprocess
8 | import os
9 | from tqdm import tqdm
10 | import multiprocessing
11 | from multiprocessing import Pool as thread_pool
12 |
13 | cpu_count = multiprocessing.cpu_count()
14 |
15 | parser = argparse.ArgumentParser(description="Download Class specific images from OpenImagesV4")
16 | parser.add_argument("--mode", help="Dataset category - train, validation or test", required=True)
17 | parser.add_argument("--classes", help="Names of object classes to be downloaded", required=True)
18 | parser.add_argument("--nthreads", help="Number of threads to use", required=False, type=int, default=cpu_count * 2)
19 | parser.add_argument("--occluded", help="Include occluded images", required=False, type=int, default=1)
20 | parser.add_argument("--truncated", help="Include truncated images", required=False, type=int, default=1)
21 | parser.add_argument("--groupOf", help="Include groupOf images", required=False, type=int, default=1)
22 | parser.add_argument("--depiction", help="Include depiction images", required=False, type=int, default=1)
23 | parser.add_argument("--inside", help="Include inside images", required=False, type=int, default=1)
24 |
25 | args = parser.parse_args()
26 |
27 | run_mode = args.mode
28 |
29 | threads = args.nthreads
30 |
31 | classes = []
32 | for class_name in args.classes.split(","):
33 | classes.append(class_name)
34 |
35 | # Read `class-descriptions-boxable.csv`
36 | with open("./class-descriptions-boxable.csv", mode="r") as infile:
37 | reader = csv.reader(infile)
38 | dict_list = {rows[1]: rows[0] for rows in reader} # rows[1] is ClassName, rows[0] is ClassCode
39 |
40 | subprocess.run(["rm", "-rf", run_mode])
41 | subprocess.run(["mkdir", run_mode])
42 |
43 | pool = thread_pool(threads)
44 | commands = []
45 | cnt = 0
46 |
47 | for ind in range(0, len(classes)):
48 | class_name = classes[ind]
49 | print("Class " + str(ind) + " : " + class_name)
50 |
51 | command = "grep " + dict_list[class_name.replace("_", " ")] + " ./" + run_mode + "-annotations-bbox.csv"
52 | class_annotations = subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode("utf-8")
53 | class_annotations = class_annotations.splitlines()
54 |
55 | for line in class_annotations:
56 | line_parts = line.split(",")
57 | img_id = line_parts[0]
58 | save_path = os.path.join(run_mode, img_id + ".jpg")
59 |
60 | # If image exists, skip
61 | if os.path.exists(save_path):
62 | continue
63 |
64 | # Download options: IsOccluded, IsTruncated, IsGroupOf, IsDepiction, IsInside
65 | if args.occluded == 0 and int(line_parts[8]) > 0:
66 | print("Skipped %s", img_id)
67 | continue
68 | if args.truncated == 0 and int(line_parts[9]) > 0:
69 | print("Skipped %s", img_id)
70 | continue
71 | if args.groupOf == 0 and int(line_parts[10]) > 0:
72 | print("Skipped %s", img_id)
73 | continue
74 | if args.depiction == 0 and int(line_parts[11]) > 0:
75 | print("Skipped %s", img_id)
76 | continue
77 | if args.inside == 0 and int(line_parts[12]) > 0:
78 | print("Skipped %s", img_id)
79 | continue
80 |
81 | # Command to download
82 | command = f"aws s3 --no-sign-request --only-show-errors cp s3://open-images-dataset/'{run_mode}'/'{img_id}'.jpg {save_path}"
83 | commands.append(command)
84 | cnt += 1
85 |
86 | print("Annotation Count : " + str(cnt))
87 | commands = list(set(commands))
88 | print("Number of images to be downloaded : " + str(len(commands)))
89 |
90 | list(tqdm(pool.imap(os.system, commands), total=len(commands)))
91 |
92 | pool.close()
93 | pool.join()
94 |
--------------------------------------------------------------------------------
/open-images-v6/eval.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --eval \
3 | --train_annot_fp "./small-train-annotations-bbox-target.csv" \
4 | --val_annot_fp "./validation-annotations-bbox-target.csv" \
5 | --model_dir "retinanet_R_101_FPN_3x__iter-100000__lr-0.0005__warmup-1000"
--------------------------------------------------------------------------------
/open-images-v6/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from tqdm import tqdm
3 | import os, json, shutil, cv2
4 | import numpy as np
5 | import pandas as pd
6 |
7 | # detectron2
8 | import detectron2
9 | from detectron2.utils.logger import setup_logger
10 | from detectron2 import model_zoo
11 | from detectron2.config import get_cfg
12 | from detectron2.structures import BoxMode
13 | from detectron2.data import DatasetCatalog, MetadataCatalog
14 | from detectron2.engine import DefaultTrainer
15 | from detectron2.evaluation import COCOEvaluator
16 |
17 | models = [
18 | "faster_rcnn_R_50_FPN_3x",
19 | "faster_rcnn_R_101_FPN_3x",
20 | "faster_rcnn_X_101_32x8d_FPN_3x", # best
21 | "retinanet_R_50_FPN_3x",
22 | "retinanet_R_101_FPN_3x",
23 | ]
24 |
25 | target_classes = [
26 | 'Accordion',
27 | 'Cello',
28 | 'Drum',
29 | 'French horn',
30 | 'Guitar',
31 | 'Musical keyboard',
32 | 'Piano',
33 | 'Saxophone',
34 | 'Trombone',
35 | 'Trumpet',
36 | 'Violin'
37 | ]
38 |
39 |
40 | def get_args_parser():
41 | parser = argparse.ArgumentParser('Set up Detectron2', add_help=False)
42 | parser.add_argument('--model', default=None, type=str)
43 | parser.add_argument('--model_dir', default=None, type=str)
44 | parser.add_argument('--train', action='store_true')
45 | parser.add_argument('--eval', action='store_true')
46 | parser.add_argument('--train_annot_fp', default=None, type=str)
47 | parser.add_argument('--val_annot_fp', default=None, type=str)
48 | parser.add_argument('--max_iter', default=10000, type=int)
49 | parser.add_argument('--lr', default=3e-4, type=float)
50 | parser.add_argument('--ims_per_batch', default=4, type=int)
51 | parser.add_argument('--warmup_iters', default=1000, type=int)
52 | parser.add_argument('--gamma', default=0.5, type=float)
53 | parser.add_argument('--lr_decay_steps', default=[100000,], type=int, nargs='*')
54 | return parser
55 |
56 |
57 | def denormalize_bboxes(bboxes, height, width):
58 | """Denormalize bounding boxes in format of (xmin, ymin, xmax, ymax)."""
59 | bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * width
60 | bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * height
61 | return np.round(bboxes)
62 |
63 |
64 | def get_detectron_dicts(annot_fp):
65 | """
66 | Create Detectron2's standard dataset from an annotation file.
67 |
68 | Args:
69 | annot_df (pd.DataFrame): annotation dataframe.
70 | Return:
71 | dataset_dicts (list[dict]): List of annotation dictionaries for Detectron2.
72 | """
73 | # Load annotatations
74 | annot_df = pd.read_csv(annot_fp)
75 |
76 | # Get image ids
77 | img_ids = annot_df["ImageID"].unique().tolist()
78 |
79 | dataset_dicts = []
80 | for img_id in tqdm(img_ids):
81 | file_name = f'images/{img_id}.jpg'
82 | if not os.path.exists(file_name): continue
83 | height, width = cv2.imread(file_name).shape[:2]
84 |
85 | record = {}
86 | record['file_name'] = file_name
87 | record['image_id'] = img_id
88 | record['height'] = height
89 | record['width'] = width
90 |
91 | # Extract bboxes from annotation file
92 | bboxes = annot_df[['XMin', 'YMin', 'XMax', 'YMax']][annot_df['ImageID'] == img_id].values
93 | bboxes = denormalize_bboxes(bboxes, height, width)
94 | class_ids = annot_df[['ClassID']][annot_df['ImageID'] == img_id].values
95 |
96 | annots = []
97 | for i, bbox in enumerate(bboxes.tolist()):
98 | annot = {
99 | "bbox": bbox,
100 | "bbox_mode": BoxMode.XYXY_ABS,
101 | "category_id": int(class_ids[i]),
102 | }
103 | annots.append(annot)
104 |
105 | record["annotations"] = annots
106 | dataset_dicts.append(record)
107 | return dataset_dicts
108 |
109 |
110 | def main(args):
111 | # Register datasets
112 | print("Registering music_train")
113 | DatasetCatalog.register("music_train", lambda path=args.train_annot_fp: get_detectron_dicts(path))
114 | MetadataCatalog.get("music_train").set(thing_classes=target_classes)
115 |
116 | print("Registering music_val")
117 | DatasetCatalog.register("music_val", lambda path=args.val_annot_fp: get_detectron_dicts(path))
118 | MetadataCatalog.get("music_val").set(thing_classes=target_classes)
119 |
120 | # Set up configurations
121 | cfg = get_cfg()
122 | if not args.model_dir:
123 | cfg.merge_from_file(model_zoo.get_config_file(f"COCO-Detection/{args.model}.yaml"))
124 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(f"COCO-Detection/{args.model}.yaml")
125 | cfg.DATASETS.TRAIN = ("music_train",)
126 | cfg.DATASETS.TEST = ("music_val",)
127 |
128 | cfg.SOLVER.IMS_PER_BATCH = args.ims_per_batch
129 | cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
130 | cfg.SOLVER.BASE_LR = args.lr
131 | cfg.SOLVER.MAX_ITER = args.max_iter
132 | cfg.SOLVER.WARMUP_ITERS = args.warmup_iters
133 | cfg.SOLVER.GAMMA = args.gamma
134 | cfg.SOLVER.STEPS = args.lr_decay_steps
135 |
136 | cfg.DATALOADER.NUM_WORKERS = 6
137 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(target_classes)
138 | cfg.MODEL.RETINANET.NUM_CLASSES = len(target_classes)
139 | cfg.OUTPUT_DIR = f"{args.model}__iter-{args.max_iter}__lr-{args.lr}"
140 | if os.path.exists(cfg.OUTPUT_DIR): shutil.rmtree(cfg.OUTPUT_DIR)
141 | os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
142 |
143 | # Save config
144 | with open(os.path.join(cfg.OUTPUT_DIR, "config.yaml"), "w") as f:
145 | f.write(cfg.dump())
146 | else:
147 | print("Loading model from ", args.model_dir)
148 | cfg.merge_from_file(os.path.join(args.model_dir, "config.yaml"))
149 | cfg.MODEL.WEIGHTS = os.path.join(args.model_dir, "model_final.pth")
150 | cfg.OUTPUT_DIR = args.model_dir
151 | cfg.DATASETS.TRAIN = ("music_train",)
152 | cfg.DATASETS.TEST = ("music_val",)
153 |
154 | # Set up trainer
155 | setup_logger(output=os.path.join(cfg.OUTPUT_DIR, "terminal_output.log"))
156 | trainer = DefaultTrainer(cfg)
157 | trainer.resume_or_load(resume=False)
158 |
159 | # Train
160 | if args.train:
161 | trainer.train()
162 |
163 | # Evaluate
164 | if args.eval:
165 | evaluator = COCOEvaluator("music_val", cfg, False, output_dir=cfg.OUTPUT_DIR)
166 | eval_results = trainer.test(cfg=cfg, model=trainer.model, evaluators=evaluator)
167 | with open(os.path.join(cfg.OUTPUT_DIR, "eval_results.json"), "w") as f:
168 | json.dump(eval_results, f)
169 |
170 | if __name__ == '__main__':
171 | parser = get_args_parser()
172 | args = parser.parse_args()
173 | main(args)
174 |
--------------------------------------------------------------------------------
/open-images-v6/train.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --train \
3 | --train_annot_fp "./small-train-annotations-bbox-target.csv" \
4 | --eval \
5 | --val_annot_fp "./validation-annotations-bbox-target.csv" \
6 | --model "faster_rcnn_X_101_32x8d_FPN_3x" \
7 | --max_iter 300 \
8 | --lr 5e-4 \
9 | --gamma 0.5 \
10 | --lr_decay_steps 250 300 \
--------------------------------------------------------------------------------
/wheat-detection/eval.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --eval \
3 | --val_annot_fp "./annot_val.csv" \
4 | --model_dir retinanet_R_101_FPN_3x__iter-10000__lr-0.0005 \
--------------------------------------------------------------------------------
/wheat-detection/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from tqdm import tqdm
3 | import os
4 | import numpy as np
5 | import pandas as pd
6 | import ast
7 | import json
8 | import shutil
9 |
10 | # import some common detectron2 utilities
11 | import detectron2
12 | from detectron2.utils.logger import setup_logger
13 | from detectron2 import model_zoo
14 | from detectron2.config import get_cfg
15 | from detectron2.structures import BoxMode
16 | from detectron2.data import DatasetCatalog, MetadataCatalog
17 | from detectron2.engine import DefaultTrainer
18 | from detectron2.evaluation import COCOEvaluator
19 |
20 |
21 | models = [
22 | "faster_rcnn_R_50_FPN_3x",
23 | "faster_rcnn_R_101_FPN_3x",
24 | "faster_rcnn_X_101_32x8d_FPN_3x",
25 | "retinanet_R_50_FPN_3x",
26 | "retinanet_R_101_FPN_3x",
27 | ]
28 |
29 |
30 | def get_args_parser():
31 | parser = argparse.ArgumentParser('Set up Detectron2', add_help=False)
32 | parser.add_argument('--model', default=None, type=str)
33 | parser.add_argument('--model_dir', default=None, type=str)
34 | parser.add_argument('--train', action='store_true')
35 | parser.add_argument('--eval', action='store_true')
36 | parser.add_argument('--train_annot_fp', default="./annot_train.csv", type=str)
37 | parser.add_argument('--val_annot_fp', default="./annot_val.csv", type=str)
38 | parser.add_argument('--max_iter', default=10000, type=int)
39 | parser.add_argument('--lr', default=3e-4, type=float)
40 | parser.add_argument('--ims_per_batch', default=4, type=int)
41 | parser.add_argument('--warmup_iters', default=1000, type=int)
42 | parser.add_argument('--gamma', default=0.5, type=float)
43 | parser.add_argument('--lr_decay_steps', default=[100000,], type=int, nargs='*')
44 | return parser
45 |
46 |
47 | def get_detectron_dicts(annot_path):
48 | """
49 | Create a Detectron2's standard dataset dicts from an annotation file.
50 |
51 | Args:
52 | annot_path (str): path to the annotation file.
53 | Return:
54 | dataset_dicts (list[dict]): List of annotation dictionaries for Detectron2.
55 | """
56 | print("Loading annotation from ", annot_path)
57 | # Load annotation DataFrame
58 | annot_df = pd.read_csv(annot_path)
59 |
60 | # Get all images in `annot_df`
61 | img_ids = annot_df["image_id"].unique().tolist()
62 |
63 | # Convert `bbox` column from string to list
64 | annot_df['bbox'] = annot_df['bbox'].apply(ast.literal_eval)
65 |
66 | dataset_dicts = []
67 | for img_id in tqdm(img_ids):
68 | file_name = f'train/{img_id}.jpg'
69 |
70 | record = {}
71 | record['file_name'] = file_name
72 | record['image_id'] = img_id
73 | record['height'] = 1024
74 | record['width'] = 1024
75 |
76 | # Extract bboxes from annotation file
77 | bboxes = annot_df[annot_df['image_id'] == img_id]['bbox']
78 |
79 | annots = []
80 | for bbox in bboxes:
81 | annot = {
82 | "bbox": bbox,
83 | "bbox_mode": BoxMode.XYWH_ABS,
84 | "category_id": 0,
85 | }
86 | annots.append(annot)
87 |
88 | record["annotations"] = annots
89 | dataset_dicts.append(record)
90 | return dataset_dicts
91 |
92 |
93 | def main(args):
94 | # Register datasets
95 | print("Registering wheat_detection_train")
96 | DatasetCatalog.register("wheat_detection_train", lambda path=args.train_annot_fp: get_detectron_dicts(path))
97 | MetadataCatalog.get("wheat_detection_train").set(thing_classes=["Wheat"])
98 |
99 | print("Registering wheat_detection_val")
100 | DatasetCatalog.register("wheat_detection_val", lambda path=args.val_annot_fp: get_detectron_dicts(path))
101 | MetadataCatalog.get("wheat_detection_val").set(thing_classes=["Wheat"])
102 |
103 | # Set up configurations
104 | cfg = get_cfg()
105 | if not args.model_dir:
106 | cfg.merge_from_file(model_zoo.get_config_file(f"COCO-Detection/{args.model}.yaml"))
107 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(f"COCO-Detection/{args.model}.yaml")
108 | cfg.DATASETS.TRAIN = ("wheat_detection_train",)
109 | cfg.DATASETS.TEST = ("wheat_detection_val",)
110 |
111 | cfg.SOLVER.IMS_PER_BATCH = args.ims_per_batch
112 | cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
113 | cfg.SOLVER.BASE_LR = args.lr
114 | cfg.SOLVER.MAX_ITER = args.max_iter
115 | cfg.SOLVER.WARMUP_ITERS = args.warmup_iters
116 | cfg.SOLVER.GAMMA = args.gamma
117 | cfg.SOLVER.STEPS = args.lr_decay_steps
118 |
119 | cfg.DATALOADER.NUM_WORKERS = 6
120 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
121 | cfg.MODEL.RETINANET.NUM_CLASSES = 1
122 | cfg.OUTPUT_DIR = f"{args.model}__iter-{args.max_iter}__lr-{args.lr}"
123 | if os.path.exists(cfg.OUTPUT_DIR): shutil.rmtree(cfg.OUTPUT_DIR)
124 | os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
125 |
126 | # Save config
127 | with open(os.path.join(cfg.OUTPUT_DIR, "config.yaml"), "w") as f:
128 | f.write(cfg.dump())
129 | else:
130 | print("Loading model from ", args.model_dir)
131 | cfg.merge_from_file(os.path.join(args.model_dir, "config.yaml"))
132 | cfg.MODEL.WEIGHTS = os.path.join(args.model_dir, "model_final.pth")
133 | cfg.OUTPUT_DIR = args.model_dir
134 |
135 | # Train
136 | setup_logger(output=os.path.join(cfg.OUTPUT_DIR, "terminal_output.log"))
137 | trainer = DefaultTrainer(cfg)
138 | trainer.resume_or_load(resume=False)
139 |
140 | if args.train:
141 | trainer.train()
142 |
143 | # Evaluate
144 | if args.eval:
145 | evaluator = COCOEvaluator("wheat_detection_val", cfg, False, output_dir=cfg.OUTPUT_DIR)
146 | eval_results = trainer.test(cfg=cfg, model=trainer.model, evaluators=evaluator)
147 | with open(os.path.join(cfg.OUTPUT_DIR, "eval_results.json"), "w") as f:
148 | json.dump(eval_results, f)
149 |
150 | if __name__ == '__main__':
151 | parser = get_args_parser()
152 | args = parser.parse_args()
153 | main(args)
154 |
--------------------------------------------------------------------------------
/wheat-detection/train.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --train \
3 | --train_annot_fp "./annot_train.csv" \
4 | --eval \
5 | --val_annot_fp "./annot_val.csv" \
6 | --model "faster_rcnn_X_101_32x8d_FPN_3x" \
7 | --max_iter 300 \
8 | --lr 5e-4 \
9 | --warmup_iters 100 \
10 | --gamma 0.5 \
11 | --lr_decay_steps 200 250
--------------------------------------------------------------------------------