├── models ├── __init__.py ├── model.py ├── resnext.py └── resnet.py ├── utils ├── __init__.py ├── target_transforms.py ├── mean.py ├── utils.py ├── myutils.py ├── temporal_transforms.py ├── label_interpolation_v1.m └── spatial_transforms.py ├── datasets ├── __init__.py ├── yt_seg.py ├── quva.py └── ucf_aug.py ├── data └── ori_data │ ├── ucf526 │ └── ucf526_annotations.zip │ ├── QUVA │ └── video2img.m │ └── YT_seg │ └── video2img.m ├── dataset.py ├── README.md ├── opts.py ├── main.py ├── train_2stream.py └── val_2stream.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/ori_data/ucf526/ucf526_annotations.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaodomgdomg/Deep-Temporal-Repetition-Counting/HEAD/data/ori_data/ucf526/ucf526_annotations.zip -------------------------------------------------------------------------------- /data/ori_data/QUVA/video2img.m: -------------------------------------------------------------------------------- 1 | files = dir('data/ori_data/QUVA/videos'); 2 | save_path = 'data/ori_data/QUVA/imgs/val'; 3 | 4 | for i = 3:numel(files 5 | i 6 | obj = VideoReader([files(i).folder '/' files(i).name]); 7 | obj_numberofframe = obj.NumberOfFrame; 8 | mkdir([save_path '/' files(i).name(1:end-4)]); 9 | for j = 1:obj_numberofframe 10 | imgs = read(obj,j); 11 | imwrite(imgs, [save_path '/' files(i).name(1:end-4) '/' num2str(j,'%06d') '.jpg']); 12 | end 13 | clear obj frame 14 | end 15 | -------------------------------------------------------------------------------- /utils/target_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | 5 | class Compose(object): 6 | 7 | def __init__(self, transforms): 8 | self.transforms = transforms 9 | 10 | def __call__(self, target): 11 | dst = [] 12 | for t in self.transforms: 13 | dst.append(t(target)) 14 | return dst 15 | 16 | 17 | class ClassLabel(object): 18 | 19 | def __call__(self, target): 20 | return target['label'] 21 | 22 | 23 | class VideoID(object): 24 | 25 | def __call__(self, target): 26 | return target['video_id'] 27 | -------------------------------------------------------------------------------- /data/ori_data/YT_seg/video2img.m: -------------------------------------------------------------------------------- 1 | files = dir('data/ori_data/YT_seg/videos'); 2 | save_path = 'data/ori_data/YT_seg/imgs/val'; 3 | 4 | for i = 3:numel(files) 5 | i 6 | obj = VideoReader([files(i).folder '/' files(i).name]); 7 | obj_numberofframe = obj.NumberOfFrame; 8 | 9 | mkdir([save_path '/' files(i).name(1:end-4)]); 10 | for j = 1:obj_numberofframe 11 | %imgs = read(obj,j); 12 | % imwrite(imgs, [save_path '/' files(i).name(1:end-4) '/' num2str(j,'%06d') '.jpg']); 13 | end 14 | fd_frames = fopen([save_path '/' files(i).name(1:end-4) '/' 'n_frames'], 'w'); 15 | fprintf(fd_frames, '%d\n', obj_numberofframe); 16 | fclose(fd_frames); 17 | 18 | ytseg.duration(i-2) = obj.Duration; 19 | 20 | clear obj frame 21 | end 22 | -------------------------------------------------------------------------------- /utils/mean.py: -------------------------------------------------------------------------------- 1 | def get_mean(norm_value=255, dataset='std'): 2 | assert dataset in ['quva', 'std'] 3 | 4 | if dataset == 'quva': 5 | return [ 6 | 114.7748 / norm_value, 107.7354 / norm_value, 99.4750 / norm_value 7 | ] 8 | elif dataset == 'std': 9 | return [127.0 / norm_value, 127.0 / norm_value, 127.0 / norm_value] 10 | # return [0.0 / norm_value, 0.0 / norm_value, 0.0 / norm_value] 11 | 12 | def get_std(norm_value=255, dataset = 'std'): 13 | assert dataset in ['quva', 'std'] 14 | 15 | if dataset == 'quva': 16 | return [ 17 | 114.7748 / norm_value, 107.7354 / norm_value, 99.4750 / norm_value 18 | ] 19 | elif dataset == 'std': 20 | return [127.0 / norm_value, 127.0 / norm_value, 127.0 / norm_value] 21 | # return [1.0 / norm_value, 1.0 / norm_value, 1.0 / norm_value] 22 | 23 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from models import resnet, resnext 5 | 6 | 7 | def generate_model(opt): 8 | assert opt.model in [ 9 | 'resnet', 'resnext' 10 | ] 11 | 12 | if opt.model == 'resnet': 13 | assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200] 14 | 15 | from models.resnet import get_fine_tuning_parameters 16 | 17 | if opt.model_depth == 10: 18 | model = resnet.resnet10(opt=opt) 19 | elif opt.model_depth == 18: 20 | model = resnet.resnet18(opt=opt) 21 | elif opt.model_depth == 34: 22 | model = resnet.resnet34(opt=opt) 23 | elif opt.model_depth == 50: 24 | model = resnet.resnet50(opt=opt) 25 | elif opt.model_depth == 101: 26 | model = resnet.resnet101(opt=opt) 27 | elif opt.model_depth == 152: 28 | model = resnet.resnet152(opt=opt) 29 | elif opt.model_depth == 200: 30 | model = resnet.resnet200(opt=opt) 31 | elif opt.model == 'resnext': 32 | assert opt.model_depth in [50, 101, 152] 33 | 34 | from models.resnext import get_fine_tuning_parameters 35 | 36 | if opt.model_depth == 50: 37 | model = resnext.resnet50(opt=opt) 38 | elif opt.model_depth == 101: 39 | model = resnext.resnet101(opt=opt) 40 | elif opt.model_depth == 152: 41 | model = resnext.resnet152(opt=opt) 42 | 43 | 44 | if not opt.no_cuda: 45 | model = model.cuda() 46 | 47 | 48 | 49 | 50 | return model, model.parameters() 51 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # from datasets.kinetics import Kinetics 2 | from datasets.quva import QUVA 3 | from datasets.ucf_aug import UCF_AUG 4 | from datasets.yt_seg import YT_SEG 5 | 6 | def get_training_set(opt, spatial_transform, target_transform): 7 | assert opt.train_dataset in ['ucf_aug'] 8 | 9 | if opt.train_dataset == 'ucf_aug': 10 | training_data = UCF_AUG( 11 | opt.dataset_path, 12 | 'train', 13 | sample_duration=opt.sample_duration, 14 | opt=opt, 15 | n_samples_for_each_video=10, 16 | spatial_transform=spatial_transform) 17 | 18 | return training_data 19 | 20 | 21 | def get_validation_set(dataset, spatial_transform, target_transform, opt): 22 | assert dataset in ['quva', 'ucf_aug', 'yt_seg'] 23 | 24 | if dataset == 'quva': 25 | validation_data = QUVA( 26 | opt.dataset_path, 27 | 'val', 28 | sample_duration=opt.sample_duration, 29 | n_samples_for_each_video=1, 30 | spatial_transform=spatial_transform) 31 | elif dataset == 'ucf_aug': 32 | validation_data = UCF_AUG( 33 | opt.dataset_path, 34 | 'val', 35 | sample_duration=opt.sample_duration, 36 | opt=opt, 37 | n_samples_for_each_video=1, 38 | spatial_transform=spatial_transform) 39 | elif dataset == 'yt_seg': 40 | validation_data = YT_SEG( 41 | opt.dataset_path, 42 | 'val', 43 | sample_duration=opt.sample_duration, 44 | n_samples_for_each_video=1, 45 | spatial_transform=spatial_transform) 46 | 47 | return validation_data 48 | 49 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import torch 3 | import math 4 | import numpy as np 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | self.lists = [] 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | for _ in range(0, n): 25 | self.lists.append(val) 26 | 27 | def std(self): 28 | std_sum = 0.0 29 | for i in range(0, len(self.lists)): 30 | std_sum = std_sum + pow(self.lists[i]-self.avg, 2) 31 | std_sum = pow(std_sum / self.count, 0.5) 32 | return std_sum 33 | 34 | 35 | class Logger(object): 36 | 37 | def __init__(self, path, header): 38 | self.log_file = open(path, 'w') 39 | self.logger = csv.writer(self.log_file, delimiter='\t') 40 | 41 | self.logger.writerow(header) 42 | self.header = header 43 | 44 | def __del(self): 45 | self.log_file.close() 46 | 47 | def log(self, values): 48 | write_values = [] 49 | for col in self.header: 50 | assert col in values 51 | write_values.append(values[col]) 52 | 53 | self.logger.writerow(write_values) 54 | self.log_file.flush() 55 | 56 | 57 | def load_value_file(file_path): 58 | with open(file_path, 'r') as input_file: 59 | value = float(input_file.read().rstrip('\n\r')) 60 | 61 | return value 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /utils/myutils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import torch 3 | import math 4 | import numpy as np 5 | import random 6 | 7 | 8 | def update_inputs_2stream(sample_inputs, state, sample_len, opt): 9 | inputs = torch.zeros([3, opt.basic_duration, opt.sample_size, opt.sample_size], dtype=torch.float).cuda() 10 | 11 | padding = torch.zeros([3, opt.sample_size, opt.sample_size], dtype=torch.float).cuda() 12 | padding.index_fill_(0, torch.tensor([0]).cuda(), opt.mean[0]) 13 | padding.index_fill_(0, torch.tensor([1]).cuda(), opt.mean[1]) 14 | padding.index_fill_(0, torch.tensor([2]).cuda(), opt.mean[2]) 15 | 16 | frame_indices = torch.zeros([opt.basic_duration], dtype=torch.float).cuda() 17 | 18 | sduration = int(opt.basic_duration / 2) 19 | 20 | pos = int(state[0] - (state[1] - state[0] + 1) * opt.l_context_ratio) 21 | pos2 = int(state[1]) 22 | steps = (pos2-pos+1)*1.0/(sduration-1) 23 | for j in range(0, sduration): 24 | p = round(pos-0.5+j*steps) 25 | p = int(max(pos, min(pos2, p))) 26 | 27 | frame_indices[j] = p 28 | if p < 0 or p >= sample_len: 29 | inputs[:,j,:,:] = padding 30 | else: 31 | inputs[:,j,:,:] = sample_inputs[:,p,:,:] 32 | 33 | 34 | pos = int(state[1] + 1) 35 | pos2 = int(state[2] + (state[2] - pos + 1) * (opt.r_context_ratio - 1)) 36 | steps = (pos2-pos+1)*1.0/(sduration-1) 37 | for j in range(0, sduration): 38 | p = round(pos-0.5+j*steps) 39 | p = int(max(pos, min(pos2, p))) 40 | 41 | frame_indices[sduration + j] = p 42 | if p < 0 or p >= sample_len: 43 | inputs[:,sduration + j,:,:] = padding 44 | else: 45 | inputs[:,sduration + j,:,:] = sample_inputs[:,p,:,:] 46 | 47 | 48 | return inputs, frame_indices 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CVPR2020-Deep-Temporal-Repetition-Counting 2 | This is the official implementation of CVPR2020 paper: ["Context-aware and Scale-insensitive Temporal Repetition Counting"](https://openaccess.thecvf.com/content_CVPR_2020/html/Zhang_Context-Aware_and_Scale-Insensitive_Temporal_Repetition_Counting_CVPR_2020_paper.html) 3 | 4 | This code is implemented based on the project["3D ResNets for Action Recognition"](https://github.com/kenshohara/3D-ResNets-PyTorch). 5 | 6 | ## Requirements 7 | 8 | * [PyTorch](http://pytorch.org/) (ver. 1.0) 9 | * Python 2 10 | 11 | ## Dataset Preparation 12 | 13 | ### UCFRep 14 | * Please download the UCF101 dataset [here](http://crcv.ucf.edu/data/UCF101.php). 15 | * Convert UCF101 videos from avi to png files, put the png files to data/ori_data/ucf526/imgs/train 16 | * Create soft link with following commands: 17 | ```bash 18 | cd data/ori_data/ucf526/imgs 19 | ln -s train val 20 | ``` 21 | * Please download the anotations ([Google Drive](https://drive.google.com/file/d/1c0v51oP44lY_PhpJp8KYAwDaQmxj2zcs/view?usp=sharing),or [Baidu Netdisk](https://pan.baidu.com/s/1nHQZ8P-JZPTo4IRlcOBoHA) code:n5za), and put it to data/ori_data/ucf526/annotations 22 | 23 | ### QUVA 24 | * Please download the QUVA dataset in: http://tomrunia.github.io/projects/repetition/ 25 | * Put the label files to data/ori_data/QUVA/annotations/val 26 | * Convert QUVA videos to png files, put the png files to data/ori_data/QUVA/imgs 27 | 28 | ### YTsegments 29 | * Please download the YTSeg dataset in: https://github.com/ofirlevy/repcount 30 | * Put the label files to data/ori_data/YT_seg/annotations 31 | * Convert YTsegments videos to png files, put the png files to data/ori_data/YT_seg/imgs 32 | 33 | ## Running the code 34 | ### Training 35 | Train from scratch 36 | ```bash 37 | python main.py 38 | ``` 39 | If you want to finetune the model pretrained on Kinetics, first you need to download the pretrained model in [here](https://github.com/kenshohara/3D-ResNets-PyTorch) and run: 40 | ```bash 41 | python main.py --pretrain_path = pretrained_model_path 42 | ``` 43 | 44 | ### Testing 45 | You can also run the trained model provide by ours ([Google Drive](https://drive.google.com/drive/folders/1EM_G1yYBIB35sfOdgBsJwh6fTioAE_uX) or [Baidu Netdisk](https://pan.baidu.com/s/1iqwsVZDeBBdxBq3iWhPBNA) code:na81): 46 | ```bash 47 | python main.py --no_train --resume_path = trained_model_path 48 | ``` 49 | 50 | ## Citation 51 | If you use this code or pre-trained models, please cite the following: 52 | 53 | ```bibtex 54 | @InProceedings{Zhang_2020_CVPR, 55 | author = {Zhang, Huaidong and Xu, Xuemiao and Han, Guoqiang and He, Shengfeng}, 56 | title = {Context-Aware and Scale-Insensitive Temporal Repetition Counting}, 57 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 58 | month = {June}, 59 | year = {2020} 60 | } 61 | ``` 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /utils/temporal_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | 5 | class LoopPadding(object): 6 | 7 | def __init__(self, size): 8 | self.size = size 9 | 10 | def __call__(self, frame_indices): 11 | out = frame_indices 12 | 13 | for index in out: 14 | if len(out) >= self.size: 15 | break 16 | out.append(index) 17 | 18 | return out 19 | 20 | 21 | class TemporalBeginCrop(object): 22 | """Temporally crop the given frame indices at a beginning. 23 | 24 | If the number of frames is less than the size, 25 | loop the indices as many times as necessary to satisfy the size. 26 | 27 | Args: 28 | size (int): Desired output size of the crop. 29 | """ 30 | 31 | def __init__(self, size): 32 | self.size = size 33 | 34 | def __call__(self, frame_indices): 35 | out = frame_indices[:self.size] 36 | 37 | for index in out: 38 | if len(out) >= self.size: 39 | break 40 | out.append(index) 41 | 42 | return out 43 | 44 | 45 | class TemporalCenterCrop(object): 46 | """Temporally crop the given frame indices at a center. 47 | 48 | If the number of frames is less than the size, 49 | loop the indices as many times as necessary to satisfy the size. 50 | 51 | Args: 52 | size (int): Desired output size of the crop. 53 | """ 54 | 55 | def __init__(self, size): 56 | self.size = size 57 | 58 | def __call__(self, frame_indices): 59 | """ 60 | Args: 61 | frame_indices (list): frame indices to be cropped. 62 | Returns: 63 | list: Cropped frame indices. 64 | """ 65 | 66 | center_index = len(frame_indices) // 2 67 | begin_index = max(0, center_index - (self.size // 2)) 68 | end_index = min(begin_index + self.size, len(frame_indices)) 69 | 70 | out = frame_indices[begin_index:end_index] 71 | 72 | for index in out: 73 | if len(out) >= self.size: 74 | break 75 | out.append(index) 76 | 77 | return out 78 | 79 | 80 | class TemporalRandomCrop(object): 81 | """Temporally crop the given frame indices at a random location. 82 | 83 | If the number of frames is less than the size, 84 | loop the indices as many times as necessary to satisfy the size. 85 | 86 | Args: 87 | size (int): Desired output size of the crop. 88 | """ 89 | 90 | def __init__(self, size): 91 | self.size = size 92 | 93 | def __call__(self, frame_indices): 94 | """ 95 | Args: 96 | frame_indices (list): frame indices to be cropped. 97 | Returns: 98 | list: Cropped frame indices. 99 | """ 100 | 101 | rand_end = max(0, len(frame_indices) - self.size - 1) 102 | begin_index = random.randint(0, rand_end) 103 | end_index = min(begin_index + self.size, len(frame_indices)) 104 | 105 | out = frame_indices[begin_index:end_index] 106 | 107 | for index in out: 108 | if len(out) >= self.size: 109 | break 110 | out.append(index) 111 | 112 | return out 113 | -------------------------------------------------------------------------------- /utils/label_interpolation_v1.m: -------------------------------------------------------------------------------- 1 | 2 | 3 | clear all; 4 | 5 | root = ''; 6 | load_root = fullfile(root, 'repetition_label_c01only'); 7 | save_root = fullfile(root, 'repetition_label_c01only_interpolation'); 8 | 9 | dir_path = dir(load_root); 10 | 11 | dataset_fail_num = 0; 12 | dataset_num = 0; 13 | dataset_cnt = 0; 14 | 15 | ptr=1; 16 | for i = 11:numel(dir_path) 17 | 18 | fid_read = fopen(fullfile(load_root, [dir_path(i).name]), 'r'); 19 | 20 | 21 | for j = 1:25 22 | 23 | file_name = fscanf(fid_read, '%s ', 1); 24 | tmp = fscanf(fid_read, '%d'); 25 | 26 | cnt_num = numel(tmp); 27 | if cnt_num == 1 28 | dataset_fail_num = dataset_fail_num + 1; 29 | else 30 | tmp2 = tmp(2:end)-tmp(1:end-1); 31 | 32 | % estimation of dataset distribution 33 | dataset_num = dataset_num + 1; 34 | dataset_cnt = dataset_cnt + numel(tmp); 35 | dataset.duration(ptr) = tmp(end)-tmp(1)+1; 36 | dataset.count(ptr) = numel(tmp); 37 | dataset.length_variation(ptr) = (max(tmp2)-min(tmp2))/mean(tmp2); 38 | 39 | clear label; 40 | 41 | label.duration = dataset.duration(ptr); 42 | label.start_frame = tmp(1); 43 | label.end_frame = tmp(end); 44 | label.temporal_bound_num = dataset.count(ptr); 45 | label.temporal_bound = tmp; 46 | 47 | for k = 1:numel(tmp)-1 48 | for p = tmp(k):tmp(k+1) 49 | seg_m = tmp(k+1)-tmp(k); 50 | if seg_m <= 0 51 | raise('error'); 52 | end 53 | seg_l = -1; 54 | seg_r = -1; 55 | 56 | if k == numel(tmp)-1 57 | label.offset_next_estimate(p-tmp(1)+1) = -1; 58 | else 59 | seg_r = tmp(k+2)-tmp(k+1); 60 | if max(seg_m, seg_r) - min(seg_m, seg_r) > 2 && max(seg_m, seg_r) / min(seg_m, seg_r) > 1.3 61 | label.offset_next_estimate(p-tmp(1)+1) = -1; 62 | else 63 | offset_next = (p-tmp(k)) / seg_m * seg_r + tmp(k+1) - p; 64 | offset_next = int32(round(offset_next)); 65 | label.offset_next_estimate(p-tmp(1)+1) = offset_next; 66 | end 67 | end 68 | 69 | if k == 1 70 | label.offset_pre_estimate(p-tmp(1)+1) = -1; 71 | else 72 | seg_l = tmp(k)-tmp(k-1); 73 | if max(seg_m, seg_l) - min(seg_m, seg_l) > 2 && max(seg_m, seg_l) / min(seg_m, seg_l) > 1.3 74 | label.offset_pre_estimate(p-tmp(1)+1) = -1; 75 | else 76 | offset_next = p - tmp(k-1) - (p-tmp(k)) / seg_m * seg_l; 77 | offset_next = int32(round(offset_next)); 78 | label.offset_pre_estimate(p-tmp(1)+1) = offset_next; 79 | end 80 | end 81 | 82 | end 83 | end 84 | 85 | % fullfile(save_root, [file_name '.mat']) 86 | savepath = fullfile(save_root, 'mat', [file_name '.mat']); 87 | save(savepath, 'label'); 88 | % save savepath 89 | data{i,j} = label; 90 | ptr = ptr+1; 91 | 92 | end 93 | end 94 | fclose(fid_read); 95 | 96 | end 97 | -------------------------------------------------------------------------------- /datasets/yt_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import math 6 | import functools 7 | import json 8 | import copy 9 | import numpy as np 10 | import pickle 11 | 12 | from utils.utils import load_value_file 13 | 14 | 15 | def pil_loader(path): 16 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 17 | with open(path, 'rb') as f: 18 | with Image.open(f) as img: 19 | return img.convert('RGB') 20 | 21 | 22 | def accimage_loader(path): 23 | try: 24 | import accimage 25 | return accimage.Image(path) 26 | except IOError: 27 | # Potentially a decoding problem, fall back to PIL.Image 28 | return pil_loader(path) 29 | 30 | 31 | def get_default_image_loader(): 32 | from torchvision import get_image_backend 33 | if get_image_backend() == 'accimage': 34 | return accimage_loader 35 | else: 36 | return pil_loader 37 | 38 | 39 | def video_loader(video_dir_path, frame_indices, image_loader): 40 | video = [] 41 | for i in frame_indices: 42 | image_path = os.path.join(video_dir_path, '{:06d}.jpg'.format(i)) 43 | if os.path.exists(image_path): 44 | video.append(image_loader(image_path)) 45 | else: 46 | return video 47 | 48 | return video 49 | 50 | 51 | def get_default_video_loader(): 52 | image_loader = get_default_image_loader() 53 | return functools.partial(video_loader, image_loader=image_loader) 54 | 55 | 56 | def make_dataset(dataset_path, subset, sample_duration, n_samples_for_each_video): 57 | dataset_path = os.path.join(dataset_path, 'YT_seg') 58 | video_path = os.path.join(dataset_path, 'imgs') 59 | 60 | video_names = os.listdir(os.path.join(video_path,subset)) 61 | video_names.sort() 62 | 63 | annotation_path = os.path.join(dataset_path, 'annotations') 64 | annotations = pickle.load( open( os.path.join(annotation_path,'vidGtData.p'), "rb" ) ) 65 | 66 | dataset = [] 67 | max_n_frames = 0 68 | 69 | for i in range(len(video_names)): 70 | if (i+1) % 50 == 0 or i+1 == len(video_names): 71 | print('{} dataset loading [{}/{}]'.format(subset, i+1, len(video_names))) 72 | 73 | video_path_i = os.path.join(video_path, subset, video_names[i]) 74 | # print(video_path_i) 75 | 76 | if not os.path.exists(video_path_i): 77 | continue 78 | 79 | n_frames_file_path = os.path.join(video_path_i, 'n_frames') 80 | n_frames = int(load_value_file(n_frames_file_path)) 81 | max_n_frames = max(max_n_frames, n_frames) 82 | if n_frames <= 0: 83 | continue 84 | 85 | begin_t = 1 86 | end_t = n_frames 87 | sample = { 88 | 'video': video_path_i, 89 | 'segment': [begin_t, end_t], 90 | 'n_frames': n_frames, 91 | 'video_id': video_names[i][0:3], 92 | 'label': annotations[i] 93 | } 94 | 95 | if n_samples_for_each_video == 1: 96 | sample['frame_indices'] = list(range(1, n_frames + 1)) 97 | dataset.append(sample) 98 | else: 99 | if n_samples_for_each_video > 1: 100 | step = max(1, 101 | math.ceil((n_frames - 1 - sample_duration) / 102 | (n_samples_for_each_video - 1))) 103 | step = int(step) 104 | else: 105 | raise('error, n_samples_for_each_video should >=1\n') 106 | # step = sample_duration 107 | for j in range(1, n_frames-sample_duration, step): 108 | sample_j = copy.deepcopy(sample) 109 | sample_j['frame_indices'] = list( 110 | range(j, min(n_frames + 1, j + sample_duration))) 111 | dataset.append(sample_j) 112 | 113 | 114 | return dataset, max_n_frames 115 | 116 | 117 | class YT_SEG(data.Dataset): 118 | 119 | def __init__(self, 120 | dataset_path, 121 | subset, 122 | sample_duration, 123 | n_samples_for_each_video=10, 124 | spatial_transform=None, 125 | target_transform=None, 126 | get_loader=get_default_video_loader): 127 | self.data, self.max_n_frames = make_dataset(dataset_path, subset, sample_duration, n_samples_for_each_video) 128 | 129 | self.spatial_transform = spatial_transform 130 | self.target_transform = target_transform 131 | self.loader = get_loader() 132 | self.mean = [0.0,0.0,0.0] 133 | self.var = [0.0,0.0,0.0] 134 | self.readed_num = 0.0 135 | 136 | def __getitem__(self, index): 137 | path = self.data[index]['video'] 138 | 139 | frame_indices = self.data[index]['frame_indices'] 140 | 141 | clip = self.loader(path, frame_indices) 142 | if self.spatial_transform is not None: 143 | self.spatial_transform.randomize_parameters() 144 | clip = [self.spatial_transform(img) for img in clip] 145 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) 146 | 147 | target = self.data[index]['label'] 148 | 149 | 150 | sample_len = clip.size(1) 151 | if clip.size(1) != self.max_n_frames: 152 | clip_zeros = torch.zeros([clip.size(0), self.max_n_frames - clip.size(1), clip.size(2), clip.size(3)], dtype=torch.float) 153 | clip = torch.cat([clip, clip_zeros], dim=1) 154 | 155 | return clip, -1, -1, target, sample_len 156 | 157 | def __len__(self): 158 | return len(self.data) 159 | -------------------------------------------------------------------------------- /datasets/quva.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import math 6 | import functools 7 | import json 8 | import copy 9 | import numpy as np 10 | 11 | from utils.utils import load_value_file 12 | 13 | 14 | def pil_loader(path): 15 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 16 | with open(path, 'rb') as f: 17 | with Image.open(f) as img: 18 | return img.convert('RGB') 19 | 20 | 21 | def accimage_loader(path): 22 | try: 23 | import accimage 24 | return accimage.Image(path) 25 | except IOError: 26 | # Potentially a decoding problem, fall back to PIL.Image 27 | return pil_loader(path) 28 | 29 | 30 | def get_default_image_loader(): 31 | from torchvision import get_image_backend 32 | if get_image_backend() == 'accimage': 33 | return accimage_loader 34 | else: 35 | return pil_loader 36 | 37 | 38 | def video_loader(video_dir_path, frame_indices, image_loader): 39 | video = [] 40 | for i in frame_indices: 41 | image_path = os.path.join(video_dir_path, 'image_{:05d}.jpg'.format(i)) 42 | if os.path.exists(image_path): 43 | video.append(image_loader(image_path)) 44 | else: 45 | return video 46 | 47 | return video 48 | 49 | 50 | def get_default_video_loader(): 51 | image_loader = get_default_image_loader() 52 | return functools.partial(video_loader, image_loader=image_loader) 53 | 54 | 55 | def get_video_names_and_annotations(dataset_path, subset): 56 | annotation_path = os.path.join(dataset_path, 'annotations') 57 | 58 | video_names = [] 59 | annotations = [] 60 | 61 | lists = os.listdir(os.path.join(annotation_path,subset)) 62 | lists.sort() 63 | 64 | for i in range(len(lists)): 65 | anno = np.load(os.path.join(annotation_path, subset, lists[i])) 66 | video_names.append(lists[i][0:-4]) 67 | annotations.append(anno) 68 | 69 | return video_names, annotations 70 | 71 | 72 | def make_dataset(dataset_path, subset, sample_duration, n_samples_for_each_video): 73 | dataset_path = os.path.join(dataset_path, 'QUVA') 74 | video_path = os.path.join(dataset_path, 'imgs') 75 | 76 | video_names, annotations = get_video_names_and_annotations(dataset_path, subset) 77 | 78 | dataset = [] 79 | max_n_frames = 0 80 | 81 | 82 | for i in range(len(video_names)): 83 | if (i+1) % 50 == 0 or i+1 == len(video_names): 84 | print('{} dataset loading [{}/{}]'.format(subset, i+1, len(video_names))) 85 | 86 | video_path_i = os.path.join(video_path, subset, video_names[i]) 87 | 88 | if not os.path.exists(video_path_i): 89 | continue 90 | 91 | n_frames_file_path = os.path.join(video_path_i, 'n_frames') 92 | n_frames = int(load_value_file(n_frames_file_path)) 93 | max_n_frames = max(max_n_frames, n_frames) 94 | if n_frames <= 0: 95 | continue 96 | 97 | begin_t = 1 98 | end_t = n_frames 99 | sample = { 100 | 'video': video_path_i, 101 | 'segment': [begin_t, end_t], 102 | 'n_frames': n_frames, 103 | 'video_id': video_names[i][0:3], 104 | 'label': annotations[i] 105 | } 106 | 107 | if n_samples_for_each_video == 1: 108 | sample['frame_indices'] = list(range(1, n_frames + 1)) 109 | dataset.append(sample) 110 | else: 111 | if n_samples_for_each_video > 1: 112 | step = max(1, 113 | math.ceil((n_frames - 1 - sample_duration) / 114 | (n_samples_for_each_video - 1))) 115 | step = int(step) 116 | else: 117 | raise('error, n_samples_for_each_video should >=1\n') 118 | # step = sample_duration 119 | for j in range(1, n_frames-sample_duration, step): 120 | sample_j = copy.deepcopy(sample) 121 | sample_j['frame_indices'] = list( 122 | range(j, min(n_frames + 1, j + sample_duration))) 123 | dataset.append(sample_j) 124 | 125 | return dataset, max_n_frames 126 | 127 | 128 | class QUVA(data.Dataset): 129 | 130 | def __init__(self, 131 | dataset_path, 132 | subset, 133 | sample_duration, 134 | n_samples_for_each_video=10, 135 | spatial_transform=None, 136 | target_transform=None, 137 | get_loader=get_default_video_loader): 138 | self.data, self.max_n_frames = make_dataset(dataset_path, subset, sample_duration, n_samples_for_each_video) 139 | 140 | self.spatial_transform = spatial_transform 141 | self.target_transform = target_transform 142 | self.loader = get_loader() 143 | 144 | self.mean = [0.0,0.0,0.0] 145 | self.var = [0.0,0.0,0.0] 146 | self.readed_num = 0.0 147 | 148 | def __getitem__(self, index): 149 | 150 | path = self.data[index]['video'] 151 | 152 | frame_indices = self.data[index]['frame_indices'] 153 | 154 | clip = self.loader(path, frame_indices) 155 | if self.spatial_transform is not None: 156 | self.spatial_transform.randomize_parameters() 157 | clip = [self.spatial_transform(img) for img in clip] 158 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) 159 | 160 | target = self.data[index]['label'] 161 | 162 | sample_len = clip.size(1) 163 | if clip.size(1) != self.max_n_frames: 164 | clip_zeros = torch.zeros([clip.size(0), self.max_n_frames - clip.size(1), clip.size(2), clip.size(3)], dtype=torch.float) 165 | clip = torch.cat([clip, clip_zeros], dim=1) 166 | 167 | return clip, -1, -1, len(target), sample_len 168 | 169 | def __len__(self): 170 | return len(self.data) 171 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | 5 | def parse_opts(): 6 | 7 | learning_policy = '2stream' 8 | validate_policy = '2stream' 9 | parser = argparse.ArgumentParser() 10 | 11 | # paths 12 | parser.add_argument( 13 | '--root_path', 14 | default= './data/', 15 | type=str, 16 | help='Root directory path of data') 17 | parser.add_argument( 18 | '--dataset_path', 19 | default='ori_data/', 20 | type=str, 21 | help='Directory path of Videos') 22 | parser.add_argument( 23 | '--result_path', 24 | default='results/' +time.strftime('%m%d-%H:%M_',time.localtime(time.time()))+learning_policy, 25 | type=str, 26 | help='Result directory path') 27 | parser.add_argument( 28 | '--train_dataset', 29 | default='ucf_aug', 30 | type=str, 31 | help='') 32 | parser.add_argument( 33 | '--val_dataset', 34 | default= ['ucf_aug','quva', 'yt_seg' ], 35 | type=str, 36 | help='Used dataset (yt_seg | quva | ucf_aug)') 37 | 38 | # button 39 | parser.add_argument( 40 | '--no_train', 41 | action='store_true', 42 | help='If true, training is not performed.') 43 | parser.set_defaults(no_train=False) 44 | parser.add_argument( 45 | '--no_val', 46 | action='store_true', 47 | help='If true, validation is not performed.') 48 | parser.set_defaults(no_val=False) 49 | 50 | # training argument 51 | parser.add_argument( 52 | '--sample_duration', 53 | default=300, 54 | type=int, 55 | help='Temporal duration of training sample') 56 | parser.add_argument( 57 | '--mean_std_dataset', 58 | default='quva', 59 | type=str, 60 | help= 61 | '') 62 | parser.add_argument( 63 | '--sample_size', 64 | default=112, 65 | type=int, 66 | help='Height and width of inputs') 67 | parser.add_argument( 68 | '--batch_size', default=24, type=int, help='Batch Size') #32 69 | parser.add_argument( 70 | '--val_batch_size', default=5, type=int, help='Batch Size') #32 71 | parser.add_argument( 72 | '--n_epochs', 73 | default=100, 74 | type=int, 75 | help='Number of total epochs to run') 76 | parser.add_argument( 77 | '--lr_patience', 78 | default=1, 79 | type=int, 80 | help='Patience of LR scheduler. See documentation of ReduceLROnPlateau.' 81 | ) 82 | parser.add_argument( 83 | '--begin_epoch', 84 | default=1, 85 | type=int, 86 | help= 87 | 'Training begins at this epoch. Previous trained model indicated by resume_path is loaded.' 88 | ) 89 | parser.add_argument( 90 | '--pretrain_path', 91 | default= '', # 'weights/resnext-101-kinetics.pth', 92 | type=str) 93 | parser.add_argument( 94 | '--train_from_scratch', 95 | action='store_true') 96 | parser.set_defaults(train_from_scratch=True) 97 | parser.add_argument( 98 | '--resume_path', 99 | default= '', # 'weights/resnext101_ucf526.pth', 100 | type=str, 101 | help='Save data (.pth) of previous training') 102 | parser.add_argument( 103 | '--checkpoint', 104 | default=1, 105 | type=int, 106 | help='Trained model is saved at every this epochs.') 107 | 108 | # learning policy 109 | parser.add_argument( 110 | '--learning_policy', 111 | default=learning_policy, 112 | type=str, 113 | help='') 114 | parser.add_argument( 115 | '--validate_policy', 116 | default=validate_policy, 117 | type=str) 118 | parser.add_argument( 119 | '--optimizer', 120 | default='adam', 121 | type=str, 122 | help='Currently only support [adam, sgd]') 123 | parser.add_argument( 124 | '--learning_rate', 125 | default=0.00005, 126 | type=float, 127 | help= 128 | 'Initial learning rate (divided by 10 while training by lr scheduler)') 129 | parser.add_argument('--momentum', default=0.9, type=float, help='Momentum') 130 | parser.add_argument( 131 | '--weight_decay', default=1e-3, type=float, help='Weight Decay') 132 | parser.add_argument('--nesterov', action='store_true', help='Nesterov momentum') 133 | parser.set_defaults(nesterov=False) 134 | parser.add_argument('--dampening', default=0.9, type=float, help='dampening of SGD') 135 | 136 | 137 | # network argument 138 | parser.add_argument( 139 | '--basic_duration', 140 | default=32, # 48 141 | type=float, 142 | help='Temporal duration of network input') 143 | parser.add_argument( 144 | '--l_context_ratio', 145 | default=1.0, 146 | type=float, 147 | help='') 148 | parser.add_argument( 149 | '--r_context_ratio', 150 | default=2.0, 151 | type=float, 152 | help='') 153 | parser.add_argument( 154 | '--norm_value', 155 | default=255, 156 | type=int, 157 | help= 158 | 'If 1, range of inputs is [0-255]. If 255, range of inputs is [0-1].') 159 | parser.add_argument( 160 | '--model', 161 | default='resnext', 162 | type=str, 163 | help='(resnet | resnext') 164 | parser.add_argument( 165 | '--model_depth', 166 | default=101, 167 | type=int, 168 | help='Depth of resnet (10 | 18 | 34 | 50 | 101)') 169 | parser.add_argument( 170 | '--resnet_shortcut', 171 | default='B', 172 | type=str, 173 | help='Shortcut type of resnet (A | B)') 174 | parser.add_argument( 175 | '--resnext_cardinality', 176 | default=32, 177 | type=int, 178 | help='ResNeXt cardinality') 179 | parser.add_argument( 180 | '--n_classes', 181 | default=7, 182 | type=int, 183 | help= 184 | '[count, enlarge, narrow, miss]' 185 | ) 186 | parser.add_argument( 187 | '--anchors', 188 | default=[0.5, 0.67, 0.8, 1.0, 1.25, 1.5, 2.0], 189 | type=float 190 | ) 191 | parser.add_argument( 192 | '--iou_ubound', 193 | default=0.5, 194 | type=float 195 | ) 196 | parser.add_argument( 197 | '--iou_lbound', 198 | default=0.5, 199 | type=float 200 | ) 201 | 202 | # hardware argument 203 | parser.add_argument( 204 | '--no_cuda', action='store_true', help='If true, cuda is not used.') 205 | parser.set_defaults(no_cuda=False) 206 | parser.add_argument( 207 | '--n_threads', 208 | default=1, 209 | type=int, 210 | help='Number of threads for multi-thread loading') 211 | parser.add_argument( 212 | '--manual_seed', default=1, type=int, help='Manually set random seed') 213 | 214 | # reserved argument 215 | parser.add_argument( 216 | '--initial_scale', 217 | default=1.0, 218 | type=float, 219 | help='Initial scale for multiscale cropping') 220 | parser.add_argument( 221 | '--n_scales', 222 | default=1, 223 | type=int, 224 | help='Number of scales for multiscale cropping') 225 | parser.add_argument( 226 | '--scale_step', 227 | default=0.9457416090031758, 228 | type=float, 229 | help='Scale step for multiscale cropping') 230 | parser.add_argument( 231 | '--train_crop', 232 | default='center', 233 | type=str, 234 | help= 235 | 'Spatial cropping method in training. random is uniform. corner is selection from 4 corners and 1 center. (random | corner | center)' 236 | ) 237 | 238 | args = parser.parse_args() 239 | 240 | return args 241 | -------------------------------------------------------------------------------- /datasets/ucf_aug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import math 6 | import functools 7 | import json 8 | import copy 9 | import numpy as np 10 | import random 11 | from scipy.io import loadmat 12 | 13 | from utils.utils import load_value_file 14 | 15 | 16 | def pil_loader(path): 17 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 18 | with open(path, 'rb') as f: 19 | with Image.open(f) as img: 20 | return img.convert('RGB') 21 | 22 | 23 | def accimage_loader(path): 24 | try: 25 | import accimage 26 | return accimage.Image(path) 27 | except IOError: 28 | # Potentially a decoding problem, fall back to PIL.Image 29 | return pil_loader(path) 30 | 31 | 32 | def get_default_image_loader(): 33 | from torchvision import get_image_backend 34 | if get_image_backend() == 'accimage': 35 | return accimage_loader 36 | else: 37 | return pil_loader 38 | 39 | 40 | def video_loader(video_dir_path, frame_indices, image_loader): 41 | video = [] 42 | for i in frame_indices: 43 | image_path = os.path.join(video_dir_path, '{:06d}.png'.format(i)) 44 | if os.path.exists(image_path): 45 | video.append(image_loader(image_path)) 46 | else: 47 | return video 48 | 49 | return video 50 | 51 | 52 | def get_default_video_loader(): 53 | image_loader = get_default_image_loader() 54 | return functools.partial(video_loader, image_loader=image_loader) 55 | 56 | 57 | def get_video_names_and_annotations(dataset_path, subset): 58 | annotation_path = os.path.join(dataset_path, 'annotations') 59 | 60 | video_names = [] 61 | annotations = [] 62 | 63 | lists = os.listdir(os.path.join(annotation_path,subset)) 64 | lists.sort() 65 | 66 | for i in range(len(lists)): 67 | anno = loadmat(os.path.join(annotation_path, subset, lists[i])) 68 | anno = anno['label'][0,0] 69 | 70 | 71 | video_names.append(lists[i][0:-4]) 72 | annotations.append(anno) 73 | 74 | return video_names, annotations 75 | 76 | 77 | def make_dataset(dataset_path, subset, sample_duration, n_samples_for_each_video, opt): 78 | dataset_path = os.path.join(dataset_path, 'ucf526') 79 | video_path = os.path.join(dataset_path, 'imgs') 80 | 81 | video_names, annotations = get_video_names_and_annotations(dataset_path, subset) 82 | 83 | dataset = [] 84 | max_n_frames = 0 85 | 86 | 87 | for i in range(len(video_names)): 88 | if (i+1) % 50 == 0 or i+1 == len(video_names): 89 | print('{} dataset loading [{}/{}]'.format(subset, i+1, len(video_names))) 90 | 91 | video_path_i = os.path.join(video_path, subset, video_names[i][2:-8], video_names[i]) 92 | if not os.path.exists(video_path_i): 93 | print('error', video_path_i) 94 | continue 95 | 96 | n_frames = int(annotations[i]['duration'][0,0]) 97 | begin_t = int(annotations[i]['start_frame'][0,0]) 98 | end_t = int(annotations[i]['end_frame'][0,0]) 99 | next_t = annotations[i]['offset_next_estimate'][0,:] 100 | pre_t = annotations[i]['offset_pre_estimate'][0,:] 101 | bound_t = annotations[i]['temporal_bound'][:,0] 102 | 103 | if n_frames <= 0 or len(bound_t) < 3: 104 | continue 105 | 106 | if n_samples_for_each_video == 1: 107 | sample = { 108 | 'video': video_path_i, 109 | 'segment': [begin_t, end_t], 110 | 'n_frames': n_frames, 111 | 'video_id': video_names[i], 112 | 'label_next': next_t, 113 | 'label_pre': pre_t, 114 | 'frame_indices': list(range(begin_t, end_t + 1)), 115 | 'counts': len(bound_t) - 1 116 | } 117 | 118 | max_n_frames = max(max_n_frames, sample['n_frames']) 119 | dataset.append(sample) 120 | else: 121 | if n_samples_for_each_video < 1: 122 | raise('error, n_samples_for_each_video should >=1\n') 123 | 124 | sample = { 125 | 'video': video_path_i, 126 | 'video_id': video_names[i], 127 | } 128 | 129 | 130 | for j in range(0, n_samples_for_each_video): 131 | sample_j = copy.deepcopy(sample) 132 | 133 | begin_j_p = random.randint(0, len(bound_t)-3) 134 | end_j_p = random.randint(begin_j_p + 2, len(bound_t)) 135 | 136 | counts = end_j_p - begin_j_p 137 | if begin_j_p == 0: 138 | begin_j = begin_t 139 | else: 140 | begin_j = random.randint(bound_t[begin_j_p-1], bound_t[begin_j_p]) 141 | 142 | if end_j_p == len(bound_t): 143 | end_j = end_t 144 | else: 145 | end_j = random.randint(bound_t[end_j_p-1], bound_t[end_j_p]) 146 | end_j = min(end_j, begin_j + sample_duration - 1) 147 | 148 | 149 | 150 | sample_j['segment'] = [begin_j, end_j] 151 | sample_j['n_frames'] = end_j - begin_j + 1 152 | sample_j['label_next'] = next_t[begin_j - begin_t: end_j - begin_t + 1] 153 | sample_j['label_pre'] = pre_t[begin_j - begin_t: end_j - begin_t + 1] 154 | sample_j['frame_indices'] = list(range(begin_j, end_j + 1)) 155 | sample_j['counts'] = counts 156 | 157 | max_n_frames = max(max_n_frames, sample_j['n_frames']) 158 | 159 | dataset.append(sample_j) 160 | 161 | max_n_frames = max(max_n_frames, sample_duration) 162 | print('[size of dataset, max_n_frames] = ', len(dataset), max_n_frames) 163 | 164 | return dataset, max_n_frames 165 | 166 | 167 | class UCF_AUG(data.Dataset): 168 | 169 | 170 | def __init__(self, 171 | dataset_path, 172 | subset, 173 | sample_duration, 174 | opt, 175 | n_samples_for_each_video=10, 176 | spatial_transform=None, 177 | get_loader=get_default_video_loader): 178 | print('dataset_path, subset, sample_duration, n_samples_for_each_video: ', dataset_path, subset, sample_duration, n_samples_for_each_video) 179 | self.data, self.max_n_frames = make_dataset(dataset_path, subset, sample_duration, n_samples_for_each_video, opt) 180 | 181 | self.spatial_transform = spatial_transform 182 | self.loader = get_loader() 183 | self.mean = [0.0,0.0,0.0] 184 | self.var = [0.0,0.0,0.0] 185 | self.readed_num = 0.0 186 | 187 | def __getitem__(self, index): 188 | path = self.data[index]['video'] 189 | 190 | frame_indices = self.data[index]['frame_indices'] 191 | 192 | clip = self.loader(path, frame_indices) 193 | if self.spatial_transform is not None: 194 | self.spatial_transform.randomize_parameters() 195 | clip = [self.spatial_transform(img) for img in clip] 196 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) 197 | 198 | if clip.size(1) != self.max_n_frames: 199 | clip_zeros = torch.zeros([clip.size(0), self.max_n_frames - clip.size(1), clip.size(2), clip.size(3)], dtype=torch.float) 200 | clip = torch.cat([clip, clip_zeros], dim=1) 201 | 202 | label_next = np.zeros(self.max_n_frames, dtype=np.int32) 203 | label_pre = np.zeros(self.max_n_frames, dtype=np.int32) 204 | 205 | for i in range(0, self.data[index]['n_frames']): 206 | label_next[i] = self.data[index]['label_next'][i] 207 | label_pre[i] = self.data[index]['label_pre'][i] 208 | 209 | return clip, label_next, label_pre, self.data[index]['counts'], self.data[index]['n_frames'] 210 | 211 | def __len__(self): 212 | return len(self.data) 213 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch import optim 8 | from torch.optim import lr_scheduler 9 | from collections import OrderedDict 10 | 11 | from opts import parse_opts 12 | from models.model import generate_model 13 | from utils.mean import get_mean, get_std 14 | from utils.spatial_transforms import ( 15 | Compose, Normalize, Scale_shorterside, Scale_longerside, CenterCrop, CornerCrop, MultiScaleCornerCrop, 16 | MultiScaleRandomCrop, RandomHorizontalFlip, ToTensor) 17 | from utils.temporal_transforms import LoopPadding, TemporalRandomCrop 18 | from utils.target_transforms import ClassLabel, VideoID 19 | from utils.target_transforms import Compose as TargetCompose 20 | 21 | from dataset import get_training_set, get_validation_set 22 | from utils.utils import Logger 23 | 24 | 25 | if __name__ == '__main__': 26 | import sys 27 | print(sys.version) 28 | print(torch.__version__) 29 | 30 | opt = parse_opts() 31 | if opt.root_path != '': 32 | opt.dataset_path = os.path.join(opt.root_path, opt.dataset_path) 33 | opt.result_path = os.path.join(opt.root_path, opt.result_path) 34 | if not os.path.exists(opt.result_path): 35 | os.makedirs(opt.result_path) 36 | if opt.resume_path: 37 | opt.resume_path = os.path.join(opt.root_path, opt.resume_path) 38 | if opt.pretrain_path: 39 | opt.pretrain_path = os.path.join(opt.root_path, opt.pretrain_path) 40 | 41 | opt.scales = [opt.initial_scale] 42 | for i in range(1, opt.n_scales): 43 | opt.scales.append(opt.scales[-1] * opt.scale_step) 44 | opt.arch = '{}-{}'.format(opt.model, opt.model_depth) 45 | opt.mean = get_mean(opt.norm_value, dataset=opt.mean_std_dataset) 46 | opt.std = get_std(opt.norm_value, dataset=opt.mean_std_dataset) 47 | print(opt) 48 | with open(os.path.join(opt.result_path, 'opts.json'), 'w') as opt_file: 49 | json.dump(vars(opt), opt_file) 50 | 51 | torch.manual_seed(opt.manual_seed) 52 | 53 | 54 | criterion = nn.CrossEntropyLoss() 55 | if not opt.no_cuda: 56 | criterion = criterion.cuda() 57 | 58 | norm_method = Normalize(opt.mean, opt.std) 59 | 60 | model, parameters = generate_model(opt) 61 | print(model) 62 | 63 | if not opt.no_train: 64 | assert opt.train_crop in ['random', 'corner', 'center'] 65 | if opt.train_crop == 'random': 66 | crop_method = MultiScaleRandomCrop(opt.scales, opt.sample_size) 67 | elif opt.train_crop == 'corner': 68 | crop_method = MultiScaleCornerCrop(opt.scales, opt.sample_size) 69 | elif opt.train_crop == 'center': 70 | crop_method = MultiScaleCornerCrop( 71 | opt.scales, opt.sample_size, crop_positions=['c']) 72 | 73 | spatial_transform = Compose([ 74 | Scale_longerside(opt.sample_size), 75 | CenterCrop(opt.sample_size), 76 | RandomHorizontalFlip(), 77 | ToTensor(opt.norm_value), norm_method 78 | ]) 79 | 80 | 81 | target_transform = ClassLabel() 82 | training_data = get_training_set(opt, spatial_transform, target_transform) 83 | 84 | train_loader = torch.utils.data.DataLoader( 85 | training_data, 86 | batch_size=opt.batch_size, 87 | shuffle=True, 88 | num_workers=opt.n_threads, 89 | pin_memory=True) 90 | 91 | 92 | 93 | if opt.learning_policy == '2stream': 94 | train_logger = Logger( 95 | os.path.join(opt.result_path, 'train.log'), 96 | ['epoch', 'loss', 'loss_cls', 'loss_box', 'OBOA', 'MAE', 'MAEP', 'MAEN', 'lr']) 97 | train_batch_logger = Logger( 98 | os.path.join(opt.result_path, 'train_batch.log'), 99 | ['epoch', 'batch', 'iter', 'loss', 'loss_cls', 'loss_box', 'OBOA', 'MAE', 'MAEP', 'MAEN', 'lr']) 100 | from train_2stream import train_epoch 101 | 102 | 103 | if opt.nesterov: 104 | dampening = 0 105 | else: 106 | dampening = opt.dampening 107 | 108 | 109 | finetune_parameters = [] 110 | 111 | 112 | 113 | ignored_params = list(map(id, finetune_parameters)) 114 | base_parameters = filter(lambda p: id(p) not in ignored_params,model.parameters()) 115 | 116 | if opt.optimizer == 'sgd': 117 | optimizer = optim.SGD( 118 | parameters, 119 | lr=opt.learning_rate, 120 | momentum=opt.momentum, 121 | dampening=dampening, 122 | weight_decay=opt.weight_decay, 123 | nesterov=opt.nesterov) 124 | 125 | elif opt.optimizer == 'adam': 126 | if opt.train_from_scratch == True: 127 | optimizer = optim.Adam([ 128 | {'params': base_parameters}, 129 | {'params': finetune_parameters, 'lr': opt.learning_rate*2}], 130 | lr=opt.learning_rate, 131 | weight_decay=opt.weight_decay) 132 | else: 133 | optimizer = optim.Adam( 134 | finetune_parameters, 135 | lr=opt.learning_rate*2, 136 | weight_decay=opt.weight_decay) 137 | 138 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[5,15], gamma=0.1) 139 | 140 | 141 | if not opt.no_val: 142 | spatial_transform = Compose([ 143 | Scale_longerside(opt.sample_size), 144 | CenterCrop(opt.sample_size), 145 | ToTensor(opt.norm_value), norm_method 146 | ]) 147 | target_transform = ClassLabel() 148 | 149 | val_loader = {} 150 | for j in range(0, len(opt.val_dataset)): 151 | validation_data = get_validation_set(opt.val_dataset[j], spatial_transform, target_transform, opt) 152 | 153 | val_loader[j] = torch.utils.data.DataLoader( 154 | validation_data, 155 | batch_size=opt.val_batch_size, 156 | shuffle=False, 157 | num_workers=opt.n_threads, 158 | pin_memory=True) 159 | 160 | 161 | if opt.validate_policy == '2stream': 162 | val_logger = {} 163 | for j in range(0, len(val_loader)): 164 | val_logger[j] = Logger( 165 | os.path.join(opt.result_path, 'val_'+opt.val_dataset[j]+'.log'), 166 | ['epoch', 'OBOA', 'MAE', 'MAE_std', 'MAEP', 'MAEN']) 167 | from val_2stream import val_epoch 168 | 169 | 170 | if opt.pretrain_path: 171 | print('loading pretrained checkpoint {}'.format(opt.pretrain_path)) 172 | pretrain = torch.load(opt.pretrain_path) 173 | pretrain = pretrain['state_dict'] 174 | new_state_dict = OrderedDict() 175 | 176 | for k, v in pretrain.items(): 177 | name = k[7:] # remove `module.` 178 | new_state_dict[name] = v 179 | 180 | model.load_state_dict(new_state_dict, strict=False) 181 | 182 | 183 | if opt.resume_path: 184 | print('loading checkpoint {}'.format(opt.resume_path)) 185 | checkpoint = torch.load(opt.resume_path) 186 | # assert opt.arch == checkpoint['arch'] 187 | 188 | opt.begin_epoch = checkpoint['epoch'] 189 | model.load_state_dict(checkpoint['state_dict'], strict=True) 190 | if not opt.no_train: 191 | optimizer.load_state_dict(checkpoint['optimizer']) 192 | 193 | del checkpoint 194 | torch.cuda.empty_cache() 195 | 196 | 197 | print('run') 198 | for i in range(opt.begin_epoch, opt.n_epochs + 1): 199 | 200 | if not opt.no_train: 201 | if opt.learning_policy == '2stream': 202 | train_epoch(i, train_loader, model, optimizer, opt, train_logger, train_batch_logger) 203 | 204 | if not opt.no_val: 205 | for j in range(0, len(val_loader)): 206 | validation_loss = val_epoch(i, val_loader[j], model, opt, val_logger[j], opt.val_dataset[j]) 207 | if opt.no_train: 208 | break 209 | 210 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = ['ResNeXt', 'resnet50', 'resnet101'] 9 | 10 | 11 | def conv3x3x3(in_planes, out_planes, stride=1): 12 | # 3x3x3 convolution with padding 13 | return nn.Conv3d( 14 | in_planes, 15 | out_planes, 16 | kernel_size=3, 17 | stride=stride, 18 | padding=1, 19 | bias=False) 20 | 21 | 22 | def downsample_basic_block(x, planes, stride): 23 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 24 | zero_pads = torch.Tensor( 25 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 26 | out.size(4)).zero_() 27 | if isinstance(out.data, torch.cuda.FloatTensor): 28 | zero_pads = zero_pads.cuda() 29 | 30 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 31 | 32 | return out 33 | 34 | 35 | class ResNeXtBottleneck(nn.Module): 36 | expansion = 2 37 | 38 | def __init__(self, inplanes, planes, cardinality, stride=1, 39 | downsample=None, kernels=3): 40 | super(ResNeXtBottleneck, self).__init__() 41 | mid_planes = cardinality * int(planes / 32) 42 | self.conv1 = nn.Conv3d(inplanes, mid_planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm3d(mid_planes) 44 | self.conv2 = nn.Conv3d( 45 | mid_planes, 46 | mid_planes, 47 | kernel_size=kernels, 48 | stride=stride, 49 | padding=1, 50 | groups=cardinality, 51 | bias=False) 52 | self.bn2 = nn.BatchNorm3d(mid_planes) 53 | self.conv3 = nn.Conv3d( 54 | mid_planes, planes * self.expansion, kernel_size=1, bias=False) 55 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | residual = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv3(out) 72 | out = self.bn3(out) 73 | 74 | if self.downsample is not None: 75 | residual = self.downsample(x) 76 | 77 | out += residual 78 | out = self.relu(out) 79 | 80 | return out 81 | 82 | 83 | class ResNeXt(nn.Module): 84 | 85 | def __init__(self, block, layers, opt): 86 | # default 87 | shortcut_type='B' 88 | cardinality=32 89 | num_classes=400 90 | 91 | # user paras 92 | num_classes=opt.n_classes 93 | shortcut_type=opt.resnet_shortcut 94 | cardinality=opt.resnext_cardinality 95 | sample_size=opt.sample_size 96 | sample_duration=opt.basic_duration 97 | self.learning_policy=opt.learning_policy 98 | self.num_classes = opt.n_classes 99 | self.inplanes = 64 100 | super(ResNeXt, self).__init__() 101 | 102 | down_stride_1 = (1, 2, 2) 103 | down_stride_2 = (2, 2, 2) 104 | 105 | self.conv1 = nn.Conv3d( 106 | 3, 107 | 64, 108 | kernel_size=7, 109 | stride=down_stride_1, 110 | padding=(3, 3, 3), 111 | bias=False) 112 | self.bn1 = nn.BatchNorm3d(64) 113 | self.relu = nn.ReLU(inplace=True) 114 | 115 | base_c = 128 116 | 117 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=down_stride_2, padding=1) 118 | self.layer1 = self._make_layer( 119 | block, base_c, layers[0], shortcut_type, cardinality) 120 | self.layer2 = self._make_layer( 121 | block, base_c*2, layers[1], shortcut_type, cardinality, stride=down_stride_2) 122 | self.layer3 = self._make_layer( 123 | block, base_c*4, layers[2], shortcut_type, cardinality, stride=down_stride_2) 124 | self.layer4 = self._make_layer( 125 | block, base_c*8, layers[3], shortcut_type, cardinality, stride=down_stride_2) 126 | 127 | last_duration = int(1) 128 | last_size = int(math.ceil(sample_size / 32.0)) 129 | self.avgpool = nn.AvgPool3d( 130 | (last_duration, last_size, last_size), stride=1) 131 | 132 | self.t_all = int(sample_duration / 16.0) 133 | self.dims = int(base_c*8 * block.expansion) 134 | 135 | self.fc_emd = self.dims * self.t_all / 2 136 | 137 | if self.learning_policy == '2stream': 138 | self.fc_cls_1 = nn.Linear(self.fc_emd, 2*num_classes).cuda() 139 | self.fc_box_1 = nn.Linear(self.fc_emd, num_classes).cuda() 140 | self.fc_cls_2 = nn.Linear(self.fc_emd, 2*num_classes).cuda() 141 | self.fc_box_2 = nn.Linear(self.fc_emd, num_classes).cuda() 142 | 143 | self.fc = [] 144 | self.others = [] 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv3d): 147 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') 148 | elif isinstance(m, nn.BatchNorm3d): 149 | m.weight.data.fill_(1) 150 | m.bias.data.zero_() 151 | 152 | 153 | def _make_layer(self, 154 | block, 155 | planes, 156 | blocks, 157 | shortcut_type, 158 | cardinality, 159 | stride=1, 160 | kernels=3): 161 | downsample = None 162 | if stride != 1 or self.inplanes != planes * block.expansion: 163 | if shortcut_type == 'A': 164 | downsample = partial( 165 | downsample_basic_block, 166 | planes=planes * block.expansion, 167 | stride=stride) 168 | else: 169 | downsample = nn.Sequential( 170 | nn.Conv3d( 171 | self.inplanes, 172 | planes * block.expansion, 173 | kernel_size=1, 174 | stride=stride, 175 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 176 | 177 | layers = [] 178 | layers.append( 179 | block(self.inplanes, planes, cardinality, stride, downsample)) 180 | self.inplanes = planes * block.expansion 181 | for i in range(1, blocks): 182 | layers.append(block(self.inplanes, planes, cardinality)) 183 | 184 | return nn.Sequential(*layers) 185 | 186 | def forward(self, x): 187 | print_flag = False 188 | if print_flag: 189 | print('x1 ', x.size()) 190 | x = self.conv1(x) 191 | if print_flag: 192 | print('x2 ', x.size()) 193 | x = self.bn1(x) 194 | x = self.relu(x) 195 | x1 = self.maxpool(x) 196 | x2 = self.layer1(x1) 197 | x3 = self.layer2(x2) 198 | x4 = self.layer3(x3) 199 | x5 = self.layer4(x4) 200 | x = self.avgpool(x5) 201 | 202 | new_x_1 = x[:,:,0:self.t_all/2,:,:] 203 | new_x_2 = x[:,:,self.t_all/2:self.t_all,:,:] 204 | 205 | new_x_1 = new_x_1.reshape(-1, self.fc_emd) 206 | new_x_2 = new_x_2.reshape(-1, self.fc_emd) 207 | 208 | if self.learning_policy == '2stream': 209 | pred_cls_1 = self.fc_cls_1(new_x_1) 210 | pred_cls_1 = pred_cls_1.reshape(-1, 2, self.num_classes) 211 | pred_box_1 = self.fc_box_1(new_x_1) 212 | 213 | pred_cls_2 = self.fc_cls_2(new_x_2) 214 | pred_cls_2 = pred_cls_2.reshape(-1, 2, self.num_classes) 215 | pred_box_2 = self.fc_box_2(new_x_2) 216 | return pred_cls_1, pred_box_1, pred_cls_2, pred_box_2 217 | 218 | 219 | 220 | 221 | def get_fine_tuning_parameters(model, ft_begin_index): 222 | if ft_begin_index == 0: 223 | return model.parameters() 224 | 225 | ft_module_names = [] 226 | for i in range(ft_begin_index, 5): 227 | ft_module_names.append('layer{}'.format(i)) 228 | ft_module_names.append('fc') 229 | 230 | parameters = [] 231 | for k, v in model.named_parameters(): 232 | for ft_module in ft_module_names: 233 | if ft_module in k: 234 | parameters.append({'params': v}) 235 | break 236 | else: 237 | parameters.append({'params': v, 'lr': 0.0}) 238 | 239 | return parameters 240 | 241 | 242 | def resnet50(**kwargs): 243 | """Constructs a ResNet-50 model. 244 | """ 245 | model = ResNeXt(ResNeXtBottleneck, [3, 4, 6, 3], **kwargs) 246 | return model 247 | 248 | 249 | def resnet101(**kwargs): 250 | """Constructs a ResNet-101 model. 251 | """ 252 | model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], **kwargs) 253 | return model 254 | 255 | 256 | def resnet152(**kwargs): 257 | """Constructs a ResNet-101 model. 258 | """ 259 | model = ResNeXt(ResNeXtBottleneck, [3, 8, 36, 3], **kwargs) 260 | return model 261 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = [ 9 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnet200' 11 | ] 12 | 13 | 14 | def conv3x3x3(in_planes, out_planes, stride=1): 15 | # 3x3x3 convolution with padding 16 | return nn.Conv3d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=3, 20 | stride=stride, 21 | padding=1, 22 | bias=False) 23 | 24 | 25 | def downsample_basic_block(x, planes, stride): 26 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 27 | zero_pads = torch.Tensor( 28 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 29 | out.size(4)).zero_() 30 | if isinstance(out.data, torch.cuda.FloatTensor): 31 | zero_pads = zero_pads.cuda() 32 | 33 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 34 | 35 | return out 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None): 42 | super(BasicBlock, self).__init__() 43 | self.conv1 = conv3x3x3(inplanes, planes, stride) 44 | self.bn1 = nn.BatchNorm3d(planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = conv3x3x3(planes, planes) 47 | self.bn2 = nn.BatchNorm3d(planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | residual = x 53 | 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class Bottleneck(nn.Module): 71 | expansion = 4 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None): 74 | super(Bottleneck, self).__init__() 75 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 76 | self.bn1 = nn.BatchNorm3d(planes) 77 | self.conv2 = nn.Conv3d( 78 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 79 | self.bn2 = nn.BatchNorm3d(planes) 80 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 81 | self.bn3 = nn.BatchNorm3d(planes * 4) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | def forward(self, x): 87 | residual = x 88 | 89 | out = self.conv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv2(out) 94 | out = self.bn2(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv3(out) 98 | out = self.bn3(out) 99 | 100 | if self.downsample is not None: 101 | residual = self.downsample(x) 102 | 103 | out += residual 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | 111 | def __init__(self, block, layers, opt): 112 | down_stride_1 = (1, 2, 2) 113 | down_stride_2 = (2, 2, 2) 114 | 115 | self.inplanes = 64 116 | self.learning_policy = opt.learning_policy 117 | self.num_classes = opt.n_classes 118 | num_classes = opt.n_classes 119 | shortcut_type = opt.resnet_shortcut 120 | sample_size = opt.sample_size 121 | sample_duration = opt.basic_duration 122 | 123 | super(ResNet, self).__init__() 124 | self.conv1 = nn.Conv3d( 125 | 3, 126 | 64, 127 | kernel_size=7, 128 | stride=down_stride_1, 129 | padding=(3, 3, 3), 130 | bias=False) 131 | self.bn1 = nn.BatchNorm3d(64) 132 | self.relu = nn.ReLU(inplace=True) 133 | 134 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=down_stride_2, padding=1) 135 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 136 | self.layer2 = self._make_layer( 137 | block, 128, layers[1], shortcut_type, stride=down_stride_2) 138 | self.layer3 = self._make_layer( 139 | block, 256, layers[2], shortcut_type, stride=down_stride_2) 140 | self.layer4 = self._make_layer( 141 | block, 512, layers[3], shortcut_type, stride=down_stride_1) 142 | 143 | last_duration = int(1) 144 | last_size = int(math.ceil(sample_size / 32.0)) 145 | self.maxpool_final = nn.MaxPool3d( 146 | (last_duration, last_size, last_size), stride=1) 147 | 148 | self.t_all = int(sample_duration / 8.0) 149 | self.dims = int(512 * block.expansion) 150 | 151 | self.fc_emd = self.dims * self.t_all 152 | 153 | if self.learning_policy == '2stream': 154 | self.fc_cls_1 = nn.Linear(self.fc_emd, 2*num_classes).cuda() 155 | self.fc_box_1 = nn.Linear(self.fc_emd, num_classes).cuda() 156 | self.fc_cls_2 = nn.Linear(self.fc_emd, 2*num_classes).cuda() 157 | self.fc_box_2 = nn.Linear(self.fc_emd, num_classes).cuda() 158 | 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv3d): 162 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 163 | elif isinstance(m, nn.BatchNorm3d): 164 | m.weight.data.fill_(1) 165 | m.bias.data.zero_() 166 | 167 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 168 | downsample = None 169 | if stride != 1 or self.inplanes != planes * block.expansion: 170 | if shortcut_type == 'A': 171 | downsample = partial( 172 | downsample_basic_block, 173 | planes=planes * block.expansion, 174 | stride=stride) 175 | else: 176 | downsample = nn.Sequential( 177 | nn.Conv3d( 178 | self.inplanes, 179 | planes * block.expansion, 180 | kernel_size=1, 181 | stride=stride, 182 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 183 | 184 | layers = [] 185 | layers.append(block(self.inplanes, planes, stride, downsample)) 186 | self.inplanes = planes * block.expansion 187 | for i in range(1, blocks): 188 | layers.append(block(self.inplanes, planes)) 189 | 190 | return nn.Sequential(*layers) 191 | 192 | def forward(self, x): 193 | x = self.conv1(x) 194 | x = self.bn1(x) 195 | x = self.relu(x) 196 | x = self.maxpool(x) 197 | x = self.layer1(x) 198 | x = self.layer2(x) 199 | x = self.layer3(x) 200 | x = self.layer4(x) 201 | x = self.maxpool_final(x) 202 | 203 | 204 | if self.learning_policy == '2stream': 205 | new_x = x.reshape(-1, self.fc_emd) 206 | pred_cls_1 = self.fc_cls_1(new_x) 207 | pred_cls_1 = pred_cls_1.reshape(-1, 2, self.num_classes) 208 | pred_box_1 = self.fc_box_1(new_x) 209 | 210 | pred_cls_2 = self.fc_cls_2(new_x) 211 | pred_cls_2 = pred_cls_2.reshape(-1, 2, self.num_classes) 212 | pred_box_2 = self.fc_box_2(new_x) 213 | return pred_cls_1, pred_box_1, pred_cls_2, pred_box_2 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | def get_fine_tuning_parameters(model, ft_begin_index): 223 | if ft_begin_index == 0: 224 | return model.parameters() 225 | 226 | ft_module_names = [] 227 | for i in range(ft_begin_index, 5): 228 | ft_module_names.append('layer{}'.format(i)) 229 | ft_module_names.append('fc') 230 | 231 | parameters = [] 232 | for k, v in model.named_parameters(): 233 | for ft_module in ft_module_names: 234 | if ft_module in k: 235 | parameters.append({'params': v}) 236 | break 237 | else: 238 | parameters.append({'params': v, 'lr': 0.0}) 239 | 240 | return parameters 241 | 242 | 243 | def resnet10(**kwargs): 244 | """Constructs a ResNet-18 model. 245 | """ 246 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 247 | return model 248 | 249 | 250 | def resnet18(**kwargs): 251 | """Constructs a ResNet-18 model. 252 | """ 253 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 254 | return model 255 | 256 | 257 | def resnet34(**kwargs): 258 | """Constructs a ResNet-34 model. 259 | """ 260 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 261 | return model 262 | 263 | 264 | def resnet50(**kwargs): 265 | """Constructs a ResNet-50 model. 266 | """ 267 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 268 | return model 269 | 270 | 271 | def resnet101(**kwargs): 272 | """Constructs a ResNet-101 model. 273 | """ 274 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 275 | return model 276 | 277 | 278 | def resnet152(**kwargs): 279 | """Constructs a ResNet-101 model. 280 | """ 281 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 282 | return model 283 | 284 | 285 | def resnet200(**kwargs): 286 | """Constructs a ResNet-101 model. 287 | """ 288 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) 289 | return model 290 | -------------------------------------------------------------------------------- /train_2stream.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import time 6 | import os 7 | import sys 8 | import numpy as np 9 | import random 10 | import math 11 | 12 | from utils.utils import AverageMeter 13 | from utils.myutils import update_inputs_2stream 14 | 15 | train_opt = {} 16 | train_opt['early_stop'] = 10 17 | train_opt['iter_terminal_num'] = 1000 18 | 19 | val_opt = {} 20 | val_opt['min_scale'] = 0.03 21 | val_opt['max_scale'] = 0.35 22 | val_opt['init_scale_num'] = 30 23 | val_opt['abandon_second_box'] = False 24 | 25 | def update_labels(label, state, sample_len, opt): 26 | if (state[0] >= sample_len): 27 | tmp = -1 28 | else: 29 | tmp = label[state[0]] 30 | gt_cls = torch.zeros([opt.n_classes], dtype=torch.float).cuda() 31 | gt_box = torch.zeros([opt.n_classes], dtype=torch.float).cuda() 32 | if tmp == -1: 33 | for i in range(0, opt.n_classes): 34 | gt_cls[i] = -1 35 | gt_box[i] = -1 36 | else: 37 | tmp = tmp + 1 38 | for i in range(0, opt.n_classes): 39 | anchor = opt.anchors[i] * state[1] 40 | 41 | I = min(tmp, anchor) 42 | U = max(tmp, anchor) 43 | IOU = float(I) / U 44 | 45 | if IOU >= opt.iou_ubound: 46 | gt_cls[i] = 1 47 | gt_box[i] = math.log(tmp / anchor) 48 | elif IOU <= opt.iou_lbound: 49 | gt_cls[i] = 0 50 | gt_box[i] = -1 51 | else: 52 | gt_cls[i] = -1 53 | gt_box[i] = -1 54 | 55 | return gt_cls, gt_box 56 | 57 | def action_step(state, action_1, action_2, step, sample_len, opt): 58 | lp, mp, rp = state 59 | 60 | seg_len_1 = (mp - lp + 1) * action_1 61 | seg_len_2 = (rp - mp) * action_2 62 | 63 | mp = int(mp + step) 64 | lp = int(mp - seg_len_1 + 1) 65 | rp = int(mp + seg_len_2) 66 | 67 | state = (lp, mp, rp) 68 | 69 | done_flag = mp >= sample_len 70 | fail_flag = (mp - lp + 1) < 4 or (rp - mp) < 4 71 | 72 | return state, done_flag, fail_flag 73 | 74 | def state_init(epoch, label_next, label_pre, label_counts, sample_len, opt): 75 | if label_next[0] == -1: 76 | lp2, rp2 = 0, sample_len / label_counts - 1 77 | else: 78 | lp2, rp2 = 0, label_next[0] - 1 79 | 80 | lp = lp2 + int(random.random() * 1.0 * (rp2 - lp2 + 1)) 81 | 82 | magic = random.random() 83 | if magic < 0.25: 84 | seg_ratio = math.pow(2, (random.random()-0.5)*2) 85 | seg_len = (rp2 - lp2 + 1) * seg_ratio 86 | elif magic > 0.75: 87 | seg_ratio = random.randint(-1, 1) 88 | if seg_ratio == -2: 89 | seg_len = (rp2 - lp2 + 1) * (0.33+(random.random()-0.5)*0.1) 90 | else: 91 | seg_len = (rp2 - lp2 + 1) * (math.pow(2, seg_ratio)+(random.random()-0.5)*0.1) 92 | else: 93 | k = random.randint(0, val_opt['init_scale_num']) 94 | powers_level = (val_opt['max_scale'] / val_opt['min_scale']) ** (float(k)/(val_opt['init_scale_num']-1)) 95 | seg_len = sample_len * val_opt['min_scale'] * powers_level 96 | 97 | rp = lp + int(seg_len-1) 98 | rp = max(rp, lp) 99 | 100 | if rp >= sample_len: 101 | lp, rp = 0, sample_len / label_counts - 1 102 | 103 | if rp * 2 + 1 >= sample_len: 104 | lp, rp = 0, sample_len / 2 - 1 105 | 106 | return lp, rp 107 | 108 | def train_epoch(epoch, data_loader, model, optimizer, opt, 109 | epoch_logger, batch_logger): 110 | 111 | print('train at epoch {}'.format(epoch)) 112 | 113 | model.train() 114 | 115 | batch_time = AverageMeter() 116 | data_time = AverageMeter() 117 | losses = AverageMeter() 118 | losses_cls = AverageMeter() 119 | losses_box = AverageMeter() 120 | maes = AverageMeter() 121 | maeps = AverageMeter() 122 | maens = AverageMeter() 123 | oboas = AverageMeter() 124 | 125 | end_time = time.time() 126 | 127 | CrossEntropyLoss = nn.CrossEntropyLoss(ignore_index = -1).cuda() 128 | SmoothL1Loss = nn.SmoothL1Loss().cuda() 129 | 130 | 131 | 132 | for i, (sample_inputs, label_next, label_pre, label_counts, sample_len) in enumerate(data_loader): 133 | 134 | if train_opt['iter_terminal_num'] != -1 and i > train_opt['iter_terminal_num']: 135 | break 136 | 137 | data_time.update(time.time() - end_time) 138 | 139 | batch_size = sample_inputs.size(0) 140 | 141 | # targets init 142 | label_next = label_next.numpy() 143 | label_pre = label_pre.numpy() 144 | label_counts = label_counts.numpy() 145 | sample_len = sample_len.numpy() 146 | total_steps = 0 147 | 148 | # track state init 149 | lp = np.zeros(batch_size, dtype=np.int) 150 | mp = np.zeros(batch_size, dtype=np.int) 151 | rp = np.zeros(batch_size, dtype=np.int) 152 | counts = np.zeros(batch_size, dtype=np.float) 153 | pre_counts = np.zeros(batch_size, dtype=np.float) 154 | end_flag = np.zeros(batch_size, dtype=np.int) 155 | for j in range(0, batch_size): 156 | while rp[j] == 0 or rp[j] >= sample_len[j]: 157 | lp[j], mp[j] = state_init(epoch, label_next[j], label_pre[j], label_counts[j], sample_len[j], opt) 158 | rp[j] = mp[j] + (mp[j] - lp[j] + 1) 159 | 160 | 161 | 162 | while 1: 163 | inputs = torch.zeros([batch_size, 3, opt.basic_duration, opt.sample_size, opt.sample_size], dtype=torch.float).cuda() 164 | # network input initilization 165 | for j in range(0, batch_size): 166 | inputs[j], _ = update_inputs_2stream(sample_inputs[j], [lp[j], mp[j], rp[j]], sample_len[j], opt) 167 | 168 | # prepare label 169 | gt_cls_1 = torch.zeros([batch_size, opt.n_classes], dtype=torch.long).cuda() 170 | gt_box_1 = torch.zeros([batch_size, opt.n_classes], dtype=torch.float).cuda() 171 | gt_cls_2 = torch.zeros([batch_size, opt.n_classes], dtype=torch.long).cuda() 172 | gt_box_2 = torch.zeros([batch_size, opt.n_classes], dtype=torch.float).cuda() 173 | 174 | for j in range(0, batch_size): 175 | gt_cls_1[j], gt_box_1[j] = update_labels(label_pre[j], [mp[j], mp[j]-lp[j]+1], sample_len[j], opt) 176 | gt_cls_2[j], gt_box_2[j] = update_labels(label_next[j], [mp[j]+1, rp[j]-mp[j]], sample_len[j], opt) 177 | 178 | # do the forward 179 | inputs = Variable(inputs) 180 | pred_cls_1, pred_box_1, pred_cls_2, pred_box_2 = model(inputs) 181 | 182 | for j in range(0, batch_size): 183 | for k in range(0, opt.n_classes): 184 | if gt_box_1[j][k] == -1: 185 | gt_box_1[j][k] = pred_box_1[j][k].detach() 186 | 187 | if gt_box_2[j][k] == -1: 188 | gt_box_2[j][k] = pred_box_2[j][k].detach() 189 | 190 | # loss calculate 191 | if val_opt['abandon_second_box'] == True: 192 | loss_cls = CrossEntropyLoss(pred_cls_1, gt_cls_1) * 1.0 193 | loss_box = SmoothL1Loss(pred_box_1, gt_box_1) * 50.0 194 | else: 195 | loss_cls = CrossEntropyLoss(pred_cls_1, gt_cls_1) * 1.0 + CrossEntropyLoss(pred_cls_2, gt_cls_2) * 1.0 196 | loss_box = SmoothL1Loss(pred_box_1, gt_box_1) * 50.0 + SmoothL1Loss(pred_box_2, gt_box_2) * 50.0 # 10 is from the faster-rcnn imple 197 | loss = loss_cls + loss_box 198 | 199 | losses_cls.update(loss_cls.item(), inputs.size(0)) 200 | losses_box.update(loss_box.item(), inputs.size(0)) 201 | losses.update(loss.item(), inputs.size(0)) 202 | 203 | # optimization 204 | optimizer.zero_grad() 205 | loss.backward() 206 | optimizer.step() 207 | 208 | pred_box_1 = torch.clamp(pred_box_1, min=-2.0, max=2.0) 209 | pred_box_2 = torch.clamp(pred_box_2, min=-2.0, max=2.0) 210 | 211 | # track state update 212 | for j in range(0, batch_size): 213 | magic_step = 5 + random.random() * 15 214 | step = int(max(sample_len[j]/magic_step, 1)) 215 | 216 | max_score, action_1 = -1e6, -1 217 | for k in range(0, opt.n_classes): 218 | box_exp = math.exp(pred_box_1[j][k]) 219 | pred_seg = box_exp * opt.anchors[k] 220 | penalty = 1 221 | score = F.softmax(pred_cls_1, dim=1)[j][1][k] * penalty 222 | if score > max_score: 223 | max_score, action_1 = score, pred_seg 224 | 225 | max_score, action_2 = -1e6, -1 226 | for k in range(0, opt.n_classes): 227 | box_exp = math.exp(pred_box_2[j][k]) 228 | pred_seg = box_exp * opt.anchors[k] 229 | penalty = 1 230 | score = F.softmax(pred_cls_2, dim=1)[j][1][k] * penalty 231 | if score > max_score: 232 | max_score, action_2 = score, pred_seg 233 | if val_opt['abandon_second_box'] == True: 234 | action_2 = action_1 235 | 236 | new_state, done_flag, fail_flag = action_step([lp[j], mp[j], rp[j]], action_1, action_2, step, sample_len[j], opt) 237 | lp[j], mp[j], rp[j] = new_state 238 | 239 | if fail_flag or done_flag: 240 | rp[j] = 0 241 | while rp[j] == 0 or rp[j] >= sample_len[j]: 242 | lp[j], mp[j] = state_init(epoch, label_next[j], label_pre[j], label_counts[j], sample_len[j], opt) 243 | rp[j] = mp[j] + (mp[j] - lp[j] + 1) 244 | pre_counts[j] = 0 245 | counts[j] = pre_counts[j] + float(sample_len[j]-lp[j]+1e-6) / float(mp[j]-lp[j]+1) 246 | else: 247 | pre_counts[j] = pre_counts[j] + step / float(mp[j]-lp[j]+1) 248 | counts[j] = pre_counts[j] + float(sample_len[j]-lp[j]+1e-6) / float(mp[j]-lp[j]+1) 249 | 250 | 251 | 252 | if done_flag: 253 | end_flag[j] = 1 254 | 255 | # terminal condition 256 | total_steps += 1 257 | if sum(end_flag) == batch_size or total_steps > train_opt['early_stop']: 258 | for j in range(0, batch_size): 259 | if counts[j] == 0: 260 | counts[j] = float(sample_len[j]) / float(mp[j]-lp[j]+1) 261 | 262 | mae = float(abs(counts[j] - label_counts[j]))/ float(label_counts[j]) 263 | if abs(counts[j] - label_counts[j]) > 1: 264 | oboa = 0.0 265 | else: 266 | oboa = 1.0 267 | 268 | maes.update(mae) 269 | if counts[j] > label_counts[j]: 270 | maeps.update(mae) 271 | elif counts[j] < label_counts[j]: 272 | maens.update(mae) 273 | oboas.update(oboa) 274 | break 275 | 276 | 277 | batch_time.update(time.time() - end_time) 278 | end_time = time.time() 279 | 280 | batch_logger.log({ 281 | 'epoch': epoch, 282 | 'batch': i + 1, 283 | 'iter': (epoch - 1) * len(data_loader) + (i + 1), 284 | 'loss': losses.val, 285 | 'loss_cls': losses_cls.val, 286 | 'loss_box': losses_box.val, 287 | 'OBOA': oboas.val, 288 | 'MAE': maes.val, 289 | 'MAEP': maeps.val, 290 | 'MAEN': maens.val, 291 | 'lr': optimizer.param_groups[0]['lr'] 292 | }) 293 | 294 | print('Epoch: [{0}][{1}/{2}]\t' 295 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 296 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 297 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg:.4f})\t' 298 | 'Loss_box {loss_box.val:.4f} ({loss_box.avg:.4f})\t' 299 | 'OBOA {oboa.val:.4f} ({oboa.avg:.4f})\t' 300 | 'MAE {maes.val:.4f} ({maes.avg:.4f})\t' 301 | 'MAEP {maeps.val:.4f} ({maeps.avg:.4f})\t' 302 | 'MAEN {maens.val:.4f} ({maens.avg:.4f})\t' 303 | 'total_steps {total_steps: d}'.format( 304 | epoch, 305 | i + 1, 306 | len(data_loader), 307 | batch_time=batch_time, 308 | loss=losses, 309 | loss_cls=losses_cls, 310 | loss_box=losses_box, 311 | oboa=oboas, 312 | maes=maes, 313 | maeps=maeps, 314 | maens=maens, 315 | total_steps=total_steps)) 316 | 317 | epoch_logger.log({ 318 | 'epoch': epoch, 319 | 'loss': losses.avg, 320 | 'loss_cls': losses_cls.avg, 321 | 'loss_box': losses_box.avg, 322 | 'OBOA': oboas.avg, 323 | 'MAE': maes.avg, 324 | 'MAEP': maeps.avg, 325 | 'MAEN': maens.avg, 326 | 'lr': optimizer.param_groups[0]['lr'] 327 | }) 328 | 329 | 330 | if epoch % opt.checkpoint == 0: 331 | save_file_path = os.path.join(opt.result_path, 332 | 'save_{}.pth'.format(epoch)) 333 | states = { 334 | 'epoch': epoch, 335 | 'opt': opt, 336 | 'state_dict': model.state_dict(), 337 | 'optimizer': optimizer.state_dict(), 338 | } 339 | torch.save(states, save_file_path) 340 | 341 | 342 | 343 | -------------------------------------------------------------------------------- /utils/spatial_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numbers 4 | import collections 5 | import numpy as np 6 | import torch 7 | from PIL import Image, ImageOps 8 | try: 9 | import accimage 10 | except ImportError: 11 | accimage = None 12 | 13 | 14 | class Compose(object): 15 | """Composes several transforms together. 16 | Args: 17 | transforms (list of ``Transform`` objects): list of transforms to compose. 18 | Example: 19 | >>> transforms.Compose([ 20 | >>> transforms.CenterCrop(10), 21 | >>> transforms.ToTensor(), 22 | >>> ]) 23 | """ 24 | 25 | def __init__(self, transforms): 26 | self.transforms = transforms 27 | 28 | def __call__(self, img): 29 | for t in self.transforms: 30 | img = t(img) 31 | return img 32 | 33 | def randomize_parameters(self): 34 | for t in self.transforms: 35 | t.randomize_parameters() 36 | 37 | 38 | class ToTensor(object): 39 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. 40 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 41 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 42 | """ 43 | 44 | def __init__(self, norm_value=255): 45 | self.norm_value = norm_value 46 | 47 | def __call__(self, pic): 48 | """ 49 | Args: 50 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. 51 | Returns: 52 | Tensor: Converted image. 53 | """ 54 | if isinstance(pic, np.ndarray): 55 | # handle numpy array 56 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 57 | # backward compatibility 58 | return img.float().div(self.norm_value) 59 | 60 | if accimage is not None and isinstance(pic, accimage.Image): 61 | nppic = np.zeros( 62 | [pic.channels, pic.height, pic.width], dtype=np.float32) 63 | pic.copyto(nppic) 64 | return torch.from_numpy(nppic) 65 | 66 | # handle PIL Image 67 | if pic.mode == 'I': 68 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 69 | elif pic.mode == 'I;16': 70 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 71 | else: 72 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 73 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 74 | if pic.mode == 'YCbCr': 75 | nchannel = 3 76 | elif pic.mode == 'I;16': 77 | nchannel = 1 78 | else: 79 | nchannel = len(pic.mode) 80 | img = img.view(pic.size[1], pic.size[0], nchannel) 81 | # put it from HWC to CHW format 82 | # yikes, this transpose takes 80% of the loading time/CPU 83 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 84 | if isinstance(img, torch.ByteTensor): 85 | return img.float().div(self.norm_value) 86 | else: 87 | return img 88 | 89 | def randomize_parameters(self): 90 | pass 91 | 92 | 93 | class Normalize(object): 94 | """Normalize an tensor image with mean and standard deviation. 95 | Given mean: (R, G, B) and std: (R, G, B), 96 | will normalize each channel of the torch.*Tensor, i.e. 97 | channel = (channel - mean) / std 98 | Args: 99 | mean (sequence): Sequence of means for R, G, B channels respecitvely. 100 | std (sequence): Sequence of standard deviations for R, G, B channels 101 | respecitvely. 102 | """ 103 | 104 | def __init__(self, mean, std): 105 | self.mean = mean 106 | self.std = std 107 | 108 | def __call__(self, tensor): 109 | """ 110 | Args: 111 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 112 | Returns: 113 | Tensor: Normalized image. 114 | """ 115 | # TODO: make efficient 116 | for t, m, s in zip(tensor, self.mean, self.std): 117 | t.sub_(m).div_(s) 118 | return tensor 119 | 120 | def randomize_parameters(self): 121 | pass 122 | 123 | 124 | class Scale_shorterside(object): 125 | """Rescale the input PIL.Image to the given size. 126 | Args: 127 | size (sequence or int): Desired output size. If size is a sequence like 128 | (w, h), output size will be matched to this. If size is an int, 129 | smaller edge of the image will be matched to this number. 130 | i.e, if height > width, then image will be rescaled to 131 | (size * height / width, size) 132 | interpolation (int, optional): Desired interpolation. Default is 133 | ``PIL.Image.BILINEAR`` 134 | """ 135 | 136 | def __init__(self, size, interpolation=Image.BILINEAR): 137 | assert isinstance(size, 138 | int) or (isinstance(size, collections.Iterable) and 139 | len(size) == 2) 140 | self.size = size 141 | self.interpolation = interpolation 142 | 143 | def __call__(self, img): 144 | """ 145 | Args: 146 | img (PIL.Image): Image to be scaled. 147 | Returns: 148 | PIL.Image: Rescaled image. 149 | """ 150 | if isinstance(self.size, int): 151 | w, h = img.size 152 | if (w <= h and w == self.size) or (h <= w and h == self.size): 153 | return img 154 | if w < h: 155 | ow = self.size 156 | oh = int(self.size * h / w) 157 | return img.resize((ow, oh), self.interpolation) 158 | else: 159 | oh = self.size 160 | ow = int(self.size * w / h) 161 | return img.resize((ow, oh), self.interpolation) 162 | else: 163 | return img.resize(self.size, self.interpolation) 164 | 165 | def randomize_parameters(self): 166 | pass 167 | 168 | class Scale_longerside(object): 169 | 170 | def __init__(self, size, interpolation=Image.BILINEAR): 171 | assert isinstance(size, 172 | int) or (isinstance(size, collections.Iterable) and 173 | len(size) == 2) 174 | self.size = size 175 | self.interpolation = interpolation 176 | 177 | def _pad(self, img): 178 | h, w = img.size[: 2] 179 | pad_h = max(self.size - h, 0) 180 | pad_w = max(self.size - w, 0) 181 | 182 | nparrary = np.pad(img, ((int(pad_w/2), int(pad_w-pad_w/2)), 183 | (int(pad_h/2), int(pad_h-pad_h/2)), (0, 0)), 'constant') 184 | img = Image.fromarray(nparrary) 185 | 186 | return img 187 | 188 | def __call__(self, img): 189 | 190 | if isinstance(self.size, int): 191 | w, h = img.size 192 | if (w >= h and w == self.size) or (h >= w and h == self.size): 193 | return self._pad(img) 194 | 195 | if w > h: 196 | ow = self.size 197 | oh = int(self.size * h / w) 198 | else: 199 | oh = self.size 200 | ow = int(self.size * w / h) 201 | img = img.resize((ow, oh), self.interpolation) 202 | 203 | return self._pad(img) 204 | 205 | else: 206 | return img.resize(self.size, self.interpolation) 207 | 208 | def randomize_parameters(self): 209 | pass 210 | 211 | 212 | class CenterCrop(object): 213 | """Crops the given PIL.Image at the center. 214 | Args: 215 | size (sequence or int): Desired output size of the crop. If size is an 216 | int instead of sequence like (h, w), a square crop (size, size) is 217 | made. 218 | """ 219 | 220 | def __init__(self, size): 221 | if isinstance(size, numbers.Number): 222 | self.size = (int(size), int(size)) 223 | else: 224 | self.size = size 225 | 226 | def __call__(self, img): 227 | """ 228 | Args: 229 | img (PIL.Image): Image to be cropped. 230 | Returns: 231 | PIL.Image: Cropped image. 232 | """ 233 | w, h = img.size 234 | th, tw = self.size 235 | x1 = int(round((w - tw) / 2.)) 236 | y1 = int(round((h - th) / 2.)) 237 | return img.crop((x1, y1, x1 + tw, y1 + th)) 238 | 239 | def randomize_parameters(self): 240 | pass 241 | 242 | 243 | class CornerCrop(object): 244 | 245 | def __init__(self, size, crop_position=None): 246 | self.size = size 247 | if crop_position is None: 248 | self.randomize = True 249 | else: 250 | self.randomize = False 251 | self.crop_position = crop_position 252 | self.crop_positions = ['c', 'tl', 'tr', 'bl', 'br'] 253 | 254 | def __call__(self, img): 255 | image_width = img.size[0] 256 | image_height = img.size[1] 257 | 258 | if self.crop_position == 'c': 259 | th, tw = (self.size, self.size) 260 | x1 = int(round((image_width - tw) / 2.)) 261 | y1 = int(round((image_height - th) / 2.)) 262 | x2 = x1 + tw 263 | y2 = y1 + th 264 | elif self.crop_position == 'tl': 265 | x1 = 0 266 | y1 = 0 267 | x2 = self.size 268 | y2 = self.size 269 | elif self.crop_position == 'tr': 270 | x1 = image_width - self.size 271 | y1 = 0 272 | x2 = image_width 273 | y2 = self.size 274 | elif self.crop_position == 'bl': 275 | x1 = 0 276 | y1 = image_height - self.size 277 | x2 = self.size 278 | y2 = image_height 279 | elif self.crop_position == 'br': 280 | x1 = image_width - self.size 281 | y1 = image_height - self.size 282 | x2 = image_width 283 | y2 = image_height 284 | 285 | img = img.crop((x1, y1, x2, y2)) 286 | 287 | return img 288 | 289 | def randomize_parameters(self): 290 | if self.randomize: 291 | self.crop_position = self.crop_positions[random.randint( 292 | 0, 293 | len(self.crop_positions) - 1)] 294 | 295 | 296 | class RandomHorizontalFlip(object): 297 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 298 | 299 | def __call__(self, img): 300 | """ 301 | Args: 302 | img (PIL.Image): Image to be flipped. 303 | Returns: 304 | PIL.Image: Randomly flipped image. 305 | """ 306 | if self.p < 0.5: 307 | return img.transpose(Image.FLIP_LEFT_RIGHT) 308 | return img 309 | 310 | def randomize_parameters(self): 311 | self.p = random.random() 312 | 313 | 314 | class MultiScaleCornerCrop(object): 315 | """Crop the given PIL.Image to randomly selected size. 316 | A crop of size is selected from scales of the original size. 317 | A position of cropping is randomly selected from 4 corners and 1 center. 318 | This crop is finally resized to given size. 319 | Args: 320 | scales: cropping scales of the original size 321 | size: size of the smaller edge 322 | interpolation: Default: PIL.Image.BILINEAR 323 | """ 324 | 325 | def __init__(self, 326 | scales, 327 | size, 328 | interpolation=Image.BILINEAR, 329 | crop_positions=['c', 'c', 'tl', 'tr', 'bl', 'br']): 330 | self.scales = scales 331 | self.size = size 332 | self.interpolation = interpolation 333 | 334 | self.crop_positions = crop_positions 335 | 336 | def __call__(self, img): 337 | min_length = min(img.size[0], img.size[1]) 338 | crop_size = int(min_length * self.scale) 339 | 340 | image_width = img.size[0] 341 | image_height = img.size[1] 342 | 343 | if self.crop_position == 'c': 344 | center_x = image_width // 2 345 | center_y = image_height // 2 346 | box_half = crop_size // 2 347 | x1 = center_x - box_half 348 | y1 = center_y - box_half 349 | x2 = center_x + box_half 350 | y2 = center_y + box_half 351 | elif self.crop_position == 'tl': 352 | x1 = 0 353 | y1 = 0 354 | x2 = crop_size 355 | y2 = crop_size 356 | elif self.crop_position == 'tr': 357 | x1 = image_width - crop_size 358 | y1 = 0 359 | x2 = image_width 360 | y2 = crop_size 361 | elif self.crop_position == 'bl': 362 | x1 = 0 363 | y1 = image_height - crop_size 364 | x2 = crop_size 365 | y2 = image_height 366 | elif self.crop_position == 'br': 367 | x1 = image_width - crop_size 368 | y1 = image_height - crop_size 369 | x2 = image_width 370 | y2 = image_height 371 | 372 | img = img.crop((x1, y1, x2, y2)) 373 | 374 | return img.resize((self.size, self.size), self.interpolation) 375 | 376 | def randomize_parameters(self): 377 | self.scale = self.scales[random.randint(0, len(self.scales) - 1)] 378 | self.crop_position = self.crop_positions[random.randint( 379 | 0, 380 | len(self.crop_positions) - 1)] 381 | 382 | 383 | class MultiScaleRandomCrop(object): 384 | 385 | def __init__(self, scales, size, interpolation=Image.BILINEAR): 386 | self.scales = scales 387 | self.size = size 388 | self.interpolation = interpolation 389 | 390 | def __call__(self, img): 391 | min_length = min(img.size[0], img.size[1]) 392 | crop_size = int(min_length * self.scale) 393 | 394 | image_width = img.size[0] 395 | image_height = img.size[1] 396 | 397 | x1 = self.tl_x * (image_width - crop_size) 398 | y1 = self.tl_y * (image_height - crop_size) 399 | x2 = x1 + crop_size 400 | y2 = y1 + crop_size 401 | 402 | img = img.crop((x1, y1, x2, y2)) 403 | 404 | return img.resize((self.size, self.size), self.interpolation) 405 | 406 | def randomize_parameters(self): 407 | self.scale = self.scales[random.randint(0, len(self.scales) - 1)] 408 | self.tl_x = random.random() 409 | self.tl_y = random.random() 410 | -------------------------------------------------------------------------------- /val_2stream.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import time 6 | import os 7 | import sys 8 | import numpy as np 9 | import random 10 | import math 11 | 12 | from utils.utils import AverageMeter 13 | from utils.myutils import update_inputs_2stream 14 | 15 | val_opt = {} 16 | val_opt['iter_terminal_num'] = 1e7 17 | 18 | val_opt['merge_level'] = 5 19 | val_opt['merge_w'] = 0.5 20 | 21 | val_opt['min_scale'] = 0.03 22 | val_opt['max_scale'] = 0.35 23 | val_opt['init_scale_num'] = 30 24 | val_opt['abandon_second_box'] = False 25 | 26 | 27 | def action_step(state, action_1, action_2, step, sample_len, opt, dataset): 28 | lp, mp, rp = state 29 | 30 | seg_len_1 = (mp - lp + 1) * action_1 31 | seg_len_2 = (rp - mp) * action_2 32 | 33 | seg_len_1 = min(max(4, seg_len_1), sample_len/val_opt['min_cycles']) 34 | seg_len_2 = min(max(4, seg_len_2), sample_len/val_opt['min_cycles']) 35 | 36 | mp = int(mp + step) 37 | lp = int(mp - seg_len_1 + 1) 38 | rp = int(mp + seg_len_2) 39 | 40 | state = (lp, mp, rp) 41 | 42 | done_flag = mp >= sample_len 43 | fail_flag = (mp - lp + 1) < 4 or (rp - mp) < 4 44 | 45 | return state, done_flag, fail_flag 46 | 47 | 48 | def val_epoch(epoch, data_loader, model, opt, epoch_logger, val_dataset): 49 | 50 | print('eval at epoch {}'.format(epoch)) 51 | 52 | if val_dataset=='ucf_aug': 53 | val_opt['min_cycles']=2 54 | else: 55 | val_opt['min_cycles']=4 56 | 57 | if val_dataset=='yt_seg': 58 | val_opt['merge_w']=0.1 59 | 60 | model.eval() 61 | 62 | batch_time = AverageMeter() 63 | data_time = AverageMeter() 64 | maes = AverageMeter() 65 | maeps = AverageMeter() 66 | maens = AverageMeter() 67 | oboas = AverageMeter() 68 | 69 | end_time = time.time() 70 | counts_oboa = [] 71 | counts_all = [] 72 | maes_all = [] 73 | oboas_all = [] 74 | 75 | cycle_length_dataset = np.zeros([150, pow(2, val_opt['merge_level'])], dtype=np.float) 76 | cycle_length_dataset_ptr = 0 77 | 78 | 79 | for i, (sample_inputs, _, _, label_counts, sample_len) in enumerate(data_loader): 80 | 81 | if val_opt['iter_terminal_num'] != -1 and i > val_opt['iter_terminal_num']: 82 | break 83 | 84 | 85 | 86 | data_time.update(time.time() - end_time) 87 | end_time = time.time() 88 | 89 | batch_size = sample_inputs.size(0) 90 | 91 | # targets init 92 | label_counts = label_counts.numpy() 93 | sample_len = sample_len.numpy() 94 | level_pow = pow(2, val_opt['merge_level']) 95 | 96 | # track state init 97 | mp = np.zeros([batch_size, val_opt['merge_level'], level_pow], dtype=np.int) 98 | lp_l = np.zeros([batch_size, val_opt['merge_level'], level_pow], dtype=np.int) 99 | lp_r = np.zeros([batch_size, val_opt['merge_level'], level_pow], dtype=np.int) 100 | rp_l = np.zeros([batch_size, val_opt['merge_level'], level_pow], dtype=np.int) 101 | rp_r = np.zeros([batch_size, val_opt['merge_level'], level_pow], dtype=np.int) 102 | 103 | load_lp = np.zeros(batch_size, dtype=np.int) 104 | load_mp = np.zeros(batch_size, dtype=np.int) 105 | load_rp = np.zeros(batch_size, dtype=np.int) 106 | save_lp = np.zeros(batch_size, dtype=np.int) 107 | save_mp = np.zeros(batch_size, dtype=np.int) 108 | save_rp = np.zeros(batch_size, dtype=np.int) 109 | 110 | load_ls = np.zeros(batch_size, dtype=np.float) 111 | load_rs = np.zeros(batch_size, dtype=np.float) 112 | save_ls = np.zeros(batch_size, dtype=np.float) 113 | save_rs = np.zeros(batch_size, dtype=np.float) 114 | 115 | counts = np.zeros(batch_size, dtype=np.float) 116 | 117 | 118 | # get the first estimation 119 | max_mp = np.zeros(batch_size, dtype=np.int) 120 | max_score = np.zeros(batch_size, dtype=np.float) 121 | for j in range(0, batch_size): 122 | max_score[j] = -1e6 123 | 124 | for k in range(0, val_opt['init_scale_num']): 125 | powers_level = (val_opt['max_scale'] / val_opt['min_scale']) ** (float(k)/(val_opt['init_scale_num']-1)) 126 | inputs = torch.zeros([batch_size, 3, opt.basic_duration, opt.sample_size, opt.sample_size], dtype=torch.float).cuda() 127 | 128 | for j in range(0, batch_size): 129 | mp_k = sample_len[j] * val_opt['min_scale'] * powers_level 130 | mid_pt = sample_len[j]/2 131 | inputs[j], _ = update_inputs_2stream(sample_inputs[j], [mid_pt-mp_k, mid_pt, mid_pt+mp_k+1], sample_len[j], opt) 132 | 133 | pred_cls, pred_box, _, _ = model(inputs) 134 | pred_box = torch.clamp(pred_box, min=-0.5, max=0.5) 135 | 136 | for j in range(0, batch_size): 137 | 138 | for p in range(3, 4): 139 | box_exp = math.exp(pred_box[j][p]) 140 | pred_seg = box_exp * opt.anchors[p] 141 | penalty = 1 142 | score = F.softmax(pred_cls, dim=1)[j][1][p] * penalty 143 | mp_k = sample_len[j] * val_opt['min_scale'] * powers_level * pred_seg 144 | if score > max_score[j] and mp_k >= 4 and mp_k < sample_len[j]/val_opt['min_cycles']: 145 | max_score[j], max_mp[j] = score, mp_k 146 | 147 | for k in range(0, 4): 148 | inputs = torch.zeros([batch_size, 3, opt.basic_duration, opt.sample_size, opt.sample_size], dtype=torch.float).cuda() 149 | for j in range(0, batch_size): 150 | mp_k = max_mp[j] 151 | mid_pt = sample_len[j]/2 152 | inputs[j], _ = update_inputs_2stream(sample_inputs[j], [mid_pt-mp_k, mid_pt, mid_pt+mp_k+1], sample_len[j], opt) 153 | 154 | pred_cls, pred_box, _, _ = model(inputs) 155 | pred_box = torch.clamp(pred_box, min=-0.5, max=0.5) 156 | 157 | for j in range(0, batch_size): 158 | max_score[j] = -1e6 159 | tmp = max_mp[j] 160 | for p in range(3, 4): 161 | box_exp = math.exp(pred_box[j][p]) 162 | pred_seg = box_exp * opt.anchors[p] 163 | penalty = 1 164 | score = F.softmax(pred_cls, dim=1)[j][1][p] * penalty 165 | mp_k = tmp * pred_seg 166 | if score > max_score[j] and mp_k >= 4 and mp_k < sample_len[j]/val_opt['min_cycles']: 167 | max_score[j], max_mp[j] = score, round(float(max_mp[j]*(1-val_opt['merge_w']))+float(mp_k*val_opt['merge_w'])) 168 | 169 | 170 | for j in range(0, batch_size): 171 | for l2 in range(0, level_pow): 172 | mp[j,0,l2] = int(float(sample_len[j]) / float(level_pow+1) * (l2+0.5)) 173 | lp_l[j,0,l2] = mp[j,0,l2] - max_mp[j] 174 | rp_l[j,0,l2] = mp[j,0,l2] + max_mp[j] + 1 175 | lp_r[j,0,l2] = lp_l[j,0,l2] 176 | rp_r[j,0,l2] = rp_l[j,0,l2] 177 | 178 | 179 | 180 | total_steps = 0 181 | for l1 in range(1, val_opt['merge_level']): 182 | 183 | steps = pow(2, val_opt['merge_level']-l1-1) 184 | pos = -steps 185 | for l2 in range(0, pow(2,l1)): 186 | pos = pos + 2*steps 187 | 188 | if l1==1: 189 | iters = 4 190 | elif l1==2: 191 | iters = 2 192 | else: 193 | iters = 1 194 | 195 | for l3 in range(0, iters): 196 | 197 | total_steps = total_steps + 1 198 | 199 | inputs = torch.zeros([batch_size, 3, opt.basic_duration, opt.sample_size, opt.sample_size], dtype=torch.float).cuda() 200 | # network input initilization 201 | for j in range(0, batch_size): 202 | 203 | if l3 == 0: 204 | load_mp[j] = mp[j,l1-1,pos] 205 | load_lp[j] = round(float(lp_l[j,l1-1,pos]+lp_r[j,l1-1,pos])/2) 206 | load_rp[j] = round(float(rp_l[j,l1-1,pos]+rp_r[j,l1-1,pos])/2) 207 | else: 208 | load_mp[j] = save_mp[j] 209 | load_lp[j] = round(float(save_lp[j]) * val_opt['merge_w'] + float(load_lp[j]) * (1.0-val_opt['merge_w'])) 210 | load_rp[j] = round(float(save_rp[j]) * val_opt['merge_w'] + float(load_rp[j]) * (1.0-val_opt['merge_w'])) 211 | 212 | inputs[j], _ = update_inputs_2stream(sample_inputs[j], [load_lp[j], load_mp[j], load_rp[j]], sample_len[j], opt) 213 | 214 | # do the forward 215 | inputs = Variable(inputs) 216 | pred_cls_1, pred_box_1, pred_cls_2, pred_box_2 = model(inputs) 217 | 218 | 219 | pred_box_1 = torch.clamp(pred_box_1, min=-0.5, max=0.5) 220 | pred_box_2 = torch.clamp(pred_box_2, min=-0.5, max=0.5) 221 | 222 | # track state update 223 | for j in range(0, batch_size): 224 | 225 | max_score, action_1 = -1e6, -1 226 | for k in range(0, opt.n_classes): 227 | box_exp = math.exp(pred_box_1[j][k]) 228 | pred_seg = box_exp * opt.anchors[k] 229 | penalty = 1 230 | score = F.softmax(pred_cls_1, dim=1)[j][1][k] * penalty 231 | if score > max_score: 232 | max_score, action_1 = score, pred_seg 233 | save_ls[j] = score 234 | 235 | max_score, action_2 = -1e6, -1 236 | for k in range(0, opt.n_classes): 237 | box_exp = math.exp(pred_box_2[j][k]) 238 | pred_seg = box_exp * opt.anchors[k] 239 | penalty = 1 240 | score = F.softmax(pred_cls_2, dim=1)[j][1][k] * penalty 241 | if score > max_score: 242 | max_score, action_2 = score, pred_seg 243 | save_rs[j] = score 244 | 245 | if val_opt['abandon_second_box'] == True: 246 | action_2 = action_1 247 | save_rs[j] = save_ls[j] 248 | 249 | 250 | new_state, done_flag, fail_flag = action_step([load_lp[j], load_mp[j], load_rp[j]], action_1, action_2, 0, sample_len[j], opt, val_dataset) 251 | save_lp[j], save_mp[j], save_rp[j] = new_state 252 | 253 | 254 | if fail_flag: 255 | save_lp[j] = load_lp[j] 256 | save_rp[j] = load_rp[j] 257 | 258 | for j in range(0, batch_size): 259 | 260 | 261 | l_segments = float(save_lp[j]) * val_opt['merge_w'] + float(load_lp[j]) * (1.0-val_opt['merge_w']) 262 | r_segments = float(save_rp[j]) * val_opt['merge_w'] + float(load_rp[j]) * (1.0-val_opt['merge_w']) 263 | 264 | for s in range(-steps, 0): 265 | mp[j,l1,pos+s] = mp[j,l1-1,pos+s] 266 | lp_r[j,l1,pos+s] = mp[j,l1-1,pos+s] + (l_segments-mp[j,l1-1,pos]) 267 | rp_r[j,l1,pos+s] = mp[j,l1-1,pos+s] + (r_segments-mp[j,l1-1,pos]) 268 | 269 | 270 | if l1 <= 2 or l1 == val_opt['merge_level']-1 or l2 == 0: 271 | lp_l[j,l1,pos+s] = lp_r[j,l1,pos+s] 272 | rp_l[j,l1,pos+s] = rp_r[j,l1,pos+s] 273 | else: 274 | lp_l[j,l1,pos+s] = lp_l[j,l1-1,pos+s] 275 | rp_l[j,l1,pos+s] = rp_l[j,l1-1,pos+s] 276 | 277 | 278 | for s in range(0, steps): 279 | mp[j,l1,pos+s] = mp[j,l1-1,pos+s] 280 | lp_l[j,l1,pos+s] = mp[j,l1-1,pos+s] + (l_segments-mp[j,l1-1,pos]) 281 | rp_l[j,l1,pos+s] = mp[j,l1-1,pos+s] + (r_segments-mp[j,l1-1,pos]) 282 | if l1 <= 2 or l1 == val_opt['merge_level']-1 or l2 == pow(2,l1)-1: 283 | lp_r[j,l1,pos+s] = lp_l[j,l1,pos+s] 284 | rp_r[j,l1,pos+s] = rp_l[j,l1,pos+s] 285 | else: 286 | lp_r[j,l1,pos+s] = lp_r[j,l1-1,pos+s] 287 | rp_r[j,l1,pos+s] = rp_r[j,l1-1,pos+s] 288 | 289 | 290 | for j in range(0, batch_size): 291 | left_avg = AverageMeter() 292 | right_avg = AverageMeter() 293 | for k in range(0, level_pow): 294 | last = val_opt['merge_level'] -1 295 | 296 | lp_avg = round(float(lp_l[j,last,k]+lp_r[j,last,k])/2) 297 | rp_avg = round(float(rp_l[j,last,k]+rp_r[j,last,k])/2) 298 | 299 | pos1 = int(lp_avg - (mp[j,last,k] - lp_avg + 1) * opt.l_context_ratio) 300 | pos2 = int(rp_avg + (rp_avg - mp[j,last,k] + 0) * (opt.r_context_ratio - 1)) 301 | if pos1 >= 0 and pos2 < sample_len[j]: 302 | 303 | if val_dataset == 'quva' or val_dataset == 'yt_seg' or val_dataset == 'ucf_aug': 304 | left_avg.update(1.0/float(mp[j,last,k]-lp_avg+1)) 305 | right_avg.update(1.0/float(rp_avg - mp[j,last,k])) 306 | 307 | else: 308 | left_avg.update(float(mp[j,last,k] - lp_avg+1)) 309 | right_avg.update(float(rp_avg - mp[j,last,k])) 310 | 311 | cycle_length_dataset[cycle_length_dataset_ptr+j, k] = 1.0/float(mp[j,last,k]-lp_avg+1)+1.0/float(rp_avg - mp[j,last,k]) 312 | 313 | 314 | if left_avg.avg == 0 or right_avg.avg == 0: 315 | counts[j] = float(sample_len[j]) / float(max_mp[j]+1) 316 | else: 317 | if val_dataset == 'quva' or val_dataset == 'yt_seg' or val_dataset == 'ucf_aug': 318 | counts[j] = float(sample_len[j]) * float(left_avg.sum*0.5+right_avg.sum*0.5) /float(left_avg.count) 319 | 320 | else: 321 | counts[j] = float(sample_len[j]+1e-6) / float(left_avg.avg*0.5+right_avg.avg*0.5) 322 | 323 | counts[j] = float(round(counts[j])) 324 | # print(sample_inputs.size(), sample_len[j], label_counts[j], counts[j], float(sample_len[j]) / float(max_mp[j]+1)) 325 | 326 | counts_all.append(counts[j]) 327 | 328 | mae = float(abs(counts[j] - label_counts[j]))/ float(label_counts[j]) 329 | if mae > 0.33: 330 | counts_oboa.append(i) 331 | 332 | if abs(counts[j] - label_counts[j]) > 1: 333 | oboa = 0.0 334 | else: 335 | oboa = 1.0 336 | 337 | maes_all.append(mae) 338 | oboas_all.append(oboa) 339 | 340 | maes.update(mae) 341 | if counts[j] > label_counts[j]: 342 | maeps.update(mae) 343 | elif counts[j] < label_counts[j]: 344 | maens.update(mae) 345 | oboas.update(oboa) 346 | 347 | 348 | 349 | batch_time.update(time.time() - end_time) 350 | cycle_length_dataset_ptr = cycle_length_dataset_ptr + batch_size 351 | 352 | 353 | print('Epoch: [{0}][{1}/{2}]\t' 354 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 355 | 'OBOA {oboa.val:.4f} ({oboa.avg:.4f})\t' 356 | 'MAE {maes.val:.4f} ({maes.avg:.4f})\t' 357 | 'MAEstd {maestd:.4f}\t' 358 | 'MAEP {maeps.val:.4f} ({maeps.avg:.4f})\t' 359 | 'MAEN {maens.val:.4f} ({maens.avg:.4f})\t' 360 | 'total_steps {total_steps: d}\n'.format( 361 | epoch, 362 | i + 1, 363 | len(data_loader), 364 | batch_time=batch_time, 365 | oboa=oboas, 366 | maes=maes, 367 | maestd=maes.std(), 368 | maeps=maeps, 369 | maens=maens, 370 | total_steps=total_steps)) 371 | 372 | 373 | # np.save(val_dataset, cycle_length_dataset) 374 | 375 | 376 | epoch_logger.log({ 377 | 'epoch': epoch, 378 | 'OBOA': oboas.avg, 379 | 'MAE': maes.avg, 380 | 'MAE_std': maes.std(), 381 | 'MAEP': maeps.avg, 382 | 'MAEN': maens.avg, 383 | }) 384 | return maes.avg 385 | 386 | 387 | --------------------------------------------------------------------------------