├── README.md ├── main.py ├── requirements.txt ├── task1 ├── run.py ├── superglue │ └── utils.py └── task1_utils.py ├── task2 └── final_2nd │ ├── __pycache__ │ ├── model.cpython-36.pyc │ └── utils.cpython-36.pyc │ ├── lib │ ├── Task2.py │ ├── __pycache__ │ │ ├── Task2.cpython-36.pyc │ │ ├── Task2.cpython-37.pyc │ │ ├── marg_model.cpython-36.pyc │ │ ├── marg_model.cpython-37.pyc │ │ ├── marg_utils.cpython-36.pyc │ │ └── marg_utils.cpython-37.pyc │ ├── marg_model.py │ ├── marg_utils.py │ ├── mel_fbank.npy │ └── window.npy │ └── main.py ├── task3 └── conf │ └── task3.yaml └── task4 ├── main.py ├── main_inference.py └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # drone_ai_challenge 2 | >2021 Drone AI challenge 3 | 통합코드 4 | 5 | **How to run** 6 | 7 | python main.py 8 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import time 5 | #task3 6 | import task3.lib 7 | import task3.lib.predict as pred 8 | 9 | #task2 10 | from task2.lib.Task2 import * 11 | 12 | #task4 13 | from task4.main_inference import * 14 | 15 | #task1 16 | from task1.run import * 17 | 18 | 19 | def make_final_json(task1_answer,task2_answer, task3_answer,task4_answer): 20 | final_json = dict() 21 | final_json["task1_answer"] = task1_answer 22 | final_json["task2_answer"] = task2_answer 23 | final_json["task3_answer"] = task3_answer 24 | final_json["task4_answer"] = task4_answer 25 | 26 | with open(json_path, 'w', encoding='utf-8') as make_file: 27 | json.dump(final_json, make_file,ensure_ascii=False,indent=3 ) 28 | 29 | def save_task3_answer(set_num, pred_data): 30 | data, data_chk = {}, {} 31 | indent = 3 32 | 33 | data["task3_answer"] = [] 34 | 35 | data["task3_answer"].append({f"{set_num}": pred_data}) 36 | 37 | return data 38 | 39 | 40 | 41 | def main(args): 42 | start = time.time() 43 | if os.path.exists(args.json_path) == True: 44 | os.remove(args.json_path) 45 | #n_set = 5 if args.set_num == "all_sets" else 1 46 | #print(args.set_num) 47 | # task_2 - 5 set 에 대해서 이미 다 구현 48 | # data_path 나중에 대회 데이터셋 경로로 바꾸기 49 | data_path = '/home/agc2021/dataset/' 50 | #~~~~~~~~~ task_2 ~~~~~~~~~~~ 51 | 52 | start_task2 = time.time() 53 | 54 | task2_answer = task2_inference(data_path,5) 55 | print("TASK2 : ",time.time()-start_task2) 56 | print(task2_answer) 57 | 58 | 59 | #~~~~~~~~~ task_4 ~~~~~~~~~~~ 60 | #data_path = '/home/shinyeong/final_dataset/' 61 | start_task4 = time.time() 62 | task4_answer = task4_main(data_path) 63 | print("TASK4 : ",time.time() - start_task4) 64 | print(task4_answer) 65 | 66 | 67 | ##############dataset 경로 절대경로로 넣어야함!!!!!################ 68 | 69 | for i in range(5): #n_set : number of set 70 | args.set_num = f"set_0" + str(i+1) 71 | #print("count : ", args.set_num) 72 | set_num = f"set_0" + str(i+1) 73 | #~~~~~~~~~ task_1 ~~~~~~~~~~~ 74 | 75 | task1_video_path = data_path + set_num #final_dataset/set_01 76 | task1_img_path = data_path + set_num 77 | task1_frame_skip = 30 78 | task1_answer = task1_main(task1_video_path, task1_img_path, task1_frame_skip) 79 | 80 | 81 | #~~~~~~~~~ task_3 ~~~~~~~~~~~ 82 | 83 | t3_data = [] 84 | t3 = pred.func_task3(args) 85 | t3_res_pred_move, t3_res_pred_stay, t3_res_pred_total = t3.run() 86 | t3_data.append(t3_res_pred_move) 87 | t3_data.append(t3_res_pred_stay) 88 | t3_data.append(t3_res_pred_total) 89 | task3_answer = save_task3_answer(set_num,t3_data) 90 | 91 | #~~~~~~~~~ task_4 ~~~~~~~~~~~ 92 | 93 | 94 | 95 | 96 | make_final_json(task1_answer,task2_answer, task3_answer,task4_answer) 97 | print("TOTAL INFERENCE TIME : ", time.time()-start) 98 | 99 | 100 | if __name__ == '__main__': 101 | p=argparse.ArgumentParser() 102 | # path 103 | # p.add_argument("--dataset_dir", type=str, default="/home/agc2021/dataset") # /set_01, /set_02, /set_03, /set_04, /set_05 104 | # p.add_argument("--root_dir", type=str, default="/home/[Team_ID]") 105 | # p.add_argument("--temporary_dir", type=str, default="/home/agc2021/temp") 106 | ### 107 | json_path = "answersheet_3_00_Rony2.json" 108 | p.add_argument("--dataset_dir", type=str, default="/home/shinyeong/final_dataset") # /set_01, /set_02, /set_03, /set_04, /set_05 109 | p.add_argument("--root_dir", type=str, default="./") 110 | p.add_argument("--temporary_dir", type=str, default="../output3") 111 | 112 | ### 113 | p.add_argument("--json_path", type=str, default="answersheet_3_00_Rony2.json") 114 | p.add_argument("--task_num", type=str, default="task3_answer") 115 | p.add_argument("--set_num", type=str, default="all_set") 116 | p.add_argument("--device", default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 117 | p.add_argument("--test", type = int, default = '3', help = 'number of video,3') 118 | p.add_argument("--release_mode", type=bool, default = True) 119 | 120 | args = p.parse_args() 121 | 122 | main(args) 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.15.0 2 | addict==2.4.0 3 | beautifulsoup4==4.10.0 4 | cachetools==4.2.4 5 | certifi==2021.5.30 6 | charset-normalizer==2.0.7 7 | click==7.1.2 8 | colorama==0.4.4 9 | conformer==0.2.5 10 | cycler==0.11.0 11 | Cython==0.29.24 12 | dataclasses==0.8 13 | decorator==4.4.2 14 | easydict==1.9 15 | easyocr==1.4.1 16 | einops==0.3.0 17 | filelock==3.3.2 18 | future==0.18.2 19 | google-auth==2.3.3 20 | google-auth-oauthlib==0.4.6 21 | grpcio==1.41.1 22 | idna==3.3 23 | imageio==2.9.0 24 | importlib-metadata==4.8.1 25 | julius==0.2.6 26 | kiwisolver==1.3.1 27 | Markdown==3.3.4 28 | matplotlib==3.3.4 29 | model-index==0.1.11 30 | networkx==2.5.1 31 | numpy==1.19.5 32 | oauthlib==3.1.1 33 | opencv-python==4.5.4.58 34 | opencv-python-headless==4.5.4.58 35 | ordered-set==4.0.2 36 | packaging==21.2 37 | pandas==1.1.5 38 | Pillow==8.2.0 39 | protobuf==3.19.1 40 | pyasn1==0.4.8 41 | pyasn1-modules==0.2.8 42 | pycocotools==2.0.2 43 | pyparsing==2.4.7 44 | PySocks==1.7.1 45 | python-bidi==0.4.2 46 | python-dateutil==2.8.2 47 | pytz==2021.3 48 | PyWavelets==1.1.1 49 | PyYAML==6.0 50 | requests==2.26.0 51 | rsa==4.7.2 52 | scikit-image==0.17.2 53 | scipy==1.5.4 54 | seaborn==0.11.2 55 | six==1.16.0 56 | soupsieve==2.3 57 | tabulate==0.8.9 58 | tensorboard==2.7.0 59 | tensorboard-data-server==0.6.1 60 | tensorboard-plugin-wit==1.8.0 61 | terminaltables==3.1.0 62 | tifffile==2020.9.3 63 | torchaudio==0.7.0 64 | tqdm==4.62.3 65 | typing-extensions==3.10.0.2 66 | urllib3==1.26.7 67 | Werkzeug==2.0.2 68 | yapf==0.31.0 69 | zipp==3.6.0 70 | -------------------------------------------------------------------------------- /task1/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | from threading import Thread 3 | import sys 4 | import time 5 | import cv2 6 | import numpy 7 | import argparse 8 | import math 9 | from glob import glob 10 | import numpy as np 11 | import json 12 | import torch 13 | 14 | from task1.task1_utils import match_pairs, MatchImageSizeTo,ocr 15 | 16 | #os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 17 | #os.environ["CUDA_VISIBLE_DEVICES"]="0" 18 | 19 | 20 | 21 | def task1(frames, imgs, search_radius=30, ocr_batch_size=10, match_batch_size=8): 22 | frame_total = len(frames) 23 | 24 | ######################### Frame-Image Matching ########################### 25 | print('start image matching') 26 | match_results = match_pairs(frames, imgs, match_batch_size, 'cuda') 27 | torch.cuda.empty_cache() 28 | 29 | ######################### mask frames ########################### 30 | print('masking') 31 | vid_mask = np.zeros(frame_total).astype(np.int) 32 | img_idx = [] 33 | for match_res in match_results: 34 | img_idx.append(match_res[0]) 35 | if match_res[0] == -1: 36 | continue 37 | idx = np.arange(search_radius*2+1) - search_radius + match_res[0] 38 | idx = np.clip(idx,0,frame_total-1).astype(np.int) 39 | vid_mask[idx] = 1 40 | masked_frame_idx = np.where(vid_mask==1)[0] 41 | frames = np.stack(frames, axis=0)[vid_mask==1] 42 | 43 | ######################### OCR ########################### 44 | print('start ocr') 45 | texts = [] 46 | text_idx = [] 47 | Iters = math.ceil(masked_frame_idx.shape[0]/ocr_batch_size) 48 | from tqdm import tqdm 49 | with tqdm(total=Iters) as pbar: 50 | for i in range(Iters): 51 | start = i*ocr_batch_size 52 | if i == Iters-1: 53 | end = masked_frame_idx.shape[0] 54 | else: 55 | end = (i+1)*ocr_batch_size 56 | ocr(frames[start:end], start, masked_frame_idx, texts, text_idx) 57 | pbar.update(1) 58 | torch.cuda.empty_cache() 59 | 60 | answer = [] 61 | for i, i_idx in enumerate(img_idx): 62 | ans = 'NONE' 63 | if i_idx == -1: 64 | answer.append(ans) 65 | continue 66 | min_dist = search_radius+1 67 | for j,t_idx in enumerate(text_idx): 68 | if abs(t_idx-i_idx) > search_radius: 69 | continue 70 | if abs(t_idx-i_idx) < min_dist: 71 | min_dist = abs(t_idx-i_idx) 72 | ans = texts[j] 73 | answer.append(ans) 74 | 75 | print(answer) 76 | 77 | return answer 78 | 79 | def task1_main(video_path, img_path, frame_skip): 80 | #print(video_path) 81 | start = time.time() 82 | #parser = argparse.ArgumentParser() 83 | #parser.add_argument('--video_path', default='/home/jaewon/drone/samples', help='video path') 84 | #parser.add_argument('--img_path', default='/home/jaewon/drone/samples', help='image path') 85 | #parser.add_argument('--output_path', default='output.json', help='output path') 86 | #parser.add_argument('--frame_skip', type=int, default=30, help='output path') 87 | #args = parser.parse_args() 88 | 89 | # f = open(args.output_path,'w') 90 | final_result = { 91 | "task1_answer":[{ 92 | "set_1": [], 93 | "set_2": [], 94 | "set_3": [], 95 | "set_4": [], 96 | "set_5": [] 97 | }] 98 | } 99 | 100 | imgs=[] 101 | 102 | img_list = glob(os.path.join(img_path, "*.jpg")) 103 | #print(img_list) 104 | img_list.sort() 105 | for img_ in img_list: 106 | if "rescue" in img_ : 107 | img = cv2.imread(img_, cv2.IMREAD_GRAYSCALE) 108 | img = MatchImageSizeTo()(img) 109 | imgs.append(img) 110 | 111 | vid_list = glob(os.path.join(video_path, "*.mp4")) 112 | vid_list.sort() 113 | #print("--------------------") 114 | #print(vid_list, img_list) 115 | for vid_path in vid_list: 116 | vid_name = vid_path.split('/')[-1].split('.')[0].split('_') 117 | set_num = "set_{}".format(vid_name[0][-1]) 118 | drone_num = "drone_{}".format(vid_name[1][-1]) 119 | 120 | frames = [] 121 | cap = cv2.VideoCapture(vid_path) 122 | while (cap.isOpened()): 123 | ret, frame = cap.read() 124 | frame_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) 125 | if(type(frame) == type(None)): 126 | break 127 | if frame_pos % frame_skip != 0 : 128 | continue 129 | frames.append(frame) 130 | cap.release() 131 | result = task1(frames, imgs) 132 | final_result["task1_answer"][0][set_num].append({drone_num:result}) 133 | 134 | print(final_result) 135 | #with open(args.output_path, 'w') as f: 136 | # json.dump(final_result, f) 137 | 138 | print("TASK1 TIME :", time.time()-start) 139 | 140 | return final_result 141 | 142 | -------------------------------------------------------------------------------- /task1/superglue/utils.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # Daniel DeTone 39 | # Tomasz Malisiewicz 40 | # 41 | # %AUTHORS_END% 42 | # --------------------------------------------------------------------*/ 43 | # %BANNER_END% 44 | 45 | from pathlib import Path 46 | import time 47 | from collections import OrderedDict 48 | from threading import Thread 49 | import numpy as np 50 | import cv2 51 | import torch 52 | import matplotlib.pyplot as plt 53 | import matplotlib 54 | matplotlib.use('Agg') 55 | 56 | class AverageTimer: 57 | """ Class to help manage printing simple timing of code execution. """ 58 | 59 | def __init__(self, smoothing=0.3, newline=False): 60 | self.smoothing = smoothing 61 | self.newline = newline 62 | self.times = OrderedDict() 63 | self.will_print = OrderedDict() 64 | self.reset() 65 | 66 | def reset(self): 67 | now = time.time() 68 | self.start = now 69 | self.last_time = now 70 | for name in self.will_print: 71 | self.will_print[name] = False 72 | 73 | def update(self, name='default'): 74 | now = time.time() 75 | dt = now - self.last_time 76 | if name in self.times: 77 | dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name] 78 | self.times[name] = dt 79 | self.will_print[name] = True 80 | self.last_time = now 81 | 82 | def print(self, text='Timer'): 83 | total = 0. 84 | print('[{}]'.format(text), end=' ') 85 | for key in self.times: 86 | val = self.times[key] 87 | if self.will_print[key]: 88 | print('%s=%.3f' % (key, val), end=' ') 89 | total += val 90 | print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ') 91 | if self.newline: 92 | print(flush=True) 93 | else: 94 | print(end='\r', flush=True) 95 | self.reset() 96 | 97 | 98 | class VideoStreamer: 99 | """ Class to help process image streams. Four types of possible inputs:" 100 | 1.) USB Webcam. 101 | 2.) An IP camera 102 | 3.) A directory of images (files in directory matching 'image_glob'). 103 | 4.) A video file, such as an .mp4 or .avi file. 104 | """ 105 | def __init__(self, basedir, resize, skip, image_glob, max_length=1000000): 106 | self._ip_grabbed = False 107 | self._ip_running = False 108 | self._ip_camera = False 109 | self._ip_image = None 110 | self._ip_index = 0 111 | self.cap = [] 112 | self.camera = True 113 | self.video_file = False 114 | self.listing = [] 115 | self.resize = resize 116 | self.interp = cv2.INTER_AREA 117 | self.i = 0 118 | self.skip = skip 119 | self.max_length = max_length 120 | if isinstance(basedir, int) or basedir.isdigit(): 121 | print('==> Processing USB webcam input: {}'.format(basedir)) 122 | self.cap = cv2.VideoCapture(int(basedir)) 123 | self.listing = range(0, self.max_length) 124 | elif basedir.startswith(('', 'rtsp')): 125 | print('==> Processing IP camera input: {}'.format(basedir)) 126 | self.cap = cv2.VideoCapture(basedir) 127 | self.start_ip_camera_thread() 128 | self._ip_camera = True 129 | self.listing = range(0, self.max_length) 130 | elif Path(basedir).is_dir(): 131 | print('==> Processing image directory input: {}'.format(basedir)) 132 | self.listing = list(Path(basedir).glob(image_glob[0])) 133 | for j in range(1, len(image_glob)): 134 | image_path = list(Path(basedir).glob(image_glob[j])) 135 | self.listing = self.listing + image_path 136 | self.listing.sort() 137 | self.listing = self.listing[::self.skip] 138 | self.max_length = np.min([self.max_length, len(self.listing)]) 139 | if self.max_length == 0: 140 | raise IOError('No images found (maybe bad \'image_glob\' ?)') 141 | self.listing = self.listing[:self.max_length] 142 | self.camera = False 143 | elif Path(basedir).exists(): 144 | print('==> Processing video input: {}'.format(basedir)) 145 | self.cap = cv2.VideoCapture(basedir) 146 | self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) 147 | num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) 148 | self.listing = range(0, num_frames) 149 | self.listing = self.listing[::self.skip] 150 | self.video_file = True 151 | self.max_length = np.min([self.max_length, len(self.listing)]) 152 | self.listing = self.listing[:self.max_length] 153 | else: 154 | raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir)) 155 | if self.camera and not self.cap.isOpened(): 156 | raise IOError('Could not read camera') 157 | 158 | def load_image(self, impath): 159 | """ Read image as grayscale and resize to img_size. 160 | Inputs 161 | impath: Path to input image. 162 | Returns 163 | grayim: uint8 numpy array sized H x W. 164 | """ 165 | grayim = cv2.imread(impath, 0) 166 | if grayim is None: 167 | raise Exception('Error reading image %s' % impath) 168 | w, h = grayim.shape[1], grayim.shape[0] 169 | w_new, h_new = process_resize(w, h, self.resize) 170 | grayim = cv2.resize( 171 | grayim, (w_new, h_new), interpolation=self.interp) 172 | return grayim 173 | 174 | def next_frame(self): 175 | """ Return the next frame, and increment internal counter. 176 | Returns 177 | image: Next H x W image. 178 | status: True or False depending whether image was loaded. 179 | """ 180 | 181 | if self.i == self.max_length: 182 | return (None, False) 183 | if self.camera: 184 | 185 | if self._ip_camera: 186 | #Wait for first image, making sure we haven't exited 187 | while self._ip_grabbed is False and self._ip_exited is False: 188 | time.sleep(.001) 189 | 190 | ret, image = self._ip_grabbed, self._ip_image.copy() 191 | if ret is False: 192 | self._ip_running = False 193 | else: 194 | ret, image = self.cap.read() 195 | if ret is False: 196 | print('VideoStreamer: Cannot get image from camera') 197 | return (None, False) 198 | w, h = image.shape[1], image.shape[0] 199 | if self.video_file: 200 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i]) 201 | 202 | w_new, h_new = process_resize(w, h, self.resize) 203 | image = cv2.resize(image, (w_new, h_new), 204 | interpolation=self.interp) 205 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 206 | else: 207 | image_file = str(self.listing[self.i]) 208 | image = self.load_image(image_file) 209 | self.i = self.i + 1 210 | return (image, True) 211 | 212 | def start_ip_camera_thread(self): 213 | self._ip_thread = Thread(target=self.update_ip_camera, args=()) 214 | self._ip_running = True 215 | self._ip_thread.start() 216 | self._ip_exited = False 217 | return self 218 | 219 | def update_ip_camera(self): 220 | while self._ip_running: 221 | ret, img = self.cap.read() 222 | if ret is False: 223 | self._ip_running = False 224 | self._ip_exited = True 225 | self._ip_grabbed = False 226 | return 227 | 228 | self._ip_image = img 229 | self._ip_grabbed = ret 230 | self._ip_index += 1 231 | #print('IPCAMERA THREAD got frame {}'.format(self._ip_index)) 232 | 233 | 234 | def cleanup(self): 235 | self._ip_running = False 236 | 237 | # --- PREPROCESSING --- 238 | 239 | def process_resize(w, h, resize): 240 | assert(len(resize) > 0 and len(resize) <= 2) 241 | if len(resize) == 1 and resize[0] > -1: 242 | scale = resize[0] / max(h, w) 243 | w_new, h_new = int(round(w*scale)), int(round(h*scale)) 244 | elif len(resize) == 1 and resize[0] == -1: 245 | w_new, h_new = w, h 246 | else: # len(resize) == 2: 247 | w_new, h_new = resize[0], resize[1] 248 | 249 | # Issue warning if resolution is too small or too large. 250 | if max(w_new, h_new) < 160: 251 | print('Warning: input resolution is very small, results may vary') 252 | elif max(w_new, h_new) > 2000: 253 | print('Warning: input resolution is very large, results may vary') 254 | 255 | return w_new, h_new 256 | 257 | 258 | def frame2tensor(frame, device): 259 | return torch.from_numpy(frame/255.).float()[None, None].to(device) 260 | 261 | 262 | def read_image(path, device, resize, rotation, resize_float): 263 | image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) 264 | if image is None: 265 | return None, None, None 266 | w, h = image.shape[1], image.shape[0] 267 | w_new, h_new = process_resize(w, h, resize) 268 | scales = (float(w) / float(w_new), float(h) / float(h_new)) 269 | 270 | if resize_float: 271 | image = cv2.resize(image.astype('float32'), (w_new, h_new)) 272 | else: 273 | image = cv2.resize(image, (w_new, h_new)).astype('float32') 274 | 275 | if rotation != 0: 276 | image = np.rot90(image, k=rotation) 277 | if rotation % 2: 278 | scales = scales[::-1] 279 | 280 | inp = frame2tensor(image, device) 281 | return image, inp, scales 282 | 283 | 284 | # --- GEOMETRY --- 285 | 286 | 287 | def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): 288 | if len(kpts0) < 5: 289 | return None 290 | 291 | f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) 292 | norm_thresh = thresh / f_mean 293 | 294 | kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] 295 | kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] 296 | 297 | E, mask = cv2.findEssentialMat( 298 | kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, 299 | method=cv2.RANSAC) 300 | 301 | assert E is not None 302 | 303 | best_num_inliers = 0 304 | ret = None 305 | for _E in np.split(E, len(E) / 3): 306 | n, R, t, _ = cv2.recoverPose( 307 | _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) 308 | if n > best_num_inliers: 309 | best_num_inliers = n 310 | ret = (R, t[:, 0], mask.ravel() > 0) 311 | return ret 312 | 313 | 314 | def rotate_intrinsics(K, image_shape, rot): 315 | """image_shape is the shape of the image after rotation""" 316 | assert rot <= 3 317 | h, w = image_shape[:2][::-1 if (rot % 2) else 1] 318 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 319 | rot = rot % 4 320 | if rot == 1: 321 | return np.array([[fy, 0., cy], 322 | [0., fx, w-1-cx], 323 | [0., 0., 1.]], dtype=K.dtype) 324 | elif rot == 2: 325 | return np.array([[fx, 0., w-1-cx], 326 | [0., fy, h-1-cy], 327 | [0., 0., 1.]], dtype=K.dtype) 328 | else: # if rot == 3: 329 | return np.array([[fy, 0., h-1-cy], 330 | [0., fx, cx], 331 | [0., 0., 1.]], dtype=K.dtype) 332 | 333 | 334 | def rotate_pose_inplane(i_T_w, rot): 335 | rotation_matrices = [ 336 | np.array([[np.cos(r), -np.sin(r), 0., 0.], 337 | [np.sin(r), np.cos(r), 0., 0.], 338 | [0., 0., 1., 0.], 339 | [0., 0., 0., 1.]], dtype=np.float32) 340 | for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] 341 | ] 342 | return np.dot(rotation_matrices[rot], i_T_w) 343 | 344 | 345 | def scale_intrinsics(K, scales): 346 | scales = np.diag([1./scales[0], 1./scales[1], 1.]) 347 | return np.dot(scales, K) 348 | 349 | 350 | def to_homogeneous(points): 351 | return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) 352 | 353 | 354 | def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1): 355 | kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] 356 | kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] 357 | kpts0 = to_homogeneous(kpts0) 358 | kpts1 = to_homogeneous(kpts1) 359 | 360 | t0, t1, t2 = T_0to1[:3, 3] 361 | t_skew = np.array([ 362 | [0, -t2, t1], 363 | [t2, 0, -t0], 364 | [-t1, t0, 0] 365 | ]) 366 | E = t_skew @ T_0to1[:3, :3] 367 | 368 | Ep0 = kpts0 @ E.T # N x 3 369 | p1Ep0 = np.sum(kpts1 * Ep0, -1) # N 370 | Etp1 = kpts1 @ E # N x 3 371 | d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) 372 | + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) 373 | return d 374 | 375 | 376 | def angle_error_mat(R1, R2): 377 | cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 378 | cos = np.clip(cos, -1., 1.) # numercial errors can make it out of bounds 379 | return np.rad2deg(np.abs(np.arccos(cos))) 380 | 381 | 382 | def angle_error_vec(v1, v2): 383 | n = np.linalg.norm(v1) * np.linalg.norm(v2) 384 | return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) 385 | 386 | 387 | def compute_pose_error(T_0to1, R, t): 388 | R_gt = T_0to1[:3, :3] 389 | t_gt = T_0to1[:3, 3] 390 | error_t = angle_error_vec(t, t_gt) 391 | error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation 392 | error_R = angle_error_mat(R, R_gt) 393 | return error_t, error_R 394 | 395 | 396 | def pose_auc(errors, thresholds): 397 | sort_idx = np.argsort(errors) 398 | errors = np.array(errors.copy())[sort_idx] 399 | recall = (np.arange(len(errors)) + 1) / len(errors) 400 | errors = np.r_[0., errors] 401 | recall = np.r_[0., recall] 402 | aucs = [] 403 | for t in thresholds: 404 | last_index = np.searchsorted(errors, t) 405 | r = np.r_[recall[:last_index], recall[last_index-1]] 406 | e = np.r_[errors[:last_index], t] 407 | aucs.append(np.trapz(r, x=e)/t) 408 | return aucs 409 | 410 | 411 | # --- VISUALIZATION --- 412 | 413 | 414 | def plot_image_pair(imgs, dpi=100, size=6, pad=.5): 415 | n = len(imgs) 416 | assert n == 2, 'number of images must be two' 417 | figsize = (size*n, size*3/4) if size is not None else None 418 | _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) 419 | for i in range(n): 420 | ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255) 421 | ax[i].get_yaxis().set_ticks([]) 422 | ax[i].get_xaxis().set_ticks([]) 423 | for spine in ax[i].spines.values(): # remove frame 424 | spine.set_visible(False) 425 | plt.tight_layout(pad=pad) 426 | 427 | 428 | def plot_keypoints(kpts0, kpts1, color='w', ps=2): 429 | ax = plt.gcf().axes 430 | ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) 431 | ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) 432 | 433 | 434 | def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4): 435 | fig = plt.gcf() 436 | ax = fig.axes 437 | fig.canvas.draw() 438 | 439 | transFigure = fig.transFigure.inverted() 440 | fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0)) 441 | fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1)) 442 | 443 | fig.lines = [matplotlib.lines.Line2D( 444 | (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1, 445 | transform=fig.transFigure, c=color[i], linewidth=lw) 446 | for i in range(len(kpts0))] 447 | ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) 448 | ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) 449 | 450 | 451 | def make_matching_plot(image0, image1, kpts0, kpts1, mkpts0, mkpts1, 452 | color, text, path, show_keypoints=False, 453 | fast_viz=False, opencv_display=False, 454 | opencv_title='matches', small_text=[]): 455 | 456 | if fast_viz: 457 | make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, mkpts1, 458 | color, text, path, show_keypoints, 10, 459 | opencv_display, opencv_title, small_text) 460 | return 461 | 462 | plot_image_pair([image0, image1]) 463 | if show_keypoints: 464 | plot_keypoints(kpts0, kpts1, color='k', ps=4) 465 | plot_keypoints(kpts0, kpts1, color='w', ps=2) 466 | plot_matches(mkpts0, mkpts1, color) 467 | 468 | fig = plt.gcf() 469 | txt_color = 'k' if image0[:100, :150].mean() > 200 else 'w' 470 | fig.text( 471 | 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, 472 | fontsize=15, va='top', ha='left', color=txt_color) 473 | 474 | txt_color = 'k' if image0[-100:, :150].mean() > 200 else 'w' 475 | fig.text( 476 | 0.01, 0.01, '\n'.join(small_text), transform=fig.axes[0].transAxes, 477 | fontsize=5, va='bottom', ha='left', color=txt_color) 478 | 479 | plt.savefig(str(path), bbox_inches='tight', pad_inches=0) 480 | plt.close() 481 | 482 | 483 | def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, 484 | mkpts1, color, text, path=None, 485 | show_keypoints=False, margin=10, 486 | opencv_display=False, opencv_title='', 487 | small_text=[]): 488 | H0, W0 = image0.shape 489 | H1, W1 = image1.shape 490 | H, W = max(H0, H1), W0 + W1 + margin 491 | 492 | out = 255*np.ones((H, W), np.uint8) 493 | out[:H0, :W0] = image0 494 | out[:H1, W0+margin:] = image1 495 | out = np.stack([out]*3, -1) 496 | 497 | if show_keypoints: 498 | kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) 499 | white = (255, 255, 255) 500 | black = (0, 0, 0) 501 | for x, y in kpts0: 502 | cv2.circle(out, (x, y), 2, black, -1, lineType=cv2.LINE_AA) 503 | cv2.circle(out, (x, y), 1, white, -1, lineType=cv2.LINE_AA) 504 | for x, y in kpts1: 505 | cv2.circle(out, (x + margin + W0, y), 2, black, -1, 506 | lineType=cv2.LINE_AA) 507 | cv2.circle(out, (x + margin + W0, y), 1, white, -1, 508 | lineType=cv2.LINE_AA) 509 | 510 | mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int) 511 | color = (np.array(color[:, :3])*255).astype(int)[:, ::-1] 512 | for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color): 513 | c = c.tolist() 514 | cv2.line(out, (x0, y0), (x1 + margin + W0, y1), 515 | color=c, thickness=1, lineType=cv2.LINE_AA) 516 | # display line end-points as circles 517 | cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA) 518 | cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1, 519 | lineType=cv2.LINE_AA) 520 | 521 | # Scale factor for consistent visualization across scales. 522 | sc = min(H / 640., 2.0) 523 | 524 | # Big text. 525 | Ht = int(30 * sc) # text height 526 | txt_color_fg = (255, 255, 255) 527 | txt_color_bg = (0, 0, 0) 528 | for i, t in enumerate(text): 529 | cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, 530 | 1.0*sc, txt_color_bg, 2, cv2.LINE_AA) 531 | cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, 532 | 1.0*sc, txt_color_fg, 1, cv2.LINE_AA) 533 | 534 | # Small text. 535 | Ht = int(18 * sc) # text height 536 | for i, t in enumerate(reversed(small_text)): 537 | cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, 538 | 0.5*sc, txt_color_bg, 2, cv2.LINE_AA) 539 | cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, 540 | 0.5*sc, txt_color_fg, 1, cv2.LINE_AA) 541 | 542 | if path is not None: 543 | cv2.imwrite(str(path), out) 544 | 545 | if opencv_display: 546 | cv2.imshow(opencv_title, out) 547 | cv2.waitKey(1) 548 | 549 | return out 550 | 551 | 552 | def error_colormap(x): 553 | return np.clip( 554 | np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)], -1), 0, 1) 555 | -------------------------------------------------------------------------------- /task1/task1_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import random 4 | import numpy as np 5 | import matplotlib.cm as cm 6 | import torch 7 | import os 8 | import math 9 | import cv2 10 | from glob import glob 11 | 12 | import easyocr 13 | from task1.superglue.superpoint import SuperPoint 14 | from task1.superglue.superglue import SuperGlue 15 | from task1.superglue.utils import (compute_pose_error, compute_epipolar_error, 16 | estimate_pose, make_matching_plot, 17 | error_colormap, pose_auc, read_image, 18 | rotate_intrinsics, rotate_pose_inplane, 19 | scale_intrinsics) 20 | 21 | class MatchImageSizeTo(object): 22 | def __init__(self, size=1080): 23 | self.size=size 24 | 25 | def __call__(self, img): 26 | H, W = img.shape 27 | 28 | if H>=W: 29 | W_size = int(W/H * self.size * (1920/1450)) 30 | # W_size = int(W/H * self.size) 31 | img_new = cv2.resize(img, (W_size, self.size)) 32 | else: 33 | H_size = int(H/W * self.size * (1450/1920)) 34 | # H_size = int(H/W * self.size) 35 | img_new = cv2.resize(img, (self.size, H_size)) 36 | 37 | return img_new 38 | 39 | def ocr(frames, frame_idx_start, masked_frame_idx, texts, text_idx): 40 | reader = easyocr.Reader(['ko'], gpu=True) 41 | frame_num = len(frames) 42 | results = reader.readtext_batched(frames, 43 | batch_size=frame_num, 44 | output_format='dict', 45 | blocklist='!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZㆍ', 46 | min_size = 5, 47 | text_threshold=0.6) 48 | for i,result in enumerate(results): 49 | res = [] 50 | for res_ in result: 51 | word = res_['text'] 52 | if len(word) <= 2: 53 | continue 54 | if (word.endswith('실') or word.endswith('과')) and word[:3].isdigit(): 55 | res.append({'text': word, 'confident':res_['confident']}) 56 | continue 57 | if len(res) == 0: 58 | continue 59 | text = max(res, key=lambda x: x['confident'])['text'] 60 | texts.append(text) 61 | text_idx.append(masked_frame_idx[frame_idx_start+i]) 62 | 63 | def match_pairs(vid_, imgs, vid_batch, device, 64 | match_num_rate_threshold=0.02, 65 | superglue='indoor', 66 | max_keypoints = 1024, 67 | keypoint_threshold = 0.0, 68 | nms_radius = 4, 69 | sinkhorn_iterations = 15, 70 | match_threshold = 0.2): 71 | """ 72 | Args: 73 | vid_: list of numpy vid frames, range 0~255, shape H x W x 3 , BGR 74 | imgs: list of numpy images, range 0~255, shape H x W , Grayscale 75 | vid_batch: batch size for video 76 | Return: 77 | result: list of tuples (frame idx, match_rate) 78 | """ 79 | 80 | vid = [cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in vid_] 81 | 82 | torch.set_grad_enabled(False) 83 | 84 | config = { 85 | 'superpoint': { 86 | 'nms_radius': nms_radius, 87 | 'keypoint_threshold': keypoint_threshold, 88 | 'max_keypoints': max_keypoints 89 | }, 90 | 'superglue': { 91 | 'weights': superglue, 92 | 'sinkhorn_iterations': sinkhorn_iterations, 93 | 'match_threshold': match_threshold, 94 | } 95 | } 96 | 97 | superpoint = SuperPoint(config.get('superpoint', {})).eval().to(device) 98 | superglue = SuperGlue(config.get('superglue', {})).eval().to(device) 99 | 100 | T = len(vid) 101 | N = len(imgs) 102 | 103 | imgs = [torch.from_numpy(imgs[i]/255.).float()[None,None] for i in range(N)] # (1,1,H,W) 104 | imgs_kp = [] 105 | match_num_threshold = [] 106 | 107 | for img in imgs: 108 | img = img.to(device) 109 | kp = superpoint({'image': img}) # 'keypoints', 'scores', 'descriptors' 110 | kp = {**{k+'0': v for k, v in kp.items()}} 111 | for k in kp: 112 | if isinstance(kp[k], (list,tuple)): 113 | kp[k] = torch.stack(kp[k]) # (1,K,2), (1,K), (1,D,K) 114 | imgs_kp.append(kp) 115 | match_num_threshold.append(int(kp['keypoints0'].shape[1]*match_num_rate_threshold)) 116 | 117 | result = [[-1,0] for i in range(N)] 118 | vid_size = vid[0].shape[-2:] 119 | Iters = math.ceil(T/vid_batch) 120 | start = 0 121 | 122 | from tqdm import tqdm 123 | with tqdm(total=Iters) as pbar: 124 | for i in range(Iters): 125 | start = i * vid_batch 126 | if i == Iters-1: 127 | end = T 128 | else: 129 | end = (i+1) * vid_batch 130 | frames = [torch.from_numpy(vid[i]/255.).float()[None] for i in range(start,end)] #(1,H,W) 131 | frames = torch.stack(frames).to(device) # (B,1,H,W) 132 | vid_kp = superpoint({'image':frames}) 133 | vid_kp = {**{k+'1': v for k, v in vid_kp.items()}} 134 | for k in vid_kp: 135 | if isinstance(vid_kp[k], (list,tuple)): 136 | vid_kp[k] = torch.stack(vid_kp[k]) # (B,K,2), (B,K), (B,D,K) 137 | 138 | for n, img_kp_ in enumerate(imgs_kp): 139 | img_size = imgs[n].shape[-2:] 140 | img_kp = {} 141 | for k in img_kp_: 142 | if len(img_kp_[k].shape)==2: 143 | img_kp[k] = img_kp_[k].repeat((end-start),1) # (B,K,2), (B,K), (B,D,K) 144 | else: 145 | img_kp[k] = img_kp_[k].repeat((end-start),1,1) # (B,K,2), (B,K), (B,D,K) 146 | 147 | data = {**vid_kp, **img_kp, 'image0_shape': img_size, 'image1_shape': vid_size} 148 | pred = superglue(data) # matches0, matches1, matching_scores0, matching_scores1 149 | pred = {k:v.cpu().numpy() for k,v in pred.items()} # all (B,~1024) 150 | match_num = np.sum(pred['matches0']>-1, axis=1) # (B,) 151 | max_idx = np.argmax(match_num) 152 | 153 | if match_num[max_idx] < match_num_threshold[n]: 154 | continue 155 | elif match_num[max_idx] < result[n][1]: 156 | continue 157 | else: 158 | result[n][1] = match_num[max_idx] 159 | result[n][0] = start + max_idx 160 | pbar.update(1) 161 | 162 | return result -------------------------------------------------------------------------------- /task2/final_2nd/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/drone_ai_challenge/391b48f640eb84bfa01caa677c34b8b7c7c1eb7d/task2/final_2nd/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /task2/final_2nd/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/drone_ai_challenge/391b48f640eb84bfa01caa677c34b8b7c7c1eb7d/task2/final_2nd/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /task2/final_2nd/lib/Task2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | 4 | import os 5 | import sys 6 | import numpy as np 7 | import julius 8 | 9 | 10 | from task2.final_2nd.lib.marg_utils import * 11 | 12 | def task2_inference(data_path, set_nums): 13 | gpu = 0 14 | # post processing 15 | threshold = 0.2 16 | smooth = 13 17 | min_frame = 18 18 | 19 | module = __import__('task2.final_2nd.lib.marg_model', fromlist=['']) 20 | generator = module.model() 21 | generator.to('cuda:{}'.format(gpu)) 22 | generator.load_state_dict(torch.load('task2/final_2nd/weights/model.pt', map_location='cuda:{}'.format(gpu))) 23 | generator.eval() 24 | 25 | win = np.load('task2/final_2nd/lib/window.npy') 26 | win = torch.from_numpy(win).cuda(gpu) 27 | mel_fbank = np.load('task2/final_2nd/lib/mel_fbank.npy') 28 | mel_fbank = torch.from_numpy(mel_fbank).cuda(gpu) 29 | 30 | drone_num = 3 31 | answer_list = [] 32 | for i in range(set_nums): 33 | sub_answer_list = [] 34 | for j in range(drone_num): 35 | folder_name = 'set_0' + str(i+1) 36 | # file_name = 'set0' + str(i+1) + '_drone0' + str(j+1) + '_ch1.wav' 37 | # audio_48k, sr = torchaudio.load(os.path.join(data_path, folder_name, file_name)) 38 | 39 | audio_48k = 0 40 | for k in range(2): 41 | file_name = 'set0' + str(i+1) + '_drone0' + str(j+1) + '_ch' + str(k+1) + '.wav' 42 | audio_48ks, sr = torchaudio.load(os.path.join(data_path, folder_name, file_name)) 43 | audio_48k += audio_48ks 44 | audio_48k /= 2 45 | 46 | audio_16k = julius.resample_frac(audio_48k, sr, 16000).numpy()[0] 47 | with torch.no_grad(): 48 | m, f, b = inference(generator, audio_16k, win, mel_fbank, gpu=gpu) 49 | mf, ff, bf = postprocessing(m, f, b, threshold=threshold, smooth=smooth, min_frame=min_frame) 50 | sub_answer_list.append([list(mf), list(ff), list(bf)]) 51 | answer_list.append(sub_answer_list) 52 | out_str = new_answer_list_to_json(answer_list) 53 | # print(answer_list) 54 | # return answer_list, out_str 55 | return out_str 56 | 57 | # audio1, sr = librosa.load('/home/home/juheon/gc_2021/validation_data_new/set01_drone01_mono_16k.wav', sr=None, mono=True) 58 | # audio2, sr = librosa.load('/home/home/juheon/gc_2021/validation_data_new/set01_drone02_mono_16k.wav', sr=None, mono=True) 59 | # audio3, sr = librosa.load('/home/home/juheon/gc_2021/validation_data_new/set01_drone03_mono_16k.wav', sr=None, mono=True) 60 | 61 | # # model inference 62 | # with torch.no_grad(): 63 | # m1, f1, b1 = inference(generator, audio1, win, mel_fbank, gpu=gpu) 64 | # m2, f2, b2 = inference(generator, audio2, win, mel_fbank, gpu=gpu) 65 | # m3, f3, b3 = inference(generator, audio3, win, mel_fbank, gpu=gpu) 66 | 67 | 68 | 69 | # m1f, f1f, b1f = postprocessing(m1, f1, b1, threshold=threshold, smooth=smooth, min_frame=min_frame) 70 | # m2f, f2f, b2f = postprocessing(m2, f2, b2, threshold=threshold, smooth=smooth, min_frame=min_frame) 71 | # m3f, f3f, b3f = postprocessing(m3, f3, b3, threshold=threshold, smooth=smooth, min_frame=min_frame) 72 | 73 | # # write answer 74 | # answer_list = [ 75 | # [ 76 | # [list(m1f), list(f1f), list(b1f)], 77 | # [list(m2f), list(f2f), list(b2f)], 78 | # [list(m3f), list(f3f), list(b3f)], 79 | # ], 80 | # ] 81 | 82 | # # validation 83 | # gt_list = [ 84 | # [ 85 | # [[[187, 191], [194, 197]], [[199, 203]], [[51, 56], [202, 208]]], 86 | # [[[147, 151], [162, 165]], [[154, 158]], [[143, 148]]], 87 | # [[[197, 200], [204, 208]], [[198, 202]], [[210, 215]]], 88 | # ], 89 | # ] 90 | # s, d, i, er, correct = evaluation(gt_list, answer_list) 91 | 92 | 93 | -------------------------------------------------------------------------------- /task2/final_2nd/lib/__pycache__/Task2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/drone_ai_challenge/391b48f640eb84bfa01caa677c34b8b7c7c1eb7d/task2/final_2nd/lib/__pycache__/Task2.cpython-36.pyc -------------------------------------------------------------------------------- /task2/final_2nd/lib/__pycache__/Task2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/drone_ai_challenge/391b48f640eb84bfa01caa677c34b8b7c7c1eb7d/task2/final_2nd/lib/__pycache__/Task2.cpython-37.pyc -------------------------------------------------------------------------------- /task2/final_2nd/lib/__pycache__/marg_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/drone_ai_challenge/391b48f640eb84bfa01caa677c34b8b7c7c1eb7d/task2/final_2nd/lib/__pycache__/marg_model.cpython-36.pyc -------------------------------------------------------------------------------- /task2/final_2nd/lib/__pycache__/marg_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/drone_ai_challenge/391b48f640eb84bfa01caa677c34b8b7c7c1eb7d/task2/final_2nd/lib/__pycache__/marg_model.cpython-37.pyc -------------------------------------------------------------------------------- /task2/final_2nd/lib/__pycache__/marg_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/drone_ai_challenge/391b48f640eb84bfa01caa677c34b8b7c7c1eb7d/task2/final_2nd/lib/__pycache__/marg_utils.cpython-36.pyc -------------------------------------------------------------------------------- /task2/final_2nd/lib/__pycache__/marg_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/drone_ai_challenge/391b48f640eb84bfa01caa677c34b8b7c7c1eb7d/task2/final_2nd/lib/__pycache__/marg_utils.cpython-37.pyc -------------------------------------------------------------------------------- /task2/final_2nd/lib/marg_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from conformer import ConformerBlock 5 | 6 | 7 | 8 | class Enhancer(nn.Module): 9 | def __init__(self, dim=256, n_blocks=15, cf=5120): 10 | super(Enhancer, self).__init__() 11 | self.dim = dim 12 | self.cf = cf # prenet output c * f (64 * 51) 13 | self.n_blocks = n_blocks 14 | self.prenet = nn.Sequential( 15 | nn.Conv2d(1, self.dim//8, kernel_size=3, stride=1, padding=1), 16 | nn.ReLU(), 17 | nn.Conv2d(self.dim//8, self.dim//8, kernel_size=3, stride=1, padding=1), 18 | nn.ReLU(), 19 | nn.Conv2d(self.dim//8, self.dim//4, kernel_size=3, stride=1, padding=1), 20 | nn.ReLU(), 21 | ) 22 | self.conformer_block = nn.ModuleList() 23 | for i in range(self.n_blocks): 24 | self.conformer_block.append(ConformerBlock(dim=self.dim, dim_head=64, heads=4, ff_mult=4, conv_expansion_factor=2, conv_kernel_size=9, attn_dropout=0.1)) 25 | self.proj_in = nn.Conv1d(self.cf, self.dim, 1, 1) 26 | self.proj_out = nn.Conv1d(self.dim, 80, 1, 1) 27 | def forward(self, x): 28 | # x [b, f, t] 29 | x = self.prenet(x) 30 | b, c, f, t = x.size() 31 | x = x.permute(0,3,1,2) # [b, t, c, f] 32 | x = x.contiguous().view(b, t, c*f) # [b, t, f] 33 | x = x.permute(0,2,1) # [b, f, t] 34 | x = self.proj_in(x) 35 | x = x.permute(0,2,1) 36 | for layer in self.conformer_block: 37 | x = layer(x) 38 | x = x.permute(0,2,1) 39 | mask = self.proj_out(x) 40 | return mask.unsqueeze(1) 41 | 42 | 43 | class Classifier(nn.Module): 44 | def __init__(self, dim=256, n_blocks=15, cf=640): 45 | super(Classifier, self).__init__() 46 | self.dim = dim 47 | self.cf = cf # prenet output c * f (64 * 51) 48 | self.n_blocks = n_blocks 49 | self.prenet = nn.Sequential( 50 | nn.Conv2d(2, self.dim//8, kernel_size=3, stride=2, padding=1), 51 | nn.ReLU(), 52 | nn.Conv2d(self.dim//8, self.dim//8, kernel_size=3, stride=2, padding=1), 53 | nn.ReLU(), 54 | nn.Conv2d(self.dim//8, self.dim//4, kernel_size=3, stride=2, padding=1), 55 | nn.ReLU(), 56 | ) 57 | self.conformer_block = nn.ModuleList() 58 | for i in range(self.n_blocks): 59 | self.conformer_block.append(ConformerBlock(dim=self.dim, dim_head=64, heads=4, ff_mult=4, conv_expansion_factor=2, conv_kernel_size=9, attn_dropout=0.1)) 60 | self.proj_in = nn.Conv1d(self.cf, self.dim, 1, 1) 61 | self.proj_out = nn.Conv1d(self.dim, 3, 1, 1) 62 | def forward(self, x): 63 | # x [b, f, t] 64 | x = self.prenet(x) 65 | b, c, f, t = x.size() 66 | x = x.permute(0,3,1,2) # [b, t, c, f] 67 | x = x.contiguous().view(b, t, c*f) # [b, t, f] 68 | x = x.permute(0,2,1) # [b, f, t] 69 | x = self.proj_in(x) 70 | x = x.permute(0,2,1) 71 | for layer in self.conformer_block: 72 | x = layer(x) 73 | x = x.permute(0,2,1) 74 | logits = self.proj_out(x) 75 | return logits 76 | 77 | 78 | 79 | 80 | 81 | 82 | class model(nn.Module): 83 | def __init__(self): 84 | super(model, self).__init__() 85 | self.enhancer = Enhancer() 86 | self.classifier = Classifier() 87 | 88 | def forward(self, x): 89 | x_log = torch.log(x.clamp(min=1e-5)) 90 | mask = self.enhancer(x_log) 91 | # mask = torch.sigmoid(mask) 92 | logit = self.classifier(torch.cat((x,mask),axis=1)) 93 | return logit, mask 94 | 95 | -------------------------------------------------------------------------------- /task2/final_2nd/lib/marg_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import numpy as np 4 | 5 | def convert_intseclist_to_strseclist(int_sec_result): 6 | str_sec_list = [] 7 | for int_sec in int_sec_result: 8 | if int_sec == None: 9 | strsec = 'NONE' 10 | str_sec_list.append(strsec) 11 | else: 12 | min_str = str(int(int_sec // 60)) 13 | sec_str = str(int(int_sec % 60)) 14 | strsec = min_str.zfill(2) + ':' + sec_str.zfill(2) 15 | str_sec_list.append(strsec) 16 | 17 | return str_sec_list 18 | 19 | def convert_strseclist_to_intseclist(str_sec_result): 20 | int_sec_list = [] 21 | for str_sec in str_sec_result: 22 | if str_sec == 'NONE': 23 | intsec = None 24 | int_sec_list.append(intsec) 25 | else: 26 | # min_str = str(int(int_sec // 60)) 27 | # sec_str = str(int(int_sec % 60)) 28 | # strsec = min_str.zfill(2) + ':' + sec_str.zfill(2) 29 | # str_sec_list.append(strsec) 30 | int_sec_min = int(str_sec.split(':')[0]) 31 | int_sec_sec = int(str_sec.split(':')[1]) 32 | int_sec = int_sec_min * 60 + int_sec_sec 33 | int_sec_list.append(int_sec) 34 | return int_sec_list 35 | 36 | def organize_result_data(result_data): 37 | set = 'set_'+str(result_data['set_num']) 38 | structured_data = {set:[]} 39 | 40 | for drone_num, inference_list in enumerate(result_data['result']): 41 | drone_name = 'drone_'+str(drone_num+1) 42 | drone_result_dic = {drone_name:None} 43 | class_result_dic = {} 44 | for class_index, class_result in enumerate(inference_list): 45 | class_key_list = ['M', 'W', 'C'] 46 | class_key = class_key_list[class_index] 47 | str_sec_list = convert_intseclist_to_strseclist(class_result) 48 | class_result_dic.update({class_key:str_sec_list}) 49 | drone_result_dic[drone_name]=[class_result_dic] 50 | structured_data[set].append(drone_result_dic) 51 | 52 | structured_data 53 | return structured_data 54 | 55 | def new_answer_list_to_json(list): 56 | final_output_dic = {"task2_answer":[{}]} 57 | for set_index, set in enumerate(list): 58 | temp_dict = {'set_num': set_index+1, 'result':set} 59 | set_result = organize_result_data(temp_dict) 60 | final_output_dic['task2_answer'][0].update(set_result) 61 | return final_output_dic 62 | 63 | def output_dict_to_list(task2_result): 64 | drone_key_list = ['drone_1', 'drone_2', 'drone_3'] 65 | class_key_list = ['M', 'W', 'C'] 66 | answer_list = [] 67 | json_list = task2_result['task2_answer'][0] 68 | for i, set_num in enumerate(json_list): 69 | set_list = json_list[set_num] 70 | set_answer = [] 71 | for j, drone_list in enumerate(set_list): 72 | class_list = drone_list[drone_key_list[j]] 73 | drone_answer = [] 74 | for k, class_key in enumerate(class_key_list): 75 | time_list = class_list[0][class_key] 76 | time_answer = convert_strseclist_to_intseclist(time_list) 77 | drone_answer.append(time_answer) 78 | set_answer.append(drone_answer) 79 | answer_list.append(set_answer) 80 | return answer_list 81 | 82 | def write_json(result_dic, path='task2.json'): 83 | with open(path, 'w') as outfile: 84 | json.dump(result_dic, outfile, indent=2) 85 | 86 | def stft(wave, win, mel_fbank): 87 | stft_cplx = torch.stft(wave[:,:-1], 800, hop_length=200, win_length=800, window=win, center=True, pad_mode='reflect') 88 | stft_mag = torch.sqrt(stft_cplx[...,0:1]**2 + stft_cplx[...,1:2]**2)[...,0] 89 | mel_mag = torch.matmul(mel_fbank, stft_mag) 90 | return mel_mag 91 | 92 | def moving_average(a, n=3) : 93 | ret = np.cumsum(a, dtype=float) 94 | ret[n:] = ret[n:] - ret[:-n] 95 | return ret[n - 1:] / n 96 | 97 | def consecutive(data, stepsize=1): 98 | return np.split(data, np.where(np.diff(data) != stepsize)[0]+1) 99 | 100 | def consecutive_merge(data, threshold=10): 101 | merge = [] 102 | if len(data) == 1 and len(data[0])==0: 103 | return np.asarray(merge) 104 | else: 105 | temp = [data[0][0], data[0][-1]] 106 | for interval in data[1:]: 107 | if temp[-1] + threshold > interval[0]: 108 | temp[1] = interval[-1] 109 | else: 110 | merge.append(temp) 111 | temp = [interval[0], interval[-1]] 112 | merge.append(temp) 113 | return np.asarray(merge) 114 | 115 | def seg_to_answer(segment, prob, min_frame=3): 116 | answer = [] 117 | frame = np.zeros_like(prob) 118 | for item in segment: 119 | start, end = item 120 | if end-start > min_frame: 121 | weighted_center = (round)((prob[start:end] * np.arange(start, end)).sum() / prob[start:end].sum() / 10) 122 | answer.append(weighted_center) 123 | frame[start:end] = 2 124 | frame[(int)(weighted_center*10)-3:(int)(weighted_center*10)+3] = 1 125 | if len(answer) == 0: 126 | answer = [None] 127 | return np.asarray(answer), frame 128 | 129 | def inference(model, audio, win, mel_fbank, window=320, hop=160, gpu=0): 130 | num_overlap = (window-hop)//8 131 | audio = audio/(np.max(np.abs(audio))+1e-5) 132 | overlap = torch.linspace(0, 1, num_overlap).unsqueeze(0).unsqueeze(0).repeat(1,3,1).cuda(gpu) 133 | wave = torch.from_numpy(audio).unsqueeze(0).cuda(gpu) 134 | spec = stft(wave, win, mel_fbank).unsqueeze(1) 135 | length = spec.size()[-1] 136 | pad = ((length // hop + 1) * hop) - length 137 | spec = torch.cat((spec, torch.zeros(1, 1, 80, pad).cuda(gpu)), axis=-1) 138 | num_iters = (length-window)//hop + 2 139 | temp_spec = [spec[:, :, :, i*hop:i*hop+window] for i in range(num_iters)] 140 | temp_spec = torch.cat(temp_spec, axis=0) 141 | with torch.no_grad(): 142 | logits, recon = model(temp_spec) 143 | logits_sig_total = torch.sigmoid(logits) 144 | logits_seq = None 145 | for i in range(num_iters): 146 | logits_sig = logits_sig_total[i:i+1, :, :] 147 | if logits_seq is None: 148 | logits_seq = logits_sig 149 | else: 150 | logits_seq[:,:,-num_overlap:] = (1-overlap) * logits_seq[:,:,-num_overlap:] + overlap * logits_sig[:,:,:num_overlap] 151 | logits_seq = torch.cat((logits_seq, logits_sig[:,:,num_overlap:]), axis=-1) 152 | logits_seq = logits_seq[:,:,:(wave.size()[-1]//1600)] 153 | logits_seq = logits_seq.detach().cpu().numpy()[0] 154 | m, f, b = logits_seq 155 | return m, f, b 156 | 157 | def postprocessing(m, f, b, threshold=0.8, smooth=5, min_frame=3, merge_frame=10): 158 | # smoothing 159 | m_smooth = moving_average(np.pad(m, (smooth//2,smooth//2), mode='edge'), n=smooth) 160 | f_smooth = moving_average(np.pad(f, (smooth//2,smooth//2), mode='edge'), n=smooth) 161 | b_smooth = moving_average(np.pad(b, (smooth//2,smooth//2), mode='edge'), n=smooth) 162 | # threshold 163 | m_threshold = np.asarray(m_smooth>threshold, dtype='int') 164 | f_threshold = np.asarray(f_smooth>threshold, dtype='int') 165 | b_threshold = np.asarray(b_smooth>threshold, dtype='int') 166 | # index 167 | m_index = np.where(m_threshold==1)[0] 168 | f_index = np.where(f_threshold==1)[0] 169 | b_index = np.where(b_threshold==1)[0] 170 | # consecutive segment 171 | m_consecutive = consecutive(m_index) 172 | f_consecutive = consecutive(f_index) 173 | b_consecutive = consecutive(b_index) 174 | # merge if two segment close 175 | m_segment = consecutive_merge(m_consecutive, threshold=merge_frame) 176 | f_segment = consecutive_merge(f_consecutive, threshold=merge_frame) 177 | b_segment = consecutive_merge(b_consecutive, threshold=merge_frame) 178 | # middle point 179 | m_answer, m_frame = seg_to_answer(m_segment, m, min_frame) 180 | f_answer, f_frame = seg_to_answer(f_segment, f, min_frame) 181 | b_answer, b_frame = seg_to_answer(b_segment, b, min_frame) 182 | # plot 183 | m_answer = np.asarray(m_answer) 184 | f_answer = np.asarray(f_answer) 185 | b_answer = np.asarray(b_answer) 186 | return m_answer, f_answer, b_answer 187 | 188 | def cw_metrics_cal(gt, pred): 189 | correct = 0 190 | deletion = 0 191 | insertion = 0 192 | include_list = [] 193 | for ii in range(len(gt)): 194 | is_include = 0 195 | if gt[ii][0] == None: 196 | if pred[0] == None: 197 | correct += 1 198 | else: 199 | if pred[0] == None: 200 | deletion += 1 201 | else: 202 | gt_start, gt_end = gt[ii][0], gt[ii][1] 203 | for jj in range(len(pred)): 204 | if gt_start <= pred[jj] and pred[jj] <= gt_end: 205 | include_list.append(pred[jj]) 206 | is_include += 1 207 | if is_include == 0: 208 | deletion += 1 209 | elif is_include > 1: 210 | insertion = is_include - 1 211 | elif is_include == 1: 212 | correct += 1 213 | if pred[0] == None: 214 | substitution = 0 215 | else: 216 | substitution = len(pred) - len(list(set(include_list))) 217 | return substitution, deletion, insertion, correct 218 | 219 | def evaluation(answer_list, set_nums): 220 | gt_list = [] 221 | gt =[ 222 | [[[187, 191], [194, 197]], [[199, 203]], [[51, 56], [202, 208]]], 223 | [[[147, 151], [162, 165]], [[154, 158]], [[143, 148]]], 224 | [[[197, 200], [204, 208]], [[198, 202]], [[210, 215]]], 225 | ] 226 | for set_num in range(set_nums): 227 | gt_list.append(gt) 228 | set_num = len(gt_list) 229 | total_s, total_d, total_i, total_n, total_correct = 0, 0, 0, 0, 0 230 | for i in range(set_num): 231 | sw_s, sw_d, sw_i, sw_n, sw_correct = 0, 0, 0, 0, 0 232 | for j in range(3): # for drones 1 ~ 3 233 | dw_s, dw_d, dw_i, dw_n, dw_correct = 0, 0, 0, 0, 0 234 | for k in range(3): # for class man, woman, child 235 | cw_s, cw_d, cw_i, cw_correct = cw_metrics_cal(gt_list[i][j][k], answer_list[i][j][k]) 236 | dw_s += cw_s 237 | dw_d += cw_d 238 | dw_i += cw_i 239 | dw_n += len(gt_list[i][j][k]) 240 | dw_correct += cw_correct 241 | dw_er = (dw_s + dw_d + dw_i) / dw_n 242 | # print('Set', str(i), 'Drone', str(j), 's, d, i, er, correct:', dw_s, dw_d, dw_i, np.round(dw_er, 2), dw_correct) 243 | sw_s += dw_s 244 | sw_d += dw_d 245 | sw_i += dw_i 246 | sw_n += dw_n 247 | sw_er = (sw_s + sw_d + sw_i) / sw_n 248 | sw_correct += dw_correct 249 | total_s += sw_s 250 | total_d += sw_d 251 | total_i += sw_i 252 | total_n += sw_n 253 | total_er = (total_s + total_d + total_i) / total_n 254 | total_correct += sw_correct 255 | # print('Subtotal Set', str(i), 's, d, i, er, correct:', sw_s, sw_d, sw_i, np.round(sw_er, 2), sw_correct) 256 | print('Total', 's, d, i, er, correct:', total_s, total_d, total_i, np.round(total_er, 2), total_correct) 257 | return total_s, total_d, total_i, total_er, total_correct -------------------------------------------------------------------------------- /task2/final_2nd/lib/mel_fbank.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/drone_ai_challenge/391b48f640eb84bfa01caa677c34b8b7c7c1eb7d/task2/final_2nd/lib/mel_fbank.npy -------------------------------------------------------------------------------- /task2/final_2nd/lib/window.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/drone_ai_challenge/391b48f640eb84bfa01caa677c34b8b7c7c1eb7d/task2/final_2nd/lib/window.npy -------------------------------------------------------------------------------- /task2/final_2nd/main.py: -------------------------------------------------------------------------------- 1 | from task2.lib.Task2 import * 2 | 3 | # data_path = '/home/agc2021/dataset/' 4 | data_path = '../temp_data' 5 | 6 | # Task 2 7 | set_nums = 5 8 | json_str = task2_inference(data_path, set_nums) 9 | 10 | ''' Below code is for sanity check!!''' 11 | ''' Expected print for set_nums = 5 ''' 12 | ''' Total s, d, i, er, correct: 15 15 0 0.46 50 ''' 13 | # from lib.marg_utils import * 14 | # import json 15 | # json_dict = json.loads(json_str) 16 | # answer_list = json_to_answer(json_dict, set_nums) 17 | # evaluation(answer_list, set_nums) -------------------------------------------------------------------------------- /task3/conf/task3.yaml: -------------------------------------------------------------------------------- 1 | yolo_weights: task3/lib/yolov5/weights/yolov5m.pt 2 | deep_sort_weights: task3/lib/deep_sort/deep_sort/deep/checkpoint/ckpt.t7 3 | 4 | img_size: 640 5 | conf_thres: 0.5 6 | iou_thres: 0.6 7 | fourcc: mp4v 8 | 9 | show_vid: False 10 | save_vid: True 11 | save_txt: False 12 | save_jpg: True 13 | frame_skip : 2 14 | 15 | classes: 0 16 | agnostic_nms: True 17 | augment: True 18 | 19 | config_deepsort: task3/lib/deep_sort/configs/deep_sort.yaml 20 | half: True 21 | 22 | reid_thresh: 9.5 23 | ff_rm_thresh: 5.8 24 | -------------------------------------------------------------------------------- /task4/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import time 5 | #task3 6 | import task3.lib 7 | import task3.lib.predict as pred 8 | 9 | #task2 10 | from task2.final_1st.lib.Task2 import * 11 | 12 | #task4 13 | from task4.main_inference import * 14 | 15 | #task1 16 | from task1.run import * 17 | 18 | 19 | def make_final_json(task1_answer,task2_answer, task3_answer,task4_answer): 20 | final_json = dict() 21 | final_json["task1_answer"] = task1_answer 22 | final_json["task2_answer"] = task2_answer 23 | final_json["task3_answer"] = task3_answer 24 | final_json["task4_answer"] = task4_answer 25 | 26 | with open(json_path, 'w', encoding='utf-8') as make_file: 27 | json.dump(final_json, make_file,ensure_ascii=False,indent=3 ) 28 | 29 | def save_task3_answer(set_num, pred_data): 30 | 31 | data["task3_answer"].append({f"{set_num}": pred_data}) 32 | 33 | return data 34 | 35 | 36 | 37 | def main(args): 38 | start = time.time() 39 | if os.path.exists(args.json_path) == True: 40 | os.remove(args.json_path) 41 | 42 | data ={} 43 | data["task3_answer"] = [] 44 | 45 | #n_set = 5 if args.set_num == "all_sets" else 1 46 | #print(args.set_num) 47 | # task_2 - 5 set 에 대해서 이미 다 구현 48 | # data_path 나중에 대회 데이터셋 경로로 바꾸기 49 | data_path = '/home/shinyeong/final_dataset/' 50 | #~~~~~~~~~ task_2 ~~~~~~~~~~~ 51 | 52 | start_task2 = time.time() 53 | 54 | task2_answer = task2_inference(data_path,5) 55 | print("TASK2 : ",time.time()-start_task2) 56 | print(task2_answer) 57 | 58 | ''' 59 | #~~~~~~~~~ task_4 ~~~~~~~~~~~ 60 | #data_path = '/home/shinyeong/final_dataset/' 61 | start_task4 = time.time() 62 | task4_answer = task4_main(data_path) 63 | print("TASK4 : ",time.time() - start_task4) 64 | print(task4_answer) 65 | 66 | 67 | ##############dataset 경로 절대경로로 넣어야함!!!!!################ 68 | task1_answer = [{ 69 | "set_1": [], 70 | "set_2": [], 71 | "set_3": [], 72 | "set_4": [], 73 | "set_5": [] 74 | }] 75 | for i in range(5): #n_set : number of set 76 | args.set_num = f"set_0" + str(i+1) 77 | #print("count : ", args.set_num) 78 | set_num = f"set_0" + str(i+1) 79 | #~~~~~~~~~ task_1 ~~~~~~~~~~~ 80 | 81 | task1_video_path = data_path + set_num #final_dataset/set_01 82 | task1_img_path = data_path + set_num 83 | task1_frame_skip = 15 84 | task1_main(task1_video_path, task1_img_path, task1_frame_skip, task1_answer) 85 | 86 | 87 | #~~~~~~~~~ task_3 ~~~~~~~~~~~ 88 | 89 | t3_data = [] 90 | t3 = pred.func_task3(args) 91 | t3_res_pred_move, t3_res_pred_stay, t3_res_pred_total = t3.run() 92 | t3_data.append(t3_res_pred_move) 93 | t3_data.append(t3_res_pred_stay) 94 | t3_data.append(t3_res_pred_total) 95 | task3_answer = save_task3_answer(set_num,t3_data) 96 | 97 | #~~~~~~~~~ task_4 ~~~~~~~~~~~ 98 | 99 | 100 | 101 | 102 | make_final_json(task1_answer,task2_answer, task3_answer,task4_answer) 103 | print("TOTAL INFERENCE TIME : ", time.time()-start) 104 | ''' 105 | 106 | if __name__ == '__main__': 107 | p=argparse.ArgumentParser() 108 | # path 109 | # p.add_argument("--dataset_dir", type=str, default="/home/agc2021/dataset") # /set_01, /set_02, /set_03, /set_04, /set_05 110 | # p.add_argument("--root_dir", type=str, default="/home/[Team_ID]") 111 | # p.add_argument("--temporary_dir", type=str, default="/home/agc2021/temp") 112 | ### 113 | json_path = "answersheet_3_00_Rony2.json" 114 | p.add_argument("--dataset_dir", type=str, default="/home/shinyeong/final_dataset") # /set_01, /set_02, /set_03, /set_04, /set_05 115 | p.add_argument("--root_dir", type=str, default="./") 116 | p.add_argument("--temporary_dir", type=str, default="../output3") 117 | 118 | ### 119 | p.add_argument("--json_path", type=str, default="answersheet_3_00_Rony2.json") 120 | p.add_argument("--task_num", type=str, default="task3_answer") 121 | p.add_argument("--set_num", type=str, default="all_set") 122 | p.add_argument("--device", default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 123 | p.add_argument("--test", type = int, default = '3', help = 'number of video,3') 124 | p.add_argument("--release_mode", type=bool, default = True) 125 | 126 | args = p.parse_args() 127 | 128 | main(args) 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /task4/main_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from mmdet.apis import init_detector, inference_detector 3 | import mmcv 4 | import cv2 5 | import os 6 | import glob 7 | import numpy as np 8 | #parser = argparse.ArgumentParser() 9 | #parser.add_argument('--path', type=str, default='task3/dataset_path/') 10 | #args = parser.parse_args() 11 | def task4_main(path): 12 | # ##### person ##### 13 | config_file_person = "task4/configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco-person.py" 14 | checkpoint_file_person = "task4/checkpoints/faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth" 15 | # config_file_person = "configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco-person.py" 16 | # checkpoint_file_person = "checkpoints/faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth" 17 | # ################## 18 | # ##### triage ##### 19 | config_file_triage = "task4/configs/triage_config/triage_config.py" 20 | checkpoint_file_triage = "task4/work_dirs/0to3_500/ver7/epoch_41.pth" 21 | # config_file_triage = "configs/triage_config/triage_config.py" 22 | # checkpoint_file_triage = "work_dirs/0to3_500/ver7/epoch_41.pth" 23 | # ################## 24 | # # build the model from a config file and a checkpoint file 25 | model_person = init_detector(config_file_person, checkpoint_file_person, device="cuda:0") 26 | model_triage = init_detector(config_file_triage, checkpoint_file_triage, device="cuda:0") 27 | # set(번호)_drone(번호)_triage(번호).jpg 28 | # set(1~5)_drone(1~3)_triage(1~3).jpg 29 | # im_folder = "dataset_path/set0" 30 | im_folder = path 31 | # 제출 레이블 형식 : set_num(1~5), drone_num(1~3), frame_name[사망, 긴급, 응급, 비응급] 32 | #import pdb; pdb.set_trace() 33 | person_results = [] 34 | set_keys = ["set_1", "set_2", "set_3", "set_4", "set_5"] 35 | task4_answer = dict.fromkeys(set_keys) 36 | # #### Extract Region of Person #### 37 | for set_n in range(1,6): 38 | set_dict = dict() 39 | set_name = "set_"+ str(set_n) 40 | set_dir= im_folder + "set_0" + str(set_n) + "/" 41 | # print("set_dir : ",set_dir) 42 | for filename in glob.glob(set_dir): # filename : dataset_path/set_01/ 43 | # print("filename: ", filename) 44 | drone_1, drone_2, drone_3 = dict(), dict(), dict() 45 | file_list = os.listdir(filename) 46 | file_list = sorted(file_list) 47 | triage_list=[] 48 | # set_dict_list = [] 49 | for x in file_list: 50 | if 'triage' in x : 51 | #file_list = os.listdir(filename) # ["set02_drone03_triage02.jpg", "set02_drone01_triage01.jpg", "set02_drone01_triage03.jpg", "set02_drone03_triage01.jpg", "set02_drone01_triage02.jpg", "set02_drone02_triage01.jpg", "set02_drone02_triage02.jpg"] 52 | triage_list.append(x) 53 | drone_dict_list = dict() 54 | for file_idx in range(len(triage_list)): 55 | # print("file_idx: ", file_idx) 56 | answer_sheet = [0 for i in range(4)] 57 | ori_file = set_dir + triage_list[file_idx] 58 | # # print(file_list[i]) # set01_drone01_triage01.jpg 59 | # # print(set_dir + file_list[i]) # dataset_path/set01/set01_drone01_triage01.jpg 60 | result_person = inference_detector(model_person, ori_file) # person 61 | result_triage = inference_detector(model_triage, ori_file) # person 62 | labels = [ 63 | np.full(bbox.shape[0], idx, dtype=np.int32) 64 | for idx, bbox in enumerate(result_triage) 65 | ] 66 | labels = np.concatenate(labels) 67 | bboxes_person = np.vstack(result_person) 68 | bboxes_triage = np.vstack(result_triage) 69 | # class-based NMS 70 | tag_is_in_person = [False for t in range(bboxes_triage.shape[0])] 71 | for p in range(bboxes_person.shape[0]): 72 | person_pos = bboxes_person[p][:4] # each person has one or no triage tag. 73 | is_exist_pos = [] 74 | max_score = -1 75 | if bboxes_person[p][-1] < 0.5: continue 76 | for q in range(bboxes_triage.shape[0]): 77 | triage_pos = bboxes_triage[q][:4] 78 | if(triage_pos[0] >= person_pos[0]-50 and triage_pos[1] >= person_pos[1]-50 and triage_pos[2] <= person_pos[2]+50 and triage_pos[3] <= person_pos[3]+50): 79 | is_exist_pos.append(q) 80 | tag_is_in_person[q] = True 81 | if bboxes_triage[q][-1] > max_score: 82 | max_score = bboxes_triage[q][-1] 83 | for k in range(len(is_exist_pos)): 84 | if bboxes_triage[is_exist_pos[k]][-1] < max_score: 85 | bboxes_triage[is_exist_pos[k]][-1] = 0 86 | for t in range(bboxes_triage.shape[0]): 87 | if tag_is_in_person[t] == False: 88 | bboxes_triage[t][-1] = 0 89 | scores = bboxes_triage[:, -1] 90 | score_thr = 0.3 91 | inds = scores > score_thr 92 | labels = labels[inds] 93 | for i in range(len(labels)): 94 | answer_sheet[labels[i]] += 1 95 | drone_num = int(triage_list[file_idx].split("drone")[1][:2]) 96 | drone = "drone_" + str(drone_num) 97 | img_key = triage_list[file_idx].split(".jpg")[0] 98 | if drone_num == 1: 99 | drone_1[img_key] = answer_sheet 100 | elif drone_num == 2: 101 | drone_2[img_key] = answer_sheet 102 | elif drone_num == 3: 103 | drone_3[img_key] = answer_sheet 104 | drone_1_list, drone_2_list, drone_3_list = [], [], [] 105 | drone_1_list.append(drone_1) 106 | drone_2_list.append(drone_2) 107 | drone_3_list.append(drone_3) 108 | set_drone1, set_drone2, set_drone3 = dict(), dict(), dict() 109 | set_drone1["drone_1"] = drone_1_list 110 | set_drone2["drone_2"] = drone_2_list 111 | set_drone3["drone_3"] = drone_3_list 112 | set_drone_list = [] 113 | set_drone_list.append(set_drone1) 114 | set_drone_list.append(set_drone2) 115 | set_drone_list.append(set_drone3) 116 | task4_answer[set_name] = set_drone_list 117 | # print("\n \ntask4_answer[set_name] :\n ", task4_answer[set_name]) 118 | # print(task4_answer) 119 | #final_answer = dict() 120 | #final_answer["task4_answer"] = task4_answer 121 | return task4_answer 122 | if __name__ == '__main__': 123 | # path = 'dataset_path/' 124 | task4_main(path) -------------------------------------------------------------------------------- /task4/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import os 4 | import os.path as osp 5 | import shutil 6 | import sys 7 | import warnings 8 | from setuptools import find_packages, setup 9 | 10 | import torch 11 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 12 | CUDAExtension) 13 | 14 | 15 | def readme(): 16 | with open('README.md', encoding='utf-8') as f: 17 | content = f.read() 18 | return content 19 | 20 | 21 | version_file = 'mmdet/version.py' 22 | 23 | 24 | def get_version(): 25 | with open(version_file, 'r') as f: 26 | exec(compile(f.read(), version_file, 'exec')) 27 | return locals()['__version__'] 28 | 29 | 30 | def make_cuda_ext(name, module, sources, sources_cuda=[]): 31 | 32 | define_macros = [] 33 | extra_compile_args = {'cxx': []} 34 | 35 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 36 | define_macros += [('WITH_CUDA', None)] 37 | extension = CUDAExtension 38 | extra_compile_args['nvcc'] = [ 39 | '-D__CUDA_NO_HALF_OPERATORS__', 40 | '-D__CUDA_NO_HALF_CONVERSIONS__', 41 | '-D__CUDA_NO_HALF2_OPERATORS__', 42 | ] 43 | sources += sources_cuda 44 | else: 45 | print(f'Compiling {name} without CUDA') 46 | extension = CppExtension 47 | 48 | return extension( 49 | name=f'{module}.{name}', 50 | sources=[os.path.join(*module.split('.'), p) for p in sources], 51 | define_macros=define_macros, 52 | extra_compile_args=extra_compile_args) 53 | 54 | 55 | def parse_requirements(fname='requirements.txt', with_version=True): 56 | """Parse the package dependencies listed in a requirements file but strips 57 | specific versioning information. 58 | 59 | Args: 60 | fname (str): path to requirements file 61 | with_version (bool, default=False): if True include version specs 62 | 63 | Returns: 64 | List[str]: list of requirements items 65 | 66 | CommandLine: 67 | python -c "import setup; print(setup.parse_requirements())" 68 | """ 69 | import sys 70 | from os.path import exists 71 | import re 72 | require_fpath = fname 73 | 74 | def parse_line(line): 75 | """Parse information from a line in a requirements text file.""" 76 | if line.startswith('-r '): 77 | # Allow specifying requirements in other files 78 | target = line.split(' ')[1] 79 | for info in parse_require_file(target): 80 | yield info 81 | else: 82 | info = {'line': line} 83 | if line.startswith('-e '): 84 | info['package'] = line.split('#egg=')[1] 85 | elif '@git+' in line: 86 | info['package'] = line 87 | else: 88 | # Remove versioning from the package 89 | pat = '(' + '|'.join(['>=', '==', '>']) + ')' 90 | parts = re.split(pat, line, maxsplit=1) 91 | parts = [p.strip() for p in parts] 92 | 93 | info['package'] = parts[0] 94 | if len(parts) > 1: 95 | op, rest = parts[1:] 96 | if ';' in rest: 97 | # Handle platform specific dependencies 98 | 99 | version, platform_deps = map(str.strip, 100 | rest.split(';')) 101 | info['platform_deps'] = platform_deps 102 | else: 103 | version = rest # NOQA 104 | info['version'] = (op, version) 105 | yield info 106 | 107 | def parse_require_file(fpath): 108 | with open(fpath, 'r') as f: 109 | for line in f.readlines(): 110 | line = line.strip() 111 | if line and not line.startswith('#'): 112 | for info in parse_line(line): 113 | yield info 114 | 115 | def gen_packages_items(): 116 | if exists(require_fpath): 117 | for info in parse_require_file(require_fpath): 118 | parts = [info['package']] 119 | if with_version and 'version' in info: 120 | parts.extend(info['version']) 121 | if not sys.version.startswith('3.4'): 122 | # apparently package_deps are broken in 3.4 123 | platform_deps = info.get('platform_deps') 124 | if platform_deps is not None: 125 | parts.append(';' + platform_deps) 126 | item = ''.join(parts) 127 | yield item 128 | 129 | packages = list(gen_packages_items()) 130 | return packages 131 | 132 | 133 | def add_mim_extension(): 134 | """Add extra files that are required to support MIM into the package. 135 | 136 | These files will be added by creating a symlink to the originals if the 137 | package is installed in `editable` mode (e.g. pip install -e .), or by 138 | copying from the originals otherwise. 139 | """ 140 | 141 | # parse installment mode 142 | if 'develop' in sys.argv: 143 | # installed by `pip install -e .` 144 | mode = 'symlink' 145 | elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv: 146 | # installed by `pip install .` 147 | # or create source distribution by `python setup.py sdist` 148 | mode = 'copy' 149 | else: 150 | return 151 | 152 | filenames = ['tools', 'configs', 'demo', 'model-index.yml'] 153 | repo_path = osp.dirname(__file__) 154 | mim_path = osp.join(repo_path, 'mmdet', '.mim') 155 | os.makedirs(mim_path, exist_ok=True) 156 | 157 | for filename in filenames: 158 | if osp.exists(filename): 159 | src_path = osp.join(repo_path, filename) 160 | tar_path = osp.join(mim_path, filename) 161 | 162 | if osp.isfile(tar_path) or osp.islink(tar_path): 163 | os.remove(tar_path) 164 | elif osp.isdir(tar_path): 165 | shutil.rmtree(tar_path) 166 | 167 | if mode == 'symlink': 168 | src_relpath = osp.relpath(src_path, osp.dirname(tar_path)) 169 | os.symlink(src_relpath, tar_path) 170 | elif mode == 'copy': 171 | if osp.isfile(src_path): 172 | shutil.copyfile(src_path, tar_path) 173 | elif osp.isdir(src_path): 174 | shutil.copytree(src_path, tar_path) 175 | else: 176 | warnings.warn(f'Cannot copy file {src_path}.') 177 | else: 178 | raise ValueError(f'Invalid mode {mode}') 179 | 180 | 181 | if __name__ == '__main__': 182 | add_mim_extension() 183 | setup( 184 | name='mmdet', 185 | version=get_version(), 186 | description='OpenMMLab Detection Toolbox and Benchmark', 187 | long_description=readme(), 188 | long_description_content_type='text/markdown', 189 | author='MMDetection Contributors', 190 | author_email='openmmlab@gmail.com', 191 | keywords='computer vision, object detection', 192 | 193 | packages=find_packages(exclude=('configs', 'tools', 'demo')), 194 | include_package_data=True, 195 | classifiers=[ 196 | 'Development Status :: 5 - Production/Stable', 197 | 'License :: OSI Approved :: Apache Software License', 198 | 'Operating System :: OS Independent', 199 | 'Programming Language :: Python :: 3', 200 | 'Programming Language :: Python :: 3.6', 201 | 'Programming Language :: Python :: 3.7', 202 | 'Programming Language :: Python :: 3.8', 203 | 'Programming Language :: Python :: 3.9', 204 | ], 205 | license='Apache License 2.0', 206 | setup_requires=parse_requirements('requirements/build.txt'), 207 | tests_require=parse_requirements('requirements/tests.txt'), 208 | install_requires=parse_requirements('requirements/runtime.txt'), 209 | extras_require={ 210 | 'all': parse_requirements('requirements.txt'), 211 | 'tests': parse_requirements('requirements/tests.txt'), 212 | 'build': parse_requirements('requirements/build.txt'), 213 | 'optional': parse_requirements('requirements/optional.txt'), 214 | }, 215 | ext_modules=[], 216 | cmdclass={'build_ext': BuildExtension}, 217 | zip_safe=False) 218 | --------------------------------------------------------------------------------