├── .gitignore ├── README.md ├── constant.py ├── data ├── annotations │ ├── avenue.json │ └── shanghaitech_semantic_annotation.json └── hyper_params │ ├── normal.ini │ ├── temporal.ini │ └── tune_video.ini ├── doc ├── only_normal_data.md ├── temporal_annotation.md └── video_annotation.md ├── evaluate.py ├── inference.py ├── models ├── __init__.py ├── networks.py └── pix2pix.py ├── requirements.txt ├── train_scripts ├── train_normal_annotation.py ├── train_temporal_annotation.py └── train_tune_video_annotation.py └── utils ├── __init__.py ├── dataloaders ├── __init__.py ├── only_normal_loader.py ├── temporal_triplet_loader.py ├── test_loader.py └── tune_video_loader.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | data/pretrains 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Margin Learning Embedded Prediction for Video Anomaly Detection with A Few Anomalies, IJCAI 2019 2 | Wen Liu*, Weixin Luo*, Zhengxin Li, Peilin Zhao, Shenghua Gao. 3 | 4 | ## 1. Installation (Anaconda with python3.6 installation is recommended) 5 | ```shell 6 | 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ## 2. Download datasets 11 | Please manually download all datasets from [avenue.tar.gz and shanghaitech.tar.gz](https://onedrive.live.com/?authkey=%21AMqh2fTSemfrokE&id=3705E349C336415F%215109&cid=3705E349C336415F) 12 | and tar each tar.gz file, and move them in to **data** folder. 13 | 14 | You can also download data from BaiduYun(https://pan.baidu.com/s/1j0TEt-2Dw3kcfdX-LCF0YQ) i9b3 15 | 16 | ## 3. Inference the pretrain model 17 | Download the pre-trained models firstly, `pretrains` folder 18 | * from [OneDrive](https://1drv.ms/u/s!AjjUqiJZsj8whLdx8Bw8NyAQ3NlGVw?e=odcUe3); 19 | * from [BaiduPan](https://pan.baidu.com/s/1K5mE07ygCoP9Mw97RlGrSA), ioov 20 | 21 | and then, move the `pretrains` folder into `data`,`mv pretrains data`. 22 | 23 | #### 3.1 Inference with Only-Normal-Data Pretrained model 24 | ``` 25 | python inference.py --dataset avenue \ 26 | --prednet cyclegan_convlstm \ 27 | --num_his 4 \ 28 | --label_level normal \ 29 | --gpu 0 \ 30 | --interpolation --snapshot_dir ./data/pretrains/avenue/normal/checkpoints/model.ckpt-74000 31 | ``` 32 | 33 | #### 3.2 Inference with Video-Annotated Pretrained model 34 | 35 | ``` 36 | python inference.py --dataset avenue \ 37 | --prednet cyclegan_convlstm \ 38 | --num_his 4 \ 39 | --label_level tune_video \ 40 | --gpu 0 \ 41 | --interpolation --snapshot_dir ./data/pretrains/avenue/tune_video/prednet_cyclegan_convlstm_folds_10_kth_1_/MARGIN_1.0_LAMBDA_1.0/model.ckpt-76000 42 | ``` 43 | 44 | 45 | #### 3.3 Inference with Temporal-Annotated Pretrained model 46 | ``` 47 | python inference.py --dataset avenue \ 48 | --prednet cyclegan_convlstm \ 49 | --num_his 4 \ 50 | --label_level normal \ 51 | --gpu 0 \ 52 | --interpolation --snapshot_dir ./data/pretrains/avenue/temporal/prednet_cyclegan_convlstm_folds_10_kth_1_/MARGIN_1.0_LAMBDA_1.0/model.ckpt-77000 53 | ``` 54 | 55 | ## 4. Training model with different settings from scratch 56 | See more details in 57 | 58 | 4.1 [only_normal_data](./doc/only_normal_data.md); 59 | 60 | 4.2 [video_annotation](./doc/video_annotation.md); 61 | 62 | 4.3 [temporal_annotation](./doc/temporal_annotation.md). 63 | 64 | 65 | ### Citation 66 | ``` 67 | @inproceedings{melp_2019, 68 | author = {Wen Liu and 69 | Weixin Luo and 70 | Zhengxin Li and 71 | Peilin Zhao and 72 | Shenghua Gao}, 73 | title = {Margin Learning Embedded Prediction for Video Anomaly Detection with 74 | {A} Few Anomalies}, 75 | booktitle = {Proceedings of the Twenty-Eighth International Joint Conference on 76 | Artificial Intelligence, {IJCAI} 2019, Macao, China, August 10-16, 77 | 2019}, 78 | pages = {3023--3030}, 79 | publisher = {ijcai.org}, 80 | year = {2019} 81 | } 82 | ``` -------------------------------------------------------------------------------- /constant.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import configparser 4 | 5 | import ipdb 6 | 7 | 8 | def get_dir(directory): 9 | """ 10 | get the directory, if no such directory, then make it. 11 | 12 | @param directory: The new directory. 13 | """ 14 | 15 | if not os.path.exists(directory): 16 | os.makedirs(directory) 17 | 18 | return directory 19 | 20 | 21 | def parser_args(): 22 | parser = argparse.ArgumentParser(description='Options to run the network.') 23 | parser.add_argument('-g', '--gpu', type=str, nargs='*', choices=['0', '1', '2', '3', 24 | '4', '5', '6', '7', '8', '9'], required=True, 25 | help='the device id of gpu.') 26 | parser.add_argument('-i', '--iters', type=int, default=1, 27 | help='set the number of iterations, default is 1') 28 | parser.add_argument('-b', '--batch', type=int, default=4, 29 | help='set the batch size, default is 4.') 30 | 31 | parser.add_argument('-d', '--dataset', type=str, 32 | help='the name of dataset.') 33 | 34 | parser.add_argument('-o', '--output_dir', type=str, default="./data/pretrains", 35 | help='the path of the output directory') 36 | 37 | parser.add_argument('--num_his', type=int, default=4, 38 | help='set the time steps, default is 4.') 39 | 40 | parser.add_argument('--prednet', type=str, default='cyclegan_convlstm', 41 | choices=['resnet_convlstm', 'cyclegan_convlstm', 'cyclegan_conv2d', 42 | 'resnet_conv3d', 'unet_conv2d', 'conv2d_deconv2d', 'MCNet', 43 | 'two_cyclegan_convlstm_classifier', 44 | 'unet_conv2d_instance_norm', 'cyclegan_convlstm_deconv1', 45 | 'two_cyclegan_convlstm_focal_loss', 46 | 'MLE_2_NN', 'MLE_2_SVM', 'MLE_1_SVM', 'Pred_1_SVM', 'TRI_1_SVM'], 47 | help='set the name of prediction network, default is cyclegan_convlstm') 48 | 49 | parser.add_argument('--label_level', type=str, default='temporal', choices=['normal', 'video', 'tune_video', 50 | 'temporal', 'tune_temporal', 51 | 'tune_video_temporal', 52 | 'temporal_mle_nn', 53 | 'temporal_mle_svm', 54 | 'pixel'], 55 | help='set the label level.') 56 | 57 | parser.add_argument('--k_folds', type=int, default=5, 58 | help='set the number of folds.') 59 | parser.add_argument('--kth', type=int, default=1, 60 | help='choose the kth fold.') 61 | parser.add_argument('--margin', type=float, default=1.0, help='value of margin.') 62 | 63 | parser.add_argument('--pretrain', type=str, default='', 64 | help='pretrained MLE-FFP, only using for feature extraction and training MLE-2NN,' 65 | 'MLE-2-SVM and MLE-1-SVM') 66 | parser.add_argument('--snapshot_dir', type=str, default='', 67 | help='if it is folder, then it is the directory to save models, ' 68 | 'if it is a specific model.ckpt-xxx, then the system will load it for testing.') 69 | parser.add_argument('--summary_dir', type=str, default='', help='the directory to save summaries.') 70 | parser.add_argument('--psnr_dir', type=str, default='', help='the directory to save psnrs results in testing.') 71 | 72 | parser.add_argument('--evaluate', type=str, default='compute_auc', 73 | help='the evaluation metric, default is compute_auc') 74 | 75 | parser.add_argument('--interpolation', action='store_true', help='use interpolation to increase fps or not.') 76 | parser.add_argument('--multi', action='store_true', help='use multi scale and crop or not') 77 | 78 | return parser.parse_args() 79 | 80 | 81 | class Const(object): 82 | class ConstError(TypeError): 83 | pass 84 | 85 | class ConstCaseError(ConstError): 86 | pass 87 | 88 | def __setattr__(self, name, value): 89 | if name in self.__dict__: 90 | raise self.ConstError("Can't change const.{}".format(name)) 91 | if not name.isupper(): 92 | raise self.ConstCaseError('const name {} is not all uppercase'.format(name)) 93 | 94 | self.__dict__[name] = value 95 | 96 | def __str__(self): 97 | _str = '<================ Constants information ================>\n' 98 | for name, value in self.__dict__.items(): 99 | _str += '\t{}\t{}\n'.format(name, value) 100 | 101 | return _str 102 | 103 | def set_margin(self, margin): 104 | self.__dict__['MARGIN'] = margin 105 | 106 | 107 | args = parser_args() 108 | const = Const() 109 | 110 | # inputs constants 111 | const.OUTPUT_DIR = args.output_dir 112 | const.DATASET = args.dataset 113 | const.K_FOLDS = args.k_folds 114 | const.KTH = args.kth 115 | const.LABEL_LEVEL = args.label_level 116 | 117 | const.GPUS = args.gpu 118 | 119 | const.BATCH_SIZE = args.batch 120 | const.NUM_HIS = args.num_his 121 | const.ITERATIONS = args.iters 122 | const.PREDNET = args.prednet 123 | const.EVALUATE = args.evaluate 124 | const.INTERPOLATION = args.interpolation 125 | const.MULTI = args.multi 126 | 127 | 128 | # set training hyper-parameters of different datasets 129 | config = configparser.ConfigParser() 130 | assert config.read(os.path.join('./data/hyper_params', '{}.ini'.format(const.LABEL_LEVEL))) 131 | 132 | const.NORMALIZE = config.getboolean(const.DATASET, 'NORMALIZE') 133 | const.HEIGHT = config.getint(const.DATASET, 'HEIGHT') 134 | const.WIDTH = config.getint(const.DATASET, 'WIDTH') 135 | const.TRAIN_FOLDER = config.get(const.DATASET, 'TRAIN_FOLDER') 136 | const.TEST_FOLDER = config.get(const.DATASET, 'TEST_FOLDER') 137 | const.FRAME_MASK = config.get(const.DATASET, 'FRAME_MASK') 138 | const.PIXEL_MASK = config.get(const.DATASET, 'PIXEL_MASK') 139 | 140 | if args.pretrain: 141 | const.PRETRAIN_MODEL = args.pretrain 142 | else: 143 | const.PRETRAIN_MODEL = config.get(const.DATASET, 'PRETRAIN_MODEL') 144 | 145 | const.PSNR_FILE = config.get(const.DATASET, 'PSNR_FILE') 146 | 147 | 148 | # const.MARGIN = config.getfloat(const.DATASET, 'MARGIN') 149 | const.MARGIN = args.margin 150 | const.LAMBDA = config.getfloat(const.DATASET, 'LAMBDA') 151 | 152 | const.LRATE_G = eval(config.get(const.DATASET, 'LRATE_G')) 153 | const.LRATE_G_BOUNDARIES = eval(config.get(const.DATASET, 'LRATE_G_BOUNDARIES')) 154 | 155 | const.INTERVAL = config.getint(const.DATASET, 'INTERVAL') 156 | const.MULTI_INTERVAL = config.getboolean(const.DATASET, 'MULTI_INTERVAL') 157 | 158 | const.MODEL_SAVE_FREQ = config.getint(const.DATASET, 'MODEL_SAVE_FREQ') 159 | 160 | if const.LABEL_LEVEL == 'normal': 161 | const.SAVE_DIR = '{label_level}/{dataset}/prednet_{PREDNET}'.format( 162 | label_level=const.LABEL_LEVEL, dataset=const.DATASET, PREDNET=const.PREDNET 163 | ) 164 | else: 165 | const.SAVE_DIR = '{label_level}/{dataset}/prednet_{PREDNET}_folds_{K_FOLDS}_kth_{KTH}_/MARGIN_{MARGIN}_' \ 166 | 'LAMBDA_{LAMBDA}'.format(label_level=const.LABEL_LEVEL, 167 | dataset=const.DATASET, PREDNET=const.PREDNET, 168 | MARGIN=const.MARGIN, LAMBDA=const.LAMBDA, 169 | K_FOLDS=const.K_FOLDS, KTH=const.KTH) 170 | 171 | if args.snapshot_dir: 172 | # if the snapshot_dir is model.ckpt-xxx, which means it is the single model for testing. 173 | if os.path.exists(args.snapshot_dir + '.meta') or os.path.exists(args.snapshot_dir + '.data-00000-of-00001') or \ 174 | os.path.exists(args.snapshot_dir + '.index'): 175 | const.SNAPSHOT_DIR = args.snapshot_dir 176 | else: 177 | const.SNAPSHOT_DIR = get_dir(args.snapshot_dir) 178 | else: 179 | const.SNAPSHOT_DIR = get_dir(os.path.join(const.OUTPUT_DIR, 'checkpoints', const.SAVE_DIR)) 180 | 181 | if args.summary_dir: 182 | const.SUMMARY_DIR = get_dir(args.summary_dir) 183 | else: 184 | const.SUMMARY_DIR = get_dir(os.path.join(const.OUTPUT_DIR, 'summary', const.SAVE_DIR)) 185 | 186 | if args.psnr_dir: 187 | const.PSNR_DIR = get_dir(args.psnr_dir) 188 | else: 189 | if const.INTERPOLATION: 190 | const.PSNR_DIR = get_dir(os.path.join(const.OUTPUT_DIR, 'psnrs', const.SAVE_DIR + '_interpolation')) 191 | else: 192 | const.PSNR_DIR = get_dir(os.path.join(const.OUTPUT_DIR, 'psnrs', const.SAVE_DIR)) 193 | 194 | 195 | -------------------------------------------------------------------------------- /data/annotations/avenue.json: -------------------------------------------------------------------------------- 1 | { 2 | "01": { 3 | "anomalies": [ 4 | {"running": [[78, 120], [392, 422], [868, 910], [1363, 1430]]}, 5 | {"intrucing into camera": [[503, 666], [932, 1101]]} 6 | ], 7 | "length": 1439 8 | }, 9 | 10 | "02": { 11 | "anomalies": [ 12 | {"running": [[273, 320], [724, 764]]} 13 | ], 14 | "length": 1211 15 | }, 16 | 17 | "03": { 18 | "anomalies": [ 19 | {"running": [[295, 340], [582, 622]]} 20 | ], 21 | "length": 923 22 | }, 23 | 24 | "04": { 25 | "anomalies": [ 26 | {"running": [[380, 428], [649, 692]]} 27 | ], 28 | "length": 947 29 | }, 30 | 31 | "05": { 32 | "anomalies": [ 33 | {"throwing object": [[469, 786]]} 34 | ], 35 | "length": 1007 36 | }, 37 | 38 | "06": { 39 | "anomalies": [ 40 | {"intrucing into camera": [[345, 625]]}, 41 | {"throwing object": [[815, 1007]]} 42 | ], 43 | "length": 1283 44 | }, 45 | 46 | "07": { 47 | "anomalies": [ 48 | {"jumping": [[300, 307], [423, 494], [563, 605]]} 49 | ], 50 | "length": 605 51 | }, 52 | 53 | "08": { 54 | "anomalies": [ 55 | {"jumping": [[1, 36]]} 56 | ], 57 | "length": 36 58 | }, 59 | 60 | "09": { 61 | "anomalies": [ 62 | {"intrucing into camera": [[136, 183], [741, 755]]}, 63 | {"jumping": [[496, 566]]}, 64 | {"throwing object": [[824, 1175]]} 65 | ], 66 | "length": 1175 67 | }, 68 | 69 | "10": { 70 | "anomalies": [ 71 | {"throwing object": [[550, 841]]} 72 | ], 73 | "length": 841 74 | }, 75 | 76 | "11": { 77 | "anomalies": [ 78 | {"throwing object": [[1, 220]]}, 79 | {"running": [[46, 90]]}, 80 | {"intrucing into camera": [[308, 346]]} 81 | ], 82 | "length": 472 83 | }, 84 | 85 | "12": { 86 | "anomalies": [ 87 | {"throwing object": [[490, 930]]} 88 | ], 89 | "length": 1271 90 | }, 91 | 92 | "13": { 93 | "anomalies": [ 94 | {"throwing object": [[228, 289], [395, 549]]} 95 | ], 96 | "length": 549 97 | }, 98 | 99 | "14": { 100 | "anomalies": [ 101 | {"throwing object": [[363, 507]]} 102 | ], 103 | "length": 507 104 | }, 105 | 106 | "15": { 107 | "anomalies": [ 108 | {"intrucing into camera": [[498, 587]]} 109 | ], 110 | "length": 1001 111 | }, 112 | 113 | "16": { 114 | "anomalies": [ 115 | {"bycicle": [[632, 740]]} 116 | ], 117 | "length": 740 118 | }, 119 | 120 | "17": { 121 | "anomalies": [ 122 | {"dancing": [[1, 56], [99, 426]]} 123 | ], 124 | "length": 426 125 | }, 126 | 127 | "18": { 128 | "anomalies": [ 129 | {"dancing": [[1, 294]]} 130 | ], 131 | "length": 294 132 | }, 133 | 134 | "19": { 135 | "anomalies": [ 136 | {"intrucing into camera": [[109, 248]]} 137 | ], 138 | "length": 248 139 | }, 140 | 141 | "20": { 142 | "anomalies": [ 143 | {"throwing object": [[65, 144], [168, 273]]} 144 | ], 145 | "length": 273 146 | }, 147 | 148 | "21":{ 149 | "anomalies": [ 150 | {"jumping": [[14, 66]]} 151 | ], 152 | "length": 76 153 | } 154 | } -------------------------------------------------------------------------------- /data/hyper_params/normal.ini: -------------------------------------------------------------------------------- 1 | [avenue] 2 | LAMBDA = 1.0 3 | LRATE_G = [0.0002, 0.00002] 4 | LRATE_G_BOUNDARIES = [80000] 5 | 6 | HEIGHT = 224 7 | WIDTH = 224 8 | NORMALIZE = False 9 | 10 | TRAIN_FOLDER = ./data/avenue/training/frames 11 | TEST_FOLDER = ./data/avenue/testing/frames 12 | FRAME_MASK = ./data/annotations/avenue.json 13 | PIXEL_MASK = 14 | 15 | # interval clip 16 | INTERVAL = 1 17 | # use multi interval or not 18 | MULTI_INTERVAL = True 19 | 20 | MODEL_SAVE_FREQ = 1000 21 | 22 | PRETRAIN_MODEL = 23 | PSNR_FILE = 24 | 25 | [shanghaitech] 26 | LAMBDA = 1.0 27 | LRATE_G = [0.0002, 0.00002] 28 | LRATE_G_BOUNDARIES = [50000] 29 | 30 | HEIGHT = 224 31 | WIDTH = 224 32 | NORMALIZE = True 33 | 34 | 35 | TRAIN_FOLDER = ./data/shanghaitech/training/frames 36 | TEST_FOLDER = ./data/shanghaitech/testing_scenes/frames 37 | FRAME_MASK = ./data/shanghaitech/testing_scenes/test_frame_mask 38 | PIXEL_MASK = 39 | 40 | # interval clip 41 | # INTERVAL = 2 42 | INTERVAL = 1 43 | # use multi interval or not 44 | MULTI_INTERVAL = False 45 | 46 | MODEL_SAVE_FREQ = 1000 47 | 48 | PRETRAIN_MODEL = 49 | PSNR_FILE = -------------------------------------------------------------------------------- /data/hyper_params/temporal.ini: -------------------------------------------------------------------------------- 1 | [avenue] 2 | LAMBDA = 1.0 3 | LRATE_G = [0.0002, 0.00002] 4 | 5 | # Focal loss 6 | ;LRATE_G = [0.00002, 0.000002] 7 | LRATE_G_BOUNDARIES = [80000] 8 | 9 | HEIGHT = 224 10 | WIDTH = 224 11 | NORMALIZE = False 12 | 13 | TRAIN_FOLDER = ./data/avenue/training/frames 14 | TEST_FOLDER = ./data/avenue/testing/frames 15 | FRAME_MASK = ./data/annotations/avenue.json 16 | PIXEL_MASK = 17 | 18 | # interval clip 19 | INTERVAL = 1 20 | # use multi interval frames data augmentation or not, such as [t, t + 1, t + 2, ...] [t, t + k, t + 2 * k, ...] 21 | MULTI_INTERVAL = True 22 | 23 | MODEL_SAVE_FREQ = 1000 24 | 25 | PRETRAIN_MODEL = 26 | PSNR_FILE = 27 | 28 | [shanghaitech] 29 | LAMBDA = 1.0 30 | LRATE_G = [0.0002, 0.00002] 31 | LRATE_G_BOUNDARIES = [80000] 32 | 33 | HEIGHT = 224 34 | WIDTH = 224 35 | NORMALIZE = True 36 | 37 | TRAIN_FOLDER = ./data/shanghaitech/training/frames 38 | TEST_FOLDER = ./data/shanghaitech/testing_scenes/frames 39 | FRAME_MASK = ./data/shanghaitech/testing_scenes/frame_masks 40 | PIXEL_MASK = ./data/shanghaitech/testing_scenes/pixel_masks 41 | 42 | # interval clip 43 | INTERVAL = 1 44 | # use multi interval frames data augmentation or not, such as [t, t + 1, t + 2, ...] [t, t + k, t + 2 * k, ...] 45 | MULTI_INTERVAL = False 46 | 47 | MODEL_SAVE_FREQ = 1000 48 | 49 | PRETRAIN_MODEL = 50 | PSNR_FILE = 51 | -------------------------------------------------------------------------------- /data/hyper_params/tune_video.ini: -------------------------------------------------------------------------------- 1 | [avenue] 2 | LAMBDA = 1.0 3 | LRATE_G = [0.0002, 0.00002] 4 | LRATE_G_BOUNDARIES = [160000] 5 | 6 | HEIGHT = 224 7 | WIDTH = 224 8 | NORMALIZE = False 9 | 10 | TRAIN_FOLDER = ./data/avenue/training/frames 11 | TEST_FOLDER = ./data/avenue/testing/frames 12 | FRAME_MASK = ./data/annotations/avenue.json 13 | PIXEL_MASK = 14 | 15 | # interval clip 16 | INTERVAL = 1 17 | # use multi interval or not 18 | MULTI_INTERVAL = True 19 | 20 | MODEL_SAVE_FREQ = 1000 21 | 22 | PRETRAIN_MODEL = ./data/pretrains/avenue/normal/checkpoints/model.ckpt-74000 23 | PSNR_FILE = ./data/pretrains/avenue/normal/psnrs/model.ckpt-74000 24 | 25 | [shanghaitech] 26 | LAMBDA = 1.0 27 | LRATE_G = [0.00002] 28 | LRATE_G_BOUNDARIES = [200000] 29 | 30 | HEIGHT = 224 31 | WIDTH = 224 32 | NORMALIZE = True 33 | 34 | TRAIN_FOLDER = ./data/shanghaitech/training/frames 35 | TEST_FOLDER = ./data/shanghaitech/testing_scenes/frames 36 | FRAME_MASK = ./data/shanghaitech/testing_scenes/frame_masks 37 | PIXEL_MASK = 38 | 39 | # interval clip 40 | INTERVAL = 1 41 | # use multi interval or not 42 | MULTI_INTERVAL = True 43 | 44 | MODEL_SAVE_FREQ = 1000 45 | 46 | PRETRAIN_MODEL = ./data/pretrains/shanghaitech/normal/checkpoints/model.ckpt-100000 47 | PSNR_FILE = ./data/pretrains/shanghaitech/normal/psnrs/model.ckpt-100000 48 | 49 | -------------------------------------------------------------------------------- /doc/only_normal_data.md: -------------------------------------------------------------------------------- 1 | 2 | # Training the model with only normal data. 3 | In this setting, we only have the normal videos. 4 | 5 | ```shell script 6 | python train_scripts/train_normal_annotation.py --dataset avenue \ 7 | --prednet cyclegan_convlstm \ 8 | --batch 2 \ 9 | --num_his 4 \ 10 | --label_level normal \ 11 | --gpu 0 \ 12 | --iters 80000 --output_dir ./outputs 13 | 14 | ``` 15 | 16 | # Inference and evaluation. 17 | After we train the model, we run the inference and evaluate all the checkpoints. 18 | If there a more than 2 GPUs, you can immediately run the inference scripts after run the training scripts, 19 | because the inference script is always listening the directory of the checkpoints, once there is a new 20 | checkpoint, it will evaluate it immediately. Here we use `gpu 0` for training, and `gpu 1` for testing. 21 | 22 | ```shell script 23 | python inference.py --dataset avenue \ 24 | --prednet cyclegan_convlstm \ 25 | --num_his 4 \ 26 | --label_level normal \ 27 | --gpu 1 \ 28 | --interpolation --output_dir ./outputs 29 | ``` -------------------------------------------------------------------------------- /doc/temporal_annotation.md: -------------------------------------------------------------------------------- 1 | 2 | # Training the model with temporal annotation 3 | In this setting, we have a large number of normal data, and a few of abnormal data with temporal annotation, where we 4 | know which frame is normal or abnormal. 5 | 6 | We perform `k-fold` cross validation on `avenue` and `shanghaitech` dataset. 7 | In `avenue` dataset, we **re-annotate** the labels of each frames, and the re-annotated file is [avenue.json]("../data/annotations/avenue.json"). 8 | We set `k=10`, in avenue dataset and in `shanghaiTech` dataset, we set `k=5`. 9 | The following script is an example to train the model on avenue dataset with the `kth = 1` folder. 10 | We change `kth` to other folders. 11 | 12 | ```shell script 13 | python train_scripts/train_temporal_annotation.py --dataset avenue \ 14 | --prednet cyclegan_convlstm \ 15 | --batch 2 \ 16 | --num_his 4 \ 17 | --label_level temporal \ 18 | --k_folds 10 \ 19 | --kth 1 \ 20 | --gpu 0 \ 21 | --iters 80000 --output_dir ./outputs 22 | ``` 23 | 24 | # Inference and evaluation. 25 | After we train the model, we run the inference and evaluate all the checkpoints. 26 | If there a more than 2 GPUs, you can immediately run the inference scripts after run the training scripts, 27 | because the inference script is always listening the directory of the checkpoints, once there is a new 28 | checkpoint, it will evaluate it immediately. Here we use `gpu 0` for training, and `gpu 1` for testing. 29 | 30 | ```shell script 31 | python inference.py --dataset avenue \ 32 | --prednet cyclegan_convlstm \ 33 | --num_his 4 \ 34 | --label_level temporal \ 35 | --k_folds 10 \ 36 | --kth 1 \ 37 | --gpu 1 \ 38 | --interpolation --output_dir ./outputs 39 | ``` -------------------------------------------------------------------------------- /doc/video_annotation.md: -------------------------------------------------------------------------------- 1 | 2 | # Training the model with temporal annotation 3 | In this setting, we have a large number of normal data, and a few of abnormal data with **video** annotation \-\- we 4 | know which video is normal or abnornal, but we do not know the anomalies happen in which frame. 5 | 6 | We perform `k-fold` cross validation on `avenue` and `shanghaitech` dataset. 7 | In `avenue` dataset, we **re-annotate** the labels of each frames, and the re-annotated file is [avenue.json]("../data/annotations/avenue.json"). 8 | We set `k=10`, in avenue dataset and in `shanghaiTech` dataset, we set `k=5`. 9 | The following script is an example to train the model on avenue dataset with the `kth = 1` folder. 10 | We change `kth` to other folders. 11 | 12 | ```shell script 13 | python train_scripts/train_tune_video_annotation.py --dataset avenue \ 14 | --prednet cyclegan_convlstm \ 15 | --batch 2 \ 16 | --num_his 4 \ 17 | --label_level tune_video \ 18 | --k_folds 10 \ 19 | --kth 1 \ 20 | --gpu 0 \ 21 | --iters 80000 --output_dir ./outputs 22 | ``` 23 | 24 | # Inference and evaluation. 25 | After we train the model, we run the inference and evaluate all the checkpoints. 26 | If there a more than 2 GPUs, you can immediately run the inference scripts after run the training scripts, 27 | because the inference script is always listening the directory of the checkpoints, once there is a new 28 | checkpoint, it will evaluate it immediately. Here we use `gpu 0` for training, and `gpu 1` for testing. 29 | 30 | ```shell script 31 | python inference.py --dataset avenue \ 32 | --prednet cyclegan_convlstm \ 33 | --num_his 4 \ 34 | --label_level tune_video \ 35 | --k_folds 10 \ 36 | --kth 1 \ 37 | --gpu 1 \ 38 | --interpolation --output_dir ./outputs 39 | ``` -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | import pickle 5 | from sklearn import metrics 6 | from scipy import interpolate 7 | from math import factorial 8 | import json 9 | import re 10 | 11 | from constant import const 12 | 13 | 14 | NORMALIZE = const.NORMALIZE 15 | DECIDABLE_IDX = 0 16 | THRESHOLD = 200 17 | 18 | 19 | class RecordResult(object): 20 | def __init__(self, fpr=None, tpr=None, auc=-np.inf, dataset=None, loss_file=None): 21 | self.fpr = fpr 22 | self.tpr = tpr 23 | self.auc = auc 24 | self.dataset = dataset 25 | self.loss_file = loss_file 26 | 27 | def __lt__(self, other): 28 | return self.auc < other.auc 29 | 30 | def __gt__(self, other): 31 | return self.auc > other.auc 32 | 33 | def __str__(self): 34 | return 'dataset = {}, loss file = {}, auc = {}'.format(self.dataset, self.loss_file, self.auc) 35 | 36 | 37 | def temporal_annotation_to_label(annotations, length): 38 | label = np.zeros((length,), dtype=np.int8) 39 | 40 | for start, end in annotations: 41 | label[start - 1:end] = 1 42 | return label 43 | 44 | 45 | def load_video_scene(video_scene_file, k_folds, kth): 46 | scenes_videos_list = [] 47 | with open(video_scene_file, 'r') as reader: 48 | video_scenes_list = list(reader.readlines()) 49 | num_videos = len(video_scenes_list) 50 | 51 | if k_folds > 0: 52 | k_folds_ids = np.array_split(np.arange(num_videos), k_folds) 53 | test_ids = [] 54 | for k in range(k_folds): 55 | if k != (kth - 1): 56 | test_ids += k_folds_ids[k].tolist() 57 | else: 58 | test_ids = np.arange(num_videos) 59 | 60 | for ids in test_ids: 61 | video_scene = video_scenes_list[ids] 62 | video_scene = video_scene.rstrip() 63 | video, scene = video_scene.split() 64 | scenes_videos_list.append((video, scene)) 65 | 66 | return scenes_videos_list 67 | 68 | 69 | def load_gt_by_json(json_file, k_folds, kth): 70 | gts = [] 71 | with open(json_file, 'r') as file: 72 | data = json.load(file) 73 | video_names = list(sorted(data.keys())) 74 | 75 | num_videos = len(video_names) 76 | 77 | if k_folds > 0: 78 | k_folds_ids = np.array_split(np.arange(num_videos), k_folds) 79 | test_ids = [] 80 | for k in range(k_folds): 81 | if k != (kth - 1): 82 | test_ids += k_folds_ids[k].tolist() 83 | else: 84 | test_ids = np.arange(num_videos) 85 | 86 | for ids in test_ids: 87 | info = data[video_names[ids]] 88 | anomalies = info['anomalies'] 89 | length = info['length'] 90 | label = np.zeros((length,), dtype=np.int8) 91 | 92 | for event in anomalies: 93 | for name, annotation in event.items(): 94 | for start, end in annotation: 95 | label[start-1: end] = 1 96 | gts.append(label) 97 | 98 | return gts 99 | 100 | 101 | def load_semantic_frame_mask(json_file, k_folds, kth): 102 | semantic_gts = [] 103 | anomalies_names = [] 104 | anomalies_names_set = set() 105 | 106 | with open(json_file, 'r') as file: 107 | data = json.load(file) 108 | video_names = list(sorted(data.keys())) 109 | 110 | for video_name in video_names: 111 | info = data[video_name] 112 | anomalies = info['anomalies'] 113 | length = info['length'] 114 | semantic_label = ['normal'] * length 115 | name_set = set() 116 | 117 | for event in anomalies: 118 | for name, annotation in event.items(): 119 | for start, end in annotation: 120 | semantic_label[start:end] = [name] * (end - start) 121 | 122 | name_set.add(name) 123 | 124 | semantic_gts.append(semantic_label) 125 | anomalies_names.append(name_set) 126 | anomalies_names_set |= name_set 127 | 128 | print('gt json file = {}'.format(json_file)) 129 | 130 | num_videos = len(data.keys()) 131 | if k_folds != 0: 132 | k_folds_ids = np.array_split(np.arange(num_videos), k_folds) 133 | val_ids = k_folds_ids[kth - 1].tolist() 134 | test_ids = [] 135 | for k in range(k_folds): 136 | if k != (kth - 1): 137 | test_ids += k_folds_ids[k].tolist() 138 | else: 139 | val_ids = [] 140 | test_ids = np.arange(num_videos) 141 | 142 | # seen anomalies 143 | seen_anomalies = set() 144 | for v in val_ids: 145 | seen_anomalies |= anomalies_names[v] 146 | 147 | gts = [] 148 | for v in test_ids: 149 | gt = np.zeros(len(semantic_gts[v])) 150 | for t, label in enumerate(semantic_gts[v]): 151 | if label == 'normal': 152 | continue 153 | 154 | if label in seen_anomalies: 155 | gt[t] = 1 156 | else: 157 | gt[t] = 2 158 | gts.append(gt) 159 | 160 | return gts 161 | 162 | 163 | def load_psnr_gt(loss_file): 164 | with open(loss_file, 'rb') as reader: 165 | # results { 166 | # 'dataset': the name of dataset 167 | # 'psnr': the psnr of each testing videos, 168 | # } 169 | 170 | # psnr_records['psnr'] is np.array, shape(#videos) 171 | # psnr_records[0] is np.array ------> 01.avi 172 | # psnr_records[1] is np.array ------> 02.avi 173 | # ...... 174 | # psnr_records[n] is np.array ------> xx.avi 175 | 176 | results = pickle.load(reader) 177 | 178 | dataset = results['dataset'] 179 | psnr_records = results['psnr'] 180 | gts = results['frame_mask'] 181 | 182 | return dataset, psnr_records, gts 183 | 184 | 185 | def load_features_gt(loss_file): 186 | with open(loss_file, 'rb') as reader: 187 | # results { 188 | # 'dataset': the name of dataset 189 | # 'psnr': the psnr of each testing videos, 190 | # } 191 | 192 | # psnr_records['psnr'] is np.array, shape(#videos) 193 | # psnr_records[0] is np.array ------> 01.avi 194 | # psnr_records[1] is np.array ------> 02.avi 195 | # ...... 196 | # psnr_records[n] is np.array ------> xx.avi 197 | 198 | results = pickle.load(reader) 199 | 200 | dataset = results['dataset'] 201 | gts = results['frame_mask'] 202 | features = results['visualize']['feature'] 203 | 204 | f_means = [] 205 | for i, f in enumerate(features): 206 | length = len(gts[i]) 207 | sub_len = len(f) 208 | vol_size = int(np.ceil(length / sub_len)) 209 | f_m = np.mean(f, axis=(1, 2, 3)) 210 | # interpretation 211 | x_ids = np.arange(0, length, vol_size) 212 | x_ids[-1] = length - 1 213 | print(len(x_ids), sub_len) 214 | inter_func = interpolate.interp1d(x_ids, f_m) 215 | ids = np.arange(0, length) 216 | f_means.append(inter_func(ids)) 217 | 218 | return dataset, f_means, gts 219 | 220 | 221 | def load_psnr(loss_file): 222 | """ 223 | load image psnr or optical flow psnr. 224 | :param loss_file: loss file path 225 | :return: 226 | """ 227 | with open(loss_file, 'rb') as reader: 228 | # results { 229 | # 'dataset': the name of dataset 230 | # 'psnr': the psnr of each testing videos, 231 | # } 232 | 233 | # psnr_records['psnr'] is np.array, shape(#videos) 234 | # psnr_records[0] is np.array ------> 01.avi 235 | # psnr_records[1] is np.array ------> 02.avi 236 | # ...... 237 | # psnr_records[n] is np.array ------> xx.avi 238 | 239 | results = pickle.load(reader) 240 | psnrs = results['psnr'] 241 | return psnrs 242 | 243 | 244 | def get_scores_labels(loss_file): 245 | # the name of dataset, loss, and ground truth 246 | dataset, psnr_records, gt = load_psnr_gt(loss_file=loss_file) 247 | 248 | # the number of videos 249 | num_videos = len(psnr_records) 250 | 251 | scores = np.array([], dtype=np.float32) 252 | labels = np.array([], dtype=np.int8) 253 | # video normalization 254 | for i in range(num_videos): 255 | distance = psnr_records[i] 256 | 257 | if NORMALIZE: 258 | distance -= distance.min() # distances = (distance - min) / (max - min) 259 | distance /= distance.max() 260 | # distance = 1 - distance 261 | 262 | scores = np.concatenate((scores[:], distance[DECIDABLE_IDX:]), axis=0) 263 | labels = np.concatenate((labels[:], gt[i][DECIDABLE_IDX:]), axis=0) 264 | return dataset, scores, labels 265 | 266 | 267 | def precision_recall_auc(loss_file, *args): 268 | if not os.path.isdir(loss_file): 269 | loss_file_list = [loss_file] 270 | else: 271 | loss_file_list = os.listdir(loss_file) 272 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 273 | 274 | optimal_results = RecordResult() 275 | for sub_loss_file in loss_file_list: 276 | dataset, scores, labels = get_scores_labels(sub_loss_file) 277 | precision, recall, thresholds = metrics.precision_recall_curve(labels, scores, pos_label=0) 278 | auc = metrics.auc(recall, precision) 279 | 280 | results = RecordResult(recall, precision, auc, dataset, sub_loss_file) 281 | 282 | if optimal_results < results: 283 | optimal_results = results 284 | 285 | if os.path.isdir(loss_file): 286 | print(results) 287 | print('##### optimal result and model = {}'.format(optimal_results)) 288 | return optimal_results 289 | 290 | 291 | def cal_eer(fpr, tpr): 292 | # makes fpr + tpr = 1 293 | eer = fpr[np.nanargmin(np.absolute((fpr + tpr - 1)))] 294 | return eer 295 | 296 | 297 | def compute_eer(loss_file, *args): 298 | if not os.path.isdir(loss_file): 299 | loss_file_list = [loss_file] 300 | else: 301 | loss_file_list = os.listdir(loss_file) 302 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 303 | 304 | optimal_results = RecordResult(auc=np.inf) 305 | for sub_loss_file in loss_file_list: 306 | dataset, scores, labels = get_scores_labels(sub_loss_file) 307 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 308 | eer = cal_eer(fpr, tpr) 309 | 310 | results = RecordResult(fpr, tpr, eer, dataset, sub_loss_file) 311 | 312 | if optimal_results > results: 313 | optimal_results = results 314 | 315 | if os.path.isdir(loss_file): 316 | print(results) 317 | print('##### optimal result and model = {}'.format(optimal_results)) 318 | return optimal_results 319 | 320 | 321 | def compute_auc(loss_file, *args): 322 | if not os.path.isdir(loss_file): 323 | loss_file_list = [loss_file] 324 | else: 325 | loss_file_list = os.listdir(loss_file) 326 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 327 | 328 | optimal_results = RecordResult() 329 | for sub_loss_file in loss_file_list: 330 | # the name of dataset, loss, and ground truth 331 | dataset, psnr_records, gt = load_psnr_gt(loss_file=sub_loss_file) 332 | 333 | # the number of videos 334 | num_videos = len(psnr_records) 335 | 336 | scores = np.array([], dtype=np.float32) 337 | labels = np.array([], dtype=np.int8) 338 | # video normalization 339 | for i in range(num_videos): 340 | distance = psnr_records[i] 341 | 342 | if NORMALIZE: 343 | distance -= distance.min() # distances = (distance - min) / (max - min) 344 | distance /= distance.max() 345 | # distance -= np.mean(distance) 346 | # distance /= np.std(distance) 347 | # print(distance.max(), distance.min()) 348 | 349 | scores = np.concatenate((scores, distance[DECIDABLE_IDX:]), axis=0) 350 | labels = np.concatenate((labels, gt[i][DECIDABLE_IDX:]), axis=0) 351 | 352 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 353 | auc = metrics.auc(fpr, tpr) 354 | 355 | results = RecordResult(fpr, tpr, auc, dataset, sub_loss_file) 356 | 357 | if optimal_results < results: 358 | optimal_results = results 359 | 360 | # if os.path.isdir(loss_file): 361 | # print(results) 362 | print('##### optimal result and model = {}'.format(optimal_results)) 363 | return optimal_results 364 | 365 | 366 | def calculate_auc(psnrs_records, gts): 367 | scores = np.array([], dtype=np.float32) 368 | labels = np.array([], dtype=np.int8) 369 | for psnrs, gt in zip(psnrs_records, gts): 370 | # invalid_index = np.logical_or(np.isnan(psnrs), np.isinf(psnrs)) 371 | # psnrs[invalid_index] = THRESHOLD + 1 372 | # 373 | # too_big_index = np.logical_or(invalid_index, psnrs > THRESHOLD) 374 | # not_too_big_index = np.logical_not(too_big_index) 375 | # 376 | # psnr_min = np.min(psnrs[not_too_big_index]) 377 | # psnr_max = np.max(psnrs[not_too_big_index]) 378 | # 379 | # psnrs[too_big_index] = psnr_max 380 | 381 | score = (psnrs - psnrs.min()) / (psnrs.max() - psnrs.min()) 382 | scores = np.concatenate((scores, score)) 383 | labels = np.concatenate((labels, gt)) 384 | 385 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 386 | auc = metrics.auc(fpr, tpr) 387 | # print('auc = {}'.format(auc)) 388 | 389 | 390 | def compute_valid_auc(loss_file, *args): 391 | if not os.path.isdir(loss_file): 392 | loss_file_list = [loss_file] 393 | else: 394 | loss_file_list = os.listdir(loss_file) 395 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 396 | 397 | optimal_results = RecordResult() 398 | for sub_loss_file in loss_file_list: 399 | # the name of dataset, loss, and ground truth 400 | dataset, psnr_records, gt = load_psnr_gt(loss_file=sub_loss_file) 401 | 402 | # the number of videos 403 | num_videos = len(psnr_records) 404 | 405 | scores = np.array([], dtype=np.float32) 406 | labels = np.array([], dtype=np.int8) 407 | # video normalization 408 | for i in range(num_videos): 409 | psnrs = psnr_records[i] 410 | invalid_index = np.logical_or(np.isnan(psnrs), np.isinf(psnrs)) 411 | psnrs[invalid_index] = THRESHOLD + 1 412 | 413 | too_big_index = np.logical_or(invalid_index, psnrs > THRESHOLD) 414 | not_too_big_index = np.logical_not(too_big_index) 415 | 416 | psnr_min = np.min(psnrs[not_too_big_index]) 417 | psnr_max = np.max(psnrs[not_too_big_index]) 418 | 419 | psnrs[too_big_index] = psnr_max 420 | psnrs = filter_psnrs(psnrs) 421 | 422 | if NORMALIZE: 423 | psnrs = (psnrs - psnr_min) / (psnr_max - psnr_min) # distances = (distance - min) / (max - min) 424 | # distance -= np.mean(distance) 425 | # distance /= np.std(distance) 426 | # print(distance.max(), distance.min()) 427 | 428 | scores = np.concatenate((scores, psnrs[DECIDABLE_IDX:]), axis=0) 429 | labels = np.concatenate((labels, gt[i][DECIDABLE_IDX:]), axis=0) 430 | 431 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 432 | auc = metrics.auc(fpr, tpr) 433 | 434 | results = RecordResult(fpr, tpr, auc, dataset, sub_loss_file) 435 | 436 | if optimal_results < results: 437 | optimal_results = results 438 | 439 | if os.path.isdir(loss_file): 440 | print(results) 441 | print('##### optimal result and model = {}'.format(optimal_results)) 442 | return optimal_results 443 | 444 | 445 | def filter_psnrs(x): 446 | length = len(x) 447 | p = np.zeros(shape=(length,), dtype=np.float32) 448 | 449 | x_mean = np.max(x) 450 | mid_idx = length // 2 451 | p[mid_idx] = x[mid_idx] 452 | delta = length - mid_idx 453 | for i in range(mid_idx + 1, length): 454 | alpha = (i - mid_idx) / delta 455 | p[i] = alpha * x_mean + (1 - alpha) * x[i] 456 | 457 | for i in range(mid_idx - 1, -1, -1): 458 | alpha = (mid_idx - i) / delta 459 | p[i] = alpha * x_mean + (1 - alpha) * x[i] 460 | 461 | return p 462 | 463 | 464 | def filter_window(x, length): 465 | mid_idx = length // 2 466 | max_psnr = np.max(x) 467 | # max_psnr = np.mean(x) 468 | 469 | p = np.zeros(shape=x.shape, dtype=np.float32) 470 | p[mid_idx] = x[mid_idx] 471 | 472 | delta = length - mid_idx 473 | for i in range(mid_idx + 1, length): 474 | alpha = (i - mid_idx) / delta 475 | p[i] = alpha * max_psnr + (1 - alpha) * x[i] 476 | 477 | for i in range(mid_idx - 1, -1, -1): 478 | alpha = (mid_idx - i) / delta 479 | p[i] = alpha * max_psnr + (1 - alpha) * x[i] 480 | 481 | return p 482 | 483 | 484 | def filter_psnrs_2(x, window_size=128): 485 | length = len(x) 486 | 487 | window_size = min(length, window_size) 488 | p = np.empty(shape=(length,), dtype=np.float32) 489 | w_num = int(np.ceil(length / window_size)) 490 | 491 | for w in range(w_num): 492 | start = w * window_size 493 | end = min((w + 1) * window_size, length) 494 | p[start:end] = filter_window(x[start:end], length=end-start) 495 | 496 | return p 497 | 498 | 499 | def filter_psnrs_3(x, window_size=128): 500 | length = len(x) 501 | 502 | window_size = min(length, window_size) 503 | p = np.empty(shape=(length,), dtype=np.float32) 504 | w_num = int(np.ceil(length / window_size)) 505 | 506 | for w in range(w_num): 507 | start = w * window_size 508 | end = min((w + 1) * window_size, length) 509 | p[start:end] = filter_window(x[start:end], length=end-start) 510 | p[start:end] = (p[start:end] - p[start:end].min()) / (p[start:end].max() - p[start:end].min()) 511 | 512 | return p 513 | 514 | 515 | def savitzky_golay(y, window_size, order, deriv=0, rate=1): 516 | try: 517 | window_size = np.abs(np.int(window_size)) 518 | order = np.abs(np.int(order)) 519 | except ValueError: 520 | print("window_size and order have to be of type int") 521 | 522 | if window_size % 2 != 1 or window_size < 1: 523 | raise TypeError("window_size size must be a positive odd number") 524 | if window_size < order + 2: 525 | raise TypeError("window_size is too small for the polynomials order") 526 | order_range = range(order+1) 527 | half_window = (window_size -1) // 2 528 | # precompute coefficients 529 | b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window+1)]) 530 | m = np.linalg.pinv(b).A[deriv] * rate**deriv * factorial(deriv) 531 | # pad the signal at the extremes with 532 | # values taken from the signal itself 533 | firstvals = y[0] - np.abs(y[1:half_window+1][::-1] - y[0]) 534 | lastvals = y[-1] + np.abs(y[-half_window-1:-1][::-1] - y[-1]) 535 | y = np.concatenate((firstvals, y, lastvals)) 536 | return np.convolve(m[::-1], y, mode='valid') 537 | 538 | 539 | def compute_filter_auc(loss_file, *args): 540 | if not os.path.isdir(loss_file): 541 | loss_file_list = [loss_file] 542 | else: 543 | loss_file_list = os.listdir(loss_file) 544 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 545 | 546 | optimal_results = RecordResult() 547 | for sub_loss_file in loss_file_list: 548 | # the name of dataset, loss, and ground truth 549 | dataset, psnr_records, gt = load_psnr_gt(loss_file=sub_loss_file) 550 | 551 | # the number of videos 552 | num_videos = len(psnr_records) 553 | 554 | scores = np.array([], dtype=np.float32) 555 | labels = np.array([], dtype=np.int8) 556 | # video normalization 557 | for i in range(num_videos): 558 | psnrs = psnr_records[i] 559 | invalid_index = np.logical_or(np.isnan(psnrs), np.isinf(psnrs)) 560 | psnrs[invalid_index] = THRESHOLD + 1 561 | 562 | too_big_index = np.logical_or(invalid_index, psnrs > THRESHOLD) 563 | not_too_big_index = np.logical_not(too_big_index) 564 | 565 | psnr_max = np.max(psnrs[not_too_big_index]) 566 | 567 | psnrs[too_big_index] = psnr_max 568 | psnrs = filter_psnrs(psnrs) 569 | 570 | # avenue 571 | # psnrs = filter_psnrs_3(psnrs, window_size=25) 572 | 573 | # shanghaitech 574 | # psnrs = filter_psnrs_2(psnrs, window_size=500) 575 | 576 | psnr_min = np.min(psnrs) 577 | psnr_max = np.max(psnrs) 578 | 579 | if NORMALIZE: 580 | psnrs = (psnrs - psnr_min) / (psnr_max - psnr_min) # distances = (distance - min) / (max - min) 581 | 582 | scores = np.concatenate((scores, psnrs[DECIDABLE_IDX:]), axis=0) 583 | labels = np.concatenate((labels, gt[i][DECIDABLE_IDX:]), axis=0) 584 | 585 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 586 | auc = metrics.auc(fpr, tpr) 587 | 588 | results = RecordResult(fpr, tpr, auc, dataset, sub_loss_file) 589 | 590 | if optimal_results < results: 591 | optimal_results = results 592 | 593 | if os.path.isdir(loss_file): 594 | print(results) 595 | print('##### optimal result and model = {}'.format(optimal_results)) 596 | return optimal_results 597 | 598 | 599 | def smooth_psnrs(x): 600 | length = x.shape[0] 601 | for i in range(1, length): 602 | x[i] = 0.99 * x[i-1] + 0.01 * x[i] 603 | return x 604 | 605 | 606 | def compute_scene_auc(loss_file, *args): 607 | if not os.path.isdir(loss_file): 608 | loss_file_list = [loss_file] 609 | else: 610 | loss_file_list = os.listdir(loss_file) 611 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 612 | 613 | video_scene_file = args[0] 614 | pattern = re.compile('folds_([0-9]+)_kth_([0-9]+)') 615 | folds, kth = pattern.findall(loss_file)[0] 616 | folds, kth = int(folds), int(kth) 617 | 618 | video_scene_list = load_video_scene(video_scene_file, folds, kth) 619 | 620 | optimal_results = RecordResult() 621 | for sub_loss_file in loss_file_list: 622 | # the name of dataset, loss, and ground truth 623 | dataset, psnr_records, gt = load_psnr_gt(loss_file=sub_loss_file) 624 | 625 | # the number of videos 626 | num_videos = len(psnr_records) 627 | 628 | scene_psnrs_dict = {} 629 | scene_labels_dict = {} 630 | # video normalization 631 | for i in range(num_videos): 632 | psnrs = psnr_records[i] 633 | 634 | video, scene = video_scene_list[i] 635 | # print(video, scene) 636 | if scene not in scene_psnrs_dict: 637 | scene_psnrs_dict[scene] = [] 638 | scene_labels_dict[scene] = [] 639 | 640 | psnrs = filter_psnrs(psnrs) 641 | scene_psnrs_dict[scene].append(psnrs) 642 | scene_labels_dict[scene].append(gt[i]) 643 | 644 | # print(len(scene_psnrs_dict), len(scene_labels_dict)) 645 | 646 | if NORMALIZE: 647 | scores_list = [] 648 | labels_list = [] 649 | for scene in scene_psnrs_dict: 650 | psnrs = np.concatenate(scene_psnrs_dict[scene], axis=0) 651 | labels = np.concatenate(scene_labels_dict[scene], axis=0) 652 | # psnrs = filter_psnrs_2(psnrs, window_size=500) 653 | # psnrs = filter_psnrs(psnrs) 654 | # psnrs = savitzky_golay(psnrs, window_size=51, order=3) 655 | 656 | scores = (psnrs - psnrs.min()) / (psnrs.max() - psnrs.min()) 657 | scores_list.append(scores) 658 | labels_list.append(labels) 659 | 660 | scores = np.concatenate(scores_list, axis=0) 661 | labels = np.concatenate(labels_list, axis=0) 662 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 663 | auc = metrics.auc(fpr, tpr) 664 | results = RecordResult(fpr, tpr, auc, dataset, sub_loss_file) 665 | else: 666 | auc_list = [] 667 | for scene in scene_psnrs_dict: 668 | psnrs = np.concatenate(scene_psnrs_dict[scene], axis=0) 669 | labels = np.concatenate(scene_labels_dict[scene], axis=0) 670 | # psnrs = filter_psnrs_2(psnrs, window_size=51) 671 | # psnrs = filter_psnrs(psnrs) 672 | # psnrs = savitzky_golay(psnrs, window_size=51, order=3) 673 | # psnrs = smooth_psnrs(psnrs) 674 | scores = psnrs 675 | 676 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 677 | auc = metrics.auc(fpr, tpr) 678 | auc_list.append(auc) 679 | 680 | auc = np.mean(auc_list) 681 | 682 | results = RecordResult([], [], auc, dataset, sub_loss_file) 683 | 684 | if optimal_results < results: 685 | optimal_results = results 686 | 687 | if os.path.isdir(loss_file): 688 | print(results) 689 | print('##### optimal result and model = {}'.format(optimal_results)) 690 | return optimal_results 691 | 692 | 693 | def compute_seen_unseen_auc(loss_file, *args): 694 | # ipdb.set_trace() 695 | 696 | if not os.path.isdir(loss_file): 697 | loss_file_list = [loss_file] 698 | else: 699 | loss_file_list = os.listdir(loss_file) 700 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 701 | 702 | gt_file = args[0] 703 | pattern = re.compile('folds_([0-9]+)_kth_([0-9]+)') 704 | folds, kth = pattern.findall(loss_file)[0] 705 | folds, kth = int(folds), int(kth) 706 | gt = load_semantic_frame_mask(gt_file, folds, kth) 707 | 708 | seen_optimal_results = RecordResult() 709 | unseen_optimal_results = RecordResult() 710 | for sub_loss_file in loss_file_list: 711 | # the name of dataset, loss, and ground truth 712 | dataset, psnr_records, _ = load_psnr_gt(loss_file=sub_loss_file) 713 | 714 | # the number of videos 715 | num_videos = len(psnr_records) 716 | 717 | seen_scores = np.array([], dtype=np.float32) 718 | seen_labels = np.array([], dtype=np.int8) 719 | 720 | unseen_scores = np.array([], dtype=np.float32) 721 | unseen_labels = np.array([], dtype=np.int8) 722 | 723 | # video normalization 724 | for i in range(num_videos): 725 | distance = psnr_records[i] 726 | 727 | if NORMALIZE: 728 | distance -= distance.min() # distances = (distance - min) / (max - min) 729 | distance /= distance.max() 730 | 731 | # seen idx 732 | seen_idx = gt[i] != 2 733 | # unseen idx 734 | unseen_idx = gt[i] != 1 735 | # seen_idx[0:DECIDABLE_IDX] = False 736 | # unseen_idx[0:DECIDABLE_IDX] = False 737 | 738 | seen_scores = np.concatenate((seen_scores, distance[seen_idx]), axis=0) 739 | seen_labels = np.concatenate((seen_labels, gt[i][seen_idx]), axis=0) 740 | 741 | unseen_scores = np.concatenate((unseen_scores, distance[unseen_idx]), axis=0) 742 | unseen_labels = np.concatenate((unseen_labels, gt[i][unseen_idx]), axis=0) 743 | 744 | seen_fpr, seen_tpr, _ = metrics.roc_curve(seen_labels, seen_scores, pos_label=0) 745 | seen_auc = metrics.auc(seen_fpr, seen_tpr) 746 | 747 | unseen_fpr, unseen_tpr, _ = metrics.roc_curve(unseen_labels, unseen_scores, pos_label=0) 748 | unseen_auc = metrics.auc(unseen_fpr, unseen_tpr) 749 | 750 | seen_results = RecordResult(seen_fpr, seen_tpr, seen_auc, dataset, sub_loss_file) 751 | unseen_results = RecordResult(unseen_fpr, unseen_tpr, unseen_auc, dataset, sub_loss_file) 752 | 753 | if seen_optimal_results < seen_results: 754 | seen_optimal_results = seen_results 755 | 756 | if unseen_optimal_results < unseen_results: 757 | unseen_optimal_results = unseen_results 758 | 759 | if os.path.isdir(loss_file): 760 | print('seen {}'.format(seen_results)) 761 | print('unseen {}'.format(unseen_results)) 762 | 763 | print('##### seen optimal result and model = {}'.format(seen_optimal_results)) 764 | print('##### unseen optimal result and model = {}'.format(unseen_optimal_results)) 765 | return seen_optimal_results 766 | 767 | 768 | def compute_auc_with_gt_file(loss_file, *args): 769 | if not os.path.isdir(loss_file): 770 | loss_file_list = [loss_file] 771 | else: 772 | loss_file_list = os.listdir(loss_file) 773 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 774 | 775 | gt_file = args[0] 776 | pattern = re.compile('folds_([0-9]+)_kth_([0-9]+)') 777 | folds, kth = pattern.findall(loss_file)[0] 778 | folds, kth = int(folds), int(kth) 779 | gt = load_gt_by_json(gt_file, folds, kth) 780 | 781 | optimal_results = RecordResult() 782 | for sub_loss_file in loss_file_list: 783 | # the name of dataset, loss, and ground truth 784 | dataset, psnr_records, _ = load_psnr_gt(loss_file=sub_loss_file) 785 | 786 | # the number of videos 787 | num_videos = len(psnr_records) 788 | 789 | scores = np.array([], dtype=np.float32) 790 | labels = np.array([], dtype=np.int8) 791 | # video normalization 792 | for i in range(num_videos): 793 | distance = psnr_records[i] 794 | 795 | if NORMALIZE: 796 | distance -= distance.min() # distances = (distance - min) / (max - min) 797 | distance /= distance.max() 798 | 799 | scores = np.concatenate((scores, distance[DECIDABLE_IDX:]), axis=0) 800 | labels = np.concatenate((labels, gt[i][DECIDABLE_IDX:]), axis=0) 801 | 802 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 803 | auc = metrics.auc(fpr, tpr) 804 | 805 | results = RecordResult(fpr, tpr, auc, dataset, sub_loss_file) 806 | 807 | if optimal_results < results: 808 | optimal_results = results 809 | 810 | if os.path.isdir(loss_file): 811 | print(results) 812 | print('##### optimal result and model = {}'.format(optimal_results)) 813 | return optimal_results 814 | 815 | 816 | def compute_auc_with_threshold(loss_file, *args): 817 | if not os.path.isdir(loss_file): 818 | loss_file_list = [loss_file] 819 | else: 820 | loss_file_list = os.listdir(loss_file) 821 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 822 | 823 | psnr = 30 824 | optimal_results = RecordResult() 825 | for sub_loss_file in loss_file_list: 826 | # the name of dataset, loss, and ground truth 827 | dataset, psnr_records, gt = load_psnr_gt(loss_file=sub_loss_file) 828 | 829 | # the number of videos 830 | num_videos = len(psnr_records) 831 | 832 | scores = np.array([], dtype=np.float32) 833 | labels = np.array([], dtype=np.int8) 834 | # video normalization 835 | for i in range(num_videos): 836 | distance = psnr_records[i] 837 | # less_thresholds = distance < psnr 838 | # distance[less_thresholds] = distance.min() 839 | 840 | if NORMALIZE: 841 | distance -= distance.min() # distances = (distance - min) / (max - min) 842 | distance /= distance.max() 843 | 844 | scores = np.concatenate((scores, distance[DECIDABLE_IDX:]), axis=0) 845 | labels = np.concatenate((labels, gt[i][DECIDABLE_IDX:]), axis=0) 846 | 847 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 848 | auc = metrics.auc(fpr, tpr) 849 | 850 | results = RecordResult(fpr, tpr, auc, dataset, sub_loss_file) 851 | 852 | if optimal_results < results: 853 | optimal_results = results 854 | 855 | if os.path.isdir(loss_file): 856 | print(results) 857 | print('##### optimal result and model = {}'.format(optimal_results)) 858 | return optimal_results 859 | 860 | 861 | def compute_feature_auc(loss_file, *args): 862 | if not os.path.isdir(loss_file): 863 | loss_file_list = [loss_file] 864 | else: 865 | loss_file_list = os.listdir(loss_file) 866 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 867 | 868 | optimal_results = RecordResult() 869 | for sub_loss_file in loss_file_list: 870 | # the name of dataset, loss, and ground truth 871 | dataset, f_means, gt = load_features_gt(loss_file=sub_loss_file) 872 | 873 | # the number of videos 874 | num_videos = len(f_means) 875 | 876 | scores = np.array([], dtype=np.float32) 877 | labels = np.array([], dtype=np.int8) 878 | # video normalization 879 | for i in range(num_videos): 880 | distance = f_means[i] 881 | 882 | if NORMALIZE: 883 | distance -= distance.min() # distances = (distance - min) / (max - min) 884 | distance /= distance.max() 885 | # distance -= np.mean(distance) 886 | # distance /= np.std(distance) 887 | # print(distance.max(), distance.min()) 888 | 889 | scores = np.concatenate((scores, distance[DECIDABLE_IDX:]), axis=0) 890 | labels = np.concatenate((labels, gt[i][DECIDABLE_IDX:]), axis=0) 891 | 892 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 893 | auc = metrics.auc(fpr, tpr) 894 | 895 | results = RecordResult(fpr, tpr, auc, dataset, sub_loss_file) 896 | 897 | if optimal_results < results: 898 | optimal_results = results 899 | 900 | if os.path.isdir(loss_file): 901 | print(results) 902 | print('##### optimal result and model = {}'.format(optimal_results)) 903 | return optimal_results 904 | 905 | 906 | def average_psnr(loss_file, *args): 907 | if not os.path.isdir(loss_file): 908 | loss_file_list = [loss_file] 909 | else: 910 | loss_file_list = os.listdir(loss_file) 911 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 912 | 913 | max_avg_psnr = -np.inf 914 | max_file = '' 915 | for file in loss_file_list: 916 | psnr_records = load_psnr(file) 917 | 918 | psnr_records = np.concatenate(psnr_records, axis=0) 919 | avg_psnr = np.mean(psnr_records) 920 | if max_avg_psnr < avg_psnr: 921 | max_avg_psnr = avg_psnr 922 | max_file = file 923 | print('{}, average psnr = {}'.format(file, avg_psnr)) 924 | 925 | print('max average psnr file = {}, psnr = {}'.format(max_file, max_avg_psnr)) 926 | 927 | 928 | def calculate_psnr(loss_file, *args): 929 | optical_result = compute_auc(loss_file) 930 | print('##### optimal result and model = {}'.format(optical_result)) 931 | 932 | mean_psnr = [] 933 | for file in os.listdir(loss_file): 934 | file = os.path.join(loss_file, file) 935 | dataset, psnr_records, gt = load_psnr_gt(file) 936 | 937 | psnr_records = np.concatenate(psnr_records, axis=0) 938 | gt = np.concatenate(gt, axis=0) 939 | 940 | mean_normal_psnr = np.mean(psnr_records[gt == 0]) 941 | mean_abnormal_psnr = np.mean(psnr_records[gt == 1]) 942 | mean = np.mean(psnr_records) 943 | print('mean normal psrn = {}, mean abnormal psrn = {}, mean = {}'.format( 944 | mean_normal_psnr, 945 | mean_abnormal_psnr, 946 | mean) 947 | ) 948 | mean_psnr.append(mean) 949 | print('max mean psnr = {}'.format(np.max(mean_psnr))) 950 | 951 | 952 | def calculate_score(loss_file, *args): 953 | if not os.path.isdir(loss_file): 954 | loss_file_path = loss_file 955 | else: 956 | optical_result = compute_auc(loss_file) 957 | loss_file_path = optical_result.loss_file 958 | print('##### optimal result and model = {}'.format(optical_result)) 959 | dataset, psnr_records, gt = load_psnr_gt(loss_file=loss_file_path) 960 | 961 | # the number of videos 962 | num_videos = len(psnr_records) 963 | 964 | scores = np.array([], dtype=np.float32) 965 | labels = np.array([], dtype=np.int8) 966 | # video normalization 967 | for i in range(num_videos): 968 | distance = psnr_records[i] 969 | 970 | distance = (distance - distance.min()) / (distance.max() - distance.min()) 971 | 972 | scores = np.concatenate((scores, distance[DECIDABLE_IDX:]), axis=0) 973 | labels = np.concatenate((labels, gt[i][DECIDABLE_IDX:]), axis=0) 974 | 975 | mean_normal_scores = np.mean(scores[labels == 0]) 976 | mean_abnormal_scores = np.mean(scores[labels == 1]) 977 | print('mean normal scores = {}, mean abnormal scores = {}, ' 978 | 'delta = {}'.format(mean_normal_scores, mean_abnormal_scores, mean_normal_scores - mean_abnormal_scores)) 979 | 980 | 981 | eval_type_function = { 982 | 'compute_auc': compute_auc, 983 | 'compute_eer': compute_eer, 984 | 'compute_valid_auc': compute_valid_auc, 985 | 'compute_filter_auc': compute_filter_auc, 986 | 'compute_scene_auc': compute_scene_auc, 987 | 'compute_seen_unseen_auc': compute_seen_unseen_auc, 988 | 'precision_recall_auc': precision_recall_auc, 989 | 'calculate_psnr': calculate_psnr, 990 | 'calculate_score': calculate_score, 991 | 'average_psnr': average_psnr, 992 | 'average_psnr_sample': average_psnr, 993 | 'compute_auc_with_gt_file': compute_auc_with_gt_file, 994 | 'compute_feature_auc': compute_feature_auc, 995 | 'compute_auc_with_threshold': compute_auc_with_threshold 996 | } 997 | 998 | 999 | def evaluate(eval_type, save_file, gt_file=''): 1000 | assert eval_type in eval_type_function, 'there is no type of evaluation {}, please check {}' \ 1001 | .format(eval_type, eval_type_function.keys()) 1002 | eval_func = eval_type_function[eval_type] 1003 | optimal_results = eval_func(save_file, gt_file) 1004 | return optimal_results 1005 | 1006 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import time 4 | import numpy as np 5 | import pickle 6 | from scipy import interpolate 7 | 8 | from constant import const 9 | from models import prediction_networks_dict 10 | from utils.dataloaders.test_loader import DataTemporalGtLoader 11 | from utils.util import psnr_error, load 12 | 13 | import evaluate 14 | 15 | os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID" 16 | os.environ['CUDA_VISIBLE_DEVICES'] = const.GPUS[0] 17 | 18 | dataset_name = const.DATASET 19 | train_folder = const.TRAIN_FOLDER 20 | test_folder = const.TEST_FOLDER 21 | frame_mask = const.FRAME_MASK 22 | pixel_mask = const.PIXEL_MASK 23 | k_folds = const.K_FOLDS 24 | kth = const.KTH 25 | interval = const.INTERVAL 26 | 27 | batch_size = const.BATCH_SIZE 28 | iterations = const.ITERATIONS 29 | num_his = const.NUM_HIS 30 | height, width = const.HEIGHT, const.WIDTH 31 | 32 | prednet = prediction_networks_dict[const.PREDNET] 33 | evaluate_name = const.EVALUATE 34 | 35 | margin = const.MARGIN 36 | lam = const.LAMBDA 37 | 38 | summary_dir = const.SUMMARY_DIR 39 | snapshot_dir = const.SNAPSHOT_DIR 40 | psnr_dir = const.PSNR_DIR 41 | 42 | print(const) 43 | 44 | # define dataset 45 | # noinspection PyUnboundLocalVariable 46 | with tf.name_scope('dataset'): 47 | video_clips_tensor = tf.placeholder(shape=[1, (num_his + 1), height, width, 3], dtype=tf.float32) 48 | inputs = video_clips_tensor[:, 0:num_his, ...] 49 | frame_gts = video_clips_tensor[:, -1, ...] 50 | 51 | # define training generator function 52 | with tf.variable_scope('generator', reuse=None): 53 | outputs, features, _ = prednet(inputs=inputs, use_decoder=True) 54 | psnr_tensor = psnr_error(outputs, frame_gts) 55 | 56 | config = tf.ConfigProto() 57 | config.gpu_options.allow_growth = True 58 | with tf.Session(config=config) as sess: 59 | # dataset 60 | data_loader = DataTemporalGtLoader(dataset=dataset_name, folder=test_folder, k_folds=k_folds, kth=kth, 61 | frame_mask_file=frame_mask, pixel_mask_file=pixel_mask, 62 | resize_height=height, resize_width=width) 63 | video_info = data_loader.test_videos_info 64 | frame_masks = data_loader.get_frame_mask 65 | num_videos = len(video_info) 66 | 67 | # initialize weights 68 | sess.run(tf.global_variables_initializer()) 69 | print('Init global successfully!') 70 | 71 | restore_var = [v for v in tf.global_variables()] 72 | loader = tf.train.Saver(var_list=restore_var) 73 | 74 | def inference_func(ckpt, dataset_name, evaluate_name): 75 | load(loader, sess, ckpt) 76 | 77 | psnr_records = [] 78 | total = 0 79 | timestamp = time.time() 80 | 81 | if const.INTERPOLATION: 82 | vol_size = num_his + 1 83 | for v_id, (video_name, video) in enumerate(video_info.items()): 84 | length = video['length'] 85 | total += length 86 | gts = frame_masks[v_id] 87 | 88 | x_ids = np.arange(0, length, vol_size) 89 | x_ids[-1] = length - 1 90 | psnrs_ids = np.empty(shape=(len(x_ids),), dtype=np.float32) 91 | 92 | for i, t in enumerate(x_ids): 93 | if t == length - 1: 94 | start = length - vol_size 95 | end = length 96 | else: 97 | start = t 98 | end = t + vol_size 99 | 100 | video_clip = data_loader.get_video_clip(video_name, start, end) 101 | psnr = sess.run(psnr_tensor, feed_dict={video_clips_tensor: video_clip[np.newaxis, ...]}) 102 | psnrs_ids[i] = psnr 103 | 104 | print('video = {} / {}, i = {} / {}, psnr = {:.6f}, gt = {}'.format( 105 | video_name, num_videos, t, length, psnr, gts[end - 1])) 106 | 107 | # interpretation 108 | inter_func = interpolate.interp1d(x_ids, psnrs_ids) 109 | ids = np.arange(0, length) 110 | psnrs = inter_func(ids) 111 | psnr_records.append(psnrs) 112 | 113 | else: 114 | for v_id, (video_name, video) in enumerate(video_info.items()): 115 | length = video['length'] 116 | total += length 117 | psnrs = np.empty(shape=(length,), dtype=np.float32) 118 | gts = frame_masks[v_id] 119 | 120 | for i in range(num_his, length): 121 | video_clip = data_loader.get_video_clip(video_name, i - num_his, i + 1) 122 | psnr = sess.run(psnr_tensor, feed_dict={video_clips_tensor: video_clip[np.newaxis, ...]}) 123 | psnrs[i] = psnr 124 | 125 | print('video = {} / {}, i = {} / {}, psnr = {:.6f}, gt = {}'.format( 126 | video_name, num_videos, i, length, psnr, gts[i])) 127 | 128 | psnrs[0:num_his] = psnrs[num_his] 129 | psnr_records.append(psnrs) 130 | 131 | result_dict = {'dataset': dataset_name, 'psnr': psnr_records, 'diff_mask': [], 'frame_mask': frame_masks} 132 | 133 | used_time = time.time() - timestamp 134 | print('total time = {}, fps = {}'.format(used_time, total / used_time)) 135 | 136 | # TODO specify what's the actual name of ckpt. 137 | pickle_path = os.path.join(psnr_dir, os.path.split(ckpt)[-1]) 138 | with open(pickle_path, 'wb') as writer: 139 | pickle.dump(result_dict, writer, pickle.HIGHEST_PROTOCOL) 140 | 141 | results = evaluate.evaluate(evaluate_name, pickle_path) 142 | print(results) 143 | 144 | 145 | if os.path.isdir(snapshot_dir): 146 | def check_ckpt_valid(ckpt_name): 147 | is_valid = False 148 | ckpt = '' 149 | if ckpt_name.startswith('model.ckpt-'): 150 | ckpt_name_splits = ckpt_name.split('.') 151 | ckpt = str(ckpt_name_splits[0]) + '.' + str(ckpt_name_splits[1]) 152 | ckpt_path = os.path.join(snapshot_dir, ckpt) 153 | if os.path.exists(ckpt_path + '.index') and os.path.exists(ckpt_path + '.meta') and \ 154 | os.path.exists(ckpt_path + '.data-00000-of-00001'): 155 | is_valid = True 156 | 157 | return is_valid, ckpt 158 | 159 | def scan_psnr_folder(): 160 | tested_ckpt_in_psnr_sets = set() 161 | for test_psnr in os.listdir(psnr_dir): 162 | tested_ckpt_in_psnr_sets.add(test_psnr) 163 | return tested_ckpt_in_psnr_sets 164 | 165 | def scan_model_folder(): 166 | saved_models = set() 167 | for ckpt_name in os.listdir(snapshot_dir): 168 | is_valid, ckpt = check_ckpt_valid(ckpt_name) 169 | if is_valid: 170 | saved_models.add(ckpt) 171 | return saved_models 172 | 173 | tested_ckpt_sets = scan_psnr_folder() 174 | while True: 175 | all_model_ckpts = scan_model_folder() 176 | new_model_ckpts = all_model_ckpts - tested_ckpt_sets 177 | 178 | for ckpt_name in new_model_ckpts: 179 | # inference 180 | ckpt = os.path.join(snapshot_dir, ckpt_name) 181 | inference_func(ckpt, dataset_name, evaluate_name) 182 | 183 | tested_ckpt_sets.add(ckpt_name) 184 | 185 | print('waiting for models...') 186 | evaluate.evaluate('compute_auc', psnr_dir) 187 | time.sleep(300) 188 | else: 189 | inference_func(snapshot_dir, dataset_name, evaluate_name) 190 | 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .networks import ( 2 | resnet_convlstm, 3 | cyclegan_convlstm, 4 | cyclegan_convlstm_deconv1, 5 | resnet_conv3d, 6 | unet_conv2d, 7 | conv2d_deconv2d, 8 | cyclegan_conv2d, 9 | two_cyclegan_convlstm_classifier, 10 | unet_conv2d_instance_norm, 11 | ) 12 | 13 | prediction_networks_dict = { 14 | 'resnet_convlstm': resnet_convlstm, 15 | 'cyclegan_convlstm': cyclegan_convlstm, 16 | 'cyclegan_convlstm_deconv1': cyclegan_convlstm_deconv1, 17 | 'resnet_conv3d': resnet_conv3d, 18 | 'unet_conv2d': unet_conv2d, 19 | 'conv2d_deconv2d': conv2d_deconv2d, 20 | 'cyclegan_conv2d': cyclegan_conv2d, 21 | 'two_cyclegan_convlstm_classifier': two_cyclegan_convlstm_classifier, 22 | 'two_cyclegan_convlstm_focal_loss': two_cyclegan_convlstm_classifier, 23 | 'unet_conv2d_instance_norm': unet_conv2d_instance_norm 24 | } 25 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as tf_layers 3 | import numpy as np 4 | 5 | from models import pix2pix 6 | 7 | 8 | def cyclegan_arg_scope(instance_norm_center=True, 9 | instance_norm_scale=True, 10 | instance_norm_epsilon=0.001, 11 | weights_init_stddev=0.02, 12 | weight_decay=0.0): 13 | """Returns a default argument scope for all generators and discriminators. 14 | Args: 15 | instance_norm_center: Whether instance normalization applies centering. 16 | instance_norm_scale: Whether instance normalization applies scaling. 17 | instance_norm_epsilon: Small float added to the variance in the instance 18 | normalization to avoid dividing by zero. 19 | weights_init_stddev: Standard deviation of the random values to initialize 20 | the convolution kernels with. 21 | weight_decay: Magnitude of weight decay applied to all convolution kernel 22 | variables of the generator. 23 | Returns: 24 | An arg-scope. 25 | """ 26 | instance_norm_params = { 27 | 'center': instance_norm_center, 28 | 'scale': instance_norm_scale, 29 | 'epsilon': instance_norm_epsilon, 30 | } 31 | 32 | weights_regularizer = None 33 | if weight_decay and weight_decay > 0.0: 34 | weights_regularizer = tf_layers.l2_regularizer(weight_decay) 35 | 36 | with tf.contrib.framework.arg_scope( 37 | [tf_layers.conv2d, tf_layers.conv3d], 38 | normalizer_fn=tf_layers.instance_norm, 39 | normalizer_params=instance_norm_params, 40 | weights_initializer=tf.random_normal_initializer(0, weights_init_stddev), 41 | weights_regularizer=weights_regularizer) as sc: 42 | return sc 43 | 44 | 45 | def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'): 46 | """Upsamples the given inputs. 47 | Args: 48 | net: A Tensor of size [batch_size, height, width, filters]. 49 | num_outputs: The number of output filters. 50 | stride: A list of 2 scalars or a 1x2 Tensor indicating the scale, 51 | relative to the inputs, of the output dimensions. For example, if kernel 52 | size is [2, 3], then the output height and width will be twice and three 53 | times the input size. 54 | method: The upsampling method: 'nn_upsample_conv', 'bilinear_upsample_conv', 55 | or 'conv2d_transpose'. 56 | Returns: 57 | A Tensor which was upsampled using the specified method. 58 | Raises: 59 | ValueError: if `method` is not recognized. 60 | """ 61 | with tf.variable_scope('upconv'): 62 | net_shape = tf.shape(net) 63 | height = net_shape[1] 64 | width = net_shape[2] 65 | 66 | # Reflection pad by 1 in spatial dimensions (axes 1, 2 = h, w) to make a 3x3 67 | # 'valid' convolution produce an output with the same dimension as the 68 | # input. 69 | spatial_pad_1 = np.array([[0, 0], [1, 1], [1, 1], [0, 0]]) 70 | 71 | if method == 'nn_upsample_conv': 72 | net = tf.image.resize_nearest_neighbor( 73 | net, [stride[0] * height, stride[1] * width]) 74 | net = tf.pad(net, spatial_pad_1, 'REFLECT') 75 | net = tf_layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid') 76 | if method == 'bilinear_upsample_conv': 77 | net = tf.image.resize_bilinear( 78 | net, [stride[0] * height, stride[1] * width]) 79 | net = tf.pad(net, spatial_pad_1, 'REFLECT') 80 | net = tf_layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid') 81 | elif method == 'conv2d_transpose': 82 | net = tf_layers.conv2d_transpose( 83 | net, num_outputs, kernel_size=[3, 3], stride=stride, padding='same') 84 | else: 85 | raise ValueError('Unknown method: [%s]', method) 86 | 87 | return net 88 | 89 | 90 | def tensor_split_times(tensor, batch_size, time_steps): 91 | """ 92 | :param tensor: (N x T) x h x w x c 93 | :return: N x T x h x w x c 94 | """ 95 | return tf.reshape(tensor, shape=[batch_size, time_steps] + tensor.get_shape().as_list()[1:]) 96 | 97 | 98 | def tensor_fuse_times(tensor, batch_size, time_steps): 99 | """ 100 | :param tensor: N x T x h x w x c 101 | :return: (N x T) x h x w x c 102 | """ 103 | return tf.reshape(tensor, shape=[batch_size * time_steps] + tensor.get_shape().as_list()[2:]) 104 | 105 | 106 | def tensor_stack_times(tensor, time_steps): 107 | split_times = [] 108 | for i in range(time_steps): 109 | split_times.append(tensor[:, i, ...]) 110 | 111 | return tf.concat(split_times, axis=3) 112 | 113 | 114 | def recurrent_model(net_split_times, spatial_shape, batch_size, num_outputs): 115 | 116 | conv_lstm_cell = tf.contrib.rnn.ConvLSTMCell(conv_ndims=2, input_shape=spatial_shape, 117 | output_channels=num_outputs, kernel_shape=[3, 3]) 118 | initial_state = conv_lstm_cell.zero_state(batch_size, dtype=tf.float32) 119 | final_outputs, final_states = tf.nn.dynamic_rnn(cell=conv_lstm_cell, inputs=net_split_times, 120 | initial_state=initial_state, time_major=False, 121 | scope='ConvLSTM') 122 | # take the T time step output 123 | hidden_state = final_states[-1] 124 | 125 | return hidden_state 126 | 127 | 128 | def resnet_convlstm(inputs, num_filters=64, 129 | upsample_fn=cyclegan_upsample, 130 | kernel_size=3, 131 | num_outputs=3, 132 | tanh_linear_slope=0.0, 133 | use_decoder=False): 134 | 135 | batch_size, time_steps, height, width, channel = inputs.get_shape().as_list() 136 | 137 | inputs = tensor_fuse_times(inputs, batch_size, time_steps) 138 | 139 | end_points = {} 140 | 141 | if height and height % 4 != 0: 142 | raise ValueError('The input height must be a multiple of 4.') 143 | if width and width % 4 != 0: 144 | raise ValueError('The input width must be a multiple of 4.') 145 | 146 | if not isinstance(kernel_size, (list, tuple)): 147 | kernel_size = [kernel_size, kernel_size] 148 | 149 | kernel_height = kernel_size[0] 150 | kernel_width = kernel_size[1] 151 | pad_top = (kernel_height - 1) // 2 152 | pad_bottom = kernel_height // 2 153 | pad_left = (kernel_width - 1) // 2 154 | pad_right = kernel_width // 2 155 | paddings = np.array( 156 | [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], 157 | dtype=np.int32) 158 | spatial_pad_3 = np.array([[0, 0], [3, 3], [3, 3], [0, 0]]) 159 | 160 | with tf.contrib.framework.arg_scope(cyclegan_arg_scope()): 161 | 162 | ########### 163 | # Encoder # 164 | ########### 165 | with tf.variable_scope('encoder'): 166 | # 7x7 input stage, 224 x 224 x 64 167 | net = tf.pad(inputs, spatial_pad_3, 'REFLECT') 168 | net = tf_layers.conv2d(net, num_filters, kernel_size=[7, 7], padding='VALID') 169 | end_points['resnet_stage_0'] = net 170 | 171 | # 3x3 state, 112 x 112 x 128 172 | net = tf.pad(net, paddings, 'REFLECT') 173 | net = tf_layers.conv2d(net, num_filters * 2, kernel_size=kernel_size, stride=2, 174 | activation_fn=tf.nn.relu, padding='VALID') 175 | end_points['resnet_stage_1'] = net 176 | 177 | ######################################### 178 | # 3 Residual Blocks with 56 x 56 x 256 # 179 | ######################################### 180 | with tf.variable_scope('residual_blocks'): 181 | with tf.contrib.framework.arg_scope( 182 | [tf_layers.conv2d], 183 | kernel_size=kernel_size, 184 | stride=1, 185 | activation_fn=tf.nn.relu, 186 | padding='VALID'): 187 | net = tf.pad(net, paddings, 'REFLECT') 188 | net = tf_layers.conv2d(net, num_filters * 4, kernel_size=kernel_size, stride=2, 189 | activation_fn=tf.nn.relu, padding='VALID') 190 | end_points['resnet_stage_2'] = net 191 | for block_id in range(3): 192 | with tf.variable_scope('stage_2_block_{}'.format(block_id)): 193 | res_net = tf.pad(net, paddings, 'REFLECT') 194 | res_net = tf_layers.conv2d(res_net, num_filters * 4) 195 | res_net = tf.pad(res_net, paddings, 'REFLECT') 196 | res_net = tf_layers.conv2d(res_net, num_filters * 4, 197 | activation_fn=None) 198 | net += res_net 199 | 200 | end_points['resnet_state_2_block_%d' % block_id] = net 201 | 202 | ######################################### 203 | # 4 Residual Blocks with 28 x 28 x 512 # 204 | ######################################### 205 | net = tf.pad(net, paddings, 'REFLECT') 206 | net = tf_layers.conv2d(net, num_filters * 8, kernel_size=kernel_size, stride=2, 207 | activation_fn=tf.nn.relu, padding='VALID') 208 | end_points['resnet_stage_3'] = net 209 | for block_id in range(3): 210 | with tf.variable_scope('stage_3_block_{}'.format(block_id)): 211 | res_net = tf.pad(net, paddings, 'REFLECT') 212 | res_net = tf_layers.conv2d(res_net, num_filters * 8) 213 | res_net = tf.pad(res_net, paddings, 'REFLECT') 214 | res_net = tf_layers.conv2d(res_net, num_filters * 8, 215 | activation_fn=None) 216 | net += res_net 217 | 218 | end_points['resnet_state_3_block_%d' % block_id] = net 219 | 220 | #################### 221 | # Recurrent module # 222 | #################### 223 | with tf.variable_scope('recurrent'): 224 | # reshape net to N x T x h x w x c 225 | spatial_shape = net.get_shape().as_list()[1:] 226 | net_split_times = tensor_split_times(net, batch_size=batch_size, time_steps=time_steps) 227 | print('Encoder output = {}'.format(net_split_times)) 228 | 229 | hidden_state = recurrent_model(net_split_times, spatial_shape, batch_size, num_filters * 8) 230 | end_points['hidden_state'] = hidden_state 231 | print('ConvLSTM hidden state = {}', hidden_state) 232 | 233 | ########### 234 | # Decoder # 235 | ########### 236 | with tf.variable_scope('decoder'): 237 | 238 | with tf.contrib.framework.arg_scope( 239 | [tf_layers.conv2d], 240 | kernel_size=kernel_size, 241 | stride=1, 242 | activation_fn=tf.nn.relu): 243 | 244 | with tf.variable_scope('decoder1'): 245 | net = upsample_fn(hidden_state, num_outputs=num_filters * 8, stride=[2, 2]) 246 | end_points['decoder1'] = net 247 | 248 | with tf.variable_scope('decoder2'): 249 | net = upsample_fn(net, num_outputs=num_filters * 4, stride=[2, 2]) 250 | end_points['decoder2'] = net 251 | 252 | with tf.variable_scope('decoder3'): 253 | net = upsample_fn(net, num_outputs=num_filters * 2, stride=[2, 2]) 254 | end_points['decoder3'] = net 255 | 256 | with tf.variable_scope('output'): 257 | net = tf.pad(net, spatial_pad_3, 'REFLECT') 258 | logits = tf_layers.conv2d( 259 | net, 260 | num_outputs, [7, 7], 261 | activation_fn=None, 262 | normalizer_fn=None, 263 | padding='valid') 264 | # logits = tf.reshape(logits, _dynamic_or_static_shape(images)) 265 | 266 | end_points['logits'] = logits 267 | end_points['predictions'] = tf.tanh(logits) + logits * tanh_linear_slope 268 | print('Decoder output = {}'.format(logits)) 269 | 270 | return end_points['predictions'], end_points['hidden_state'], end_points 271 | 272 | 273 | def deconv_module(hidden_state, num_filters=64, num_outputs=3, kernel_size=3, 274 | upsample_fn=cyclegan_upsample, tanh_linear_slope=0.0): 275 | 276 | end_points = {} 277 | spatial_pad_3 = np.array([[0, 0], [3, 3], [3, 3], [0, 0]]) 278 | with tf.contrib.framework.arg_scope(cyclegan_arg_scope()): 279 | with tf.variable_scope('decoder'): 280 | with tf.contrib.framework.arg_scope( 281 | [tf_layers.conv2d], 282 | kernel_size=kernel_size, 283 | stride=1, 284 | activation_fn=tf.nn.relu): 285 | with tf.variable_scope('decoder1'): 286 | net = upsample_fn(hidden_state, num_outputs=num_filters * 2, stride=[2, 2]) 287 | end_points['decoder1'] = net 288 | 289 | with tf.variable_scope('decoder2'): 290 | net = upsample_fn(net, num_outputs=num_filters, stride=[2, 2]) 291 | end_points['decoder2'] = net 292 | 293 | with tf.variable_scope('output'): 294 | net = tf.pad(net, spatial_pad_3, 'REFLECT') 295 | logits = tf_layers.conv2d( 296 | net, 297 | num_outputs, [7, 7], 298 | activation_fn=None, 299 | normalizer_fn=None, 300 | padding='valid') 301 | 302 | outputs = tf.tanh(logits) + logits * tanh_linear_slope 303 | # print('Decoder output = {}'.format(logits)) 304 | 305 | return outputs, end_points 306 | 307 | 308 | def cyclegan_convlstm(inputs, num_filters=64, 309 | upsample_fn=cyclegan_upsample, 310 | kernel_size=3, 311 | num_outputs=3, 312 | tanh_linear_slope=0.0, use_decoder=False): 313 | 314 | batch_size, time_steps, height, width, channel = inputs.get_shape().as_list() 315 | 316 | inputs = tensor_fuse_times(inputs, batch_size, time_steps) 317 | 318 | if height and height % 4 != 0: 319 | raise ValueError('The input height must be a multiple of 4.') 320 | if width and width % 4 != 0: 321 | raise ValueError('The input width must be a multiple of 4.') 322 | 323 | if not isinstance(kernel_size, (list, tuple)): 324 | kernel_size = [kernel_size, kernel_size] 325 | 326 | kernel_height = kernel_size[0] 327 | kernel_width = kernel_size[1] 328 | pad_top = (kernel_height - 1) // 2 329 | pad_bottom = kernel_height // 2 330 | pad_left = (kernel_width - 1) // 2 331 | pad_right = kernel_width // 2 332 | paddings = np.array( 333 | [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], 334 | dtype=np.int32) 335 | spatial_pad_3 = np.array([[0, 0], [3, 3], [3, 3], [0, 0]]) 336 | 337 | with tf.contrib.framework.arg_scope(cyclegan_arg_scope()): 338 | 339 | ########### 340 | # Encoder # 341 | ########### 342 | with tf.variable_scope('input'): 343 | # 7x7 input stage 344 | net = tf.pad(inputs, spatial_pad_3, 'REFLECT') 345 | net = tf_layers.conv2d(net, num_filters, kernel_size=[7, 7], padding='VALID') 346 | 347 | with tf.variable_scope('encoder'): 348 | with tf.contrib.framework.arg_scope( 349 | [tf_layers.conv2d], 350 | kernel_size=kernel_size, 351 | stride=2, 352 | activation_fn=tf.nn.relu, 353 | padding='VALID'): 354 | net = tf.pad(net, paddings, 'REFLECT') 355 | net = tf_layers.conv2d(net, num_filters * 2) 356 | 357 | net = tf.pad(net, paddings, 'REFLECT') 358 | net = tf_layers.conv2d(net, num_filters * 4) 359 | 360 | ################### 361 | # Residual Blocks # 362 | ################### 363 | with tf.variable_scope('residual_blocks'): 364 | with tf.contrib.framework.arg_scope( 365 | [tf_layers.conv2d], 366 | kernel_size=kernel_size, 367 | stride=1, 368 | activation_fn=tf.nn.relu, 369 | padding='VALID'): 370 | for block_id in range(6): 371 | with tf.variable_scope('block_{}'.format(block_id)): 372 | res_net = tf.pad(net, paddings, 'REFLECT') 373 | res_net = tf_layers.conv2d(res_net, num_filters * 4) 374 | res_net = tf.pad(res_net, paddings, 'REFLECT') 375 | res_net = tf_layers.conv2d(res_net, num_filters * 4, activation_fn=None) 376 | net += res_net 377 | 378 | #################### 379 | # Recurrent module # 380 | #################### 381 | with tf.variable_scope('recurrent'): 382 | # reshape net to N x T x h x w x c 383 | spatial_shape = net.get_shape().as_list()[1:] 384 | net_split_times = tensor_split_times(net, batch_size=batch_size, time_steps=time_steps) 385 | # print('Encoder output = {}'.format(net_split_times)) 386 | 387 | hidden_state = recurrent_model(net_split_times, spatial_shape, batch_size, num_filters * 4) 388 | # print('ConvLSTM hidden state = {}', hidden_state) 389 | 390 | ########### 391 | # Decoder # 392 | ########### 393 | if use_decoder: 394 | outputs, end_points = deconv_module(hidden_state, num_filters, num_outputs=num_outputs, 395 | upsample_fn=upsample_fn, tanh_linear_slope=tanh_linear_slope) 396 | else: 397 | outputs, end_points = None, {} 398 | end_points['hidden_state'] = hidden_state 399 | 400 | return outputs, hidden_state, end_points 401 | 402 | 403 | def cyclegan_convlstm_deconv1(inputs, num_filters=64, 404 | upsample_fn=cyclegan_upsample, 405 | kernel_size=3, 406 | num_outputs=3, 407 | tanh_linear_slope=0.0, use_decoder=False): 408 | 409 | _, hidden_state, end_points = cyclegan_convlstm(inputs, use_decoder=False) 410 | 411 | with tf.contrib.framework.arg_scope(cyclegan_arg_scope()): 412 | with tf.contrib.framework.arg_scope( 413 | [tf_layers.conv2d_transpose], 414 | kernel_size=kernel_size, 415 | stride=4, 416 | padding='SAME'): 417 | with tf.variable_scope('decoder1'): 418 | net = tf_layers.conv2d_transpose( 419 | hidden_state, 420 | num_filters, 421 | normalizer_fn=None) 422 | 423 | with tf.variable_scope('output'): 424 | logits = tf_layers.conv2d( 425 | net, 426 | num_outputs, [7, 7], 427 | activation_fn=None, 428 | normalizer_fn=None, 429 | padding='SAME') 430 | 431 | outputs = tf.tanh(logits) + logits * tanh_linear_slope 432 | 433 | return outputs, hidden_state, end_points 434 | 435 | 436 | def two_cyclegan_convlstm_classifier(inputs, is_training=True, keep_prob=0.8, weight_decay=0.004): 437 | _, hidden_state, _ = cyclegan_convlstm(inputs=inputs) 438 | 439 | with tf.contrib.framework.arg_scope([tf_layers.fully_connected], 440 | weights_regularizer=tf_layers.l2_regularizer(weight_decay)): 441 | with tf.contrib.framework.arg_scope([tf_layers.dropout], is_training=is_training, keep_prob=keep_prob): 442 | with tf.variable_scope('fc'): 443 | net = tf_layers.flatten(hidden_state) 444 | net = tf_layers.dropout(net) 445 | net = tf_layers.fully_connected(inputs=net, num_outputs=128, activation_fn=tf.nn.relu) 446 | 447 | with tf.variable_scope('logits'): 448 | logits = tf_layers.fully_connected(inputs=net, num_outputs=2, activation_fn=None) 449 | probabilities = tf.nn.softmax(logits, name='prob') 450 | 451 | return logits, probabilities 452 | 453 | 454 | def cyclegan_conv2d(inputs, num_filters=64, 455 | upsample_fn=cyclegan_upsample, 456 | kernel_size=3, num_outputs=3, tanh_linear_slope=0.0, use_decoder=True): 457 | 458 | batch_size, time_steps, height, width, channel = inputs.get_shape().as_list() 459 | 460 | inputs = tensor_stack_times(inputs, time_steps) 461 | 462 | end_points = {} 463 | 464 | if height and height % 4 != 0: 465 | raise ValueError('The input height must be a multiple of 4.') 466 | if width and width % 4 != 0: 467 | raise ValueError('The input width must be a multiple of 4.') 468 | 469 | if not isinstance(kernel_size, (list, tuple)): 470 | kernel_size = [kernel_size, kernel_size] 471 | 472 | kernel_height = kernel_size[0] 473 | kernel_width = kernel_size[1] 474 | pad_top = (kernel_height - 1) // 2 475 | pad_bottom = kernel_height // 2 476 | pad_left = (kernel_width - 1) // 2 477 | pad_right = kernel_width // 2 478 | paddings = np.array( 479 | [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], 480 | dtype=np.int32) 481 | spatial_pad_3 = np.array([[0, 0], [3, 3], [3, 3], [0, 0]]) 482 | 483 | with tf.contrib.framework.arg_scope(cyclegan_arg_scope()): 484 | 485 | ########### 486 | # Encoder # 487 | ########### 488 | with tf.variable_scope('input'): 489 | # 7x7 input stage 490 | net = tf.pad(inputs, spatial_pad_3, 'REFLECT') 491 | net = tf_layers.conv2d(net, num_filters, kernel_size=[7, 7], padding='VALID') 492 | end_points['encoder_0'] = net 493 | 494 | with tf.variable_scope('encoder'): 495 | with tf.contrib.framework.arg_scope( 496 | [tf_layers.conv2d], 497 | kernel_size=kernel_size, 498 | stride=2, 499 | activation_fn=tf.nn.relu, 500 | padding='VALID'): 501 | net = tf.pad(net, paddings, 'REFLECT') 502 | net = tf_layers.conv2d(net, num_filters * 2) 503 | end_points['encoder_1'] = net 504 | net = tf.pad(net, paddings, 'REFLECT') 505 | net = tf_layers.conv2d(net, num_filters * 4) 506 | end_points['encoder_2'] = net 507 | 508 | ################### 509 | # Residual Blocks # 510 | ################### 511 | with tf.variable_scope('residual_blocks'): 512 | with tf.contrib.framework.arg_scope( 513 | [tf_layers.conv2d], 514 | kernel_size=kernel_size, 515 | stride=1, 516 | activation_fn=tf.nn.relu, 517 | padding='VALID'): 518 | for block_id in range(6): 519 | with tf.variable_scope('block_{}'.format(block_id)): 520 | res_net = tf.pad(net, paddings, 'REFLECT') 521 | res_net = tf_layers.conv2d(res_net, num_filters * 4) 522 | res_net = tf.pad(res_net, paddings, 'REFLECT') 523 | res_net = tf_layers.conv2d(res_net, num_filters * 4, 524 | activation_fn=None) 525 | net += res_net 526 | 527 | end_points['resnet_block_%d' % block_id] = net 528 | 529 | end_points['hidden_state'] = net 530 | 531 | ########### 532 | # Decoder # 533 | ########### 534 | with tf.variable_scope('decoder'): 535 | with tf.contrib.framework.arg_scope( 536 | [tf_layers.conv2d], 537 | kernel_size=kernel_size, 538 | stride=1, 539 | activation_fn=tf.nn.relu): 540 | with tf.variable_scope('decoder1'): 541 | net = upsample_fn(net, num_outputs=num_filters * 2, stride=[2, 2]) 542 | end_points['decoder1'] = net 543 | 544 | with tf.variable_scope('decoder2'): 545 | net = upsample_fn(net, num_outputs=num_filters, stride=[2, 2]) 546 | end_points['decoder2'] = net 547 | 548 | with tf.variable_scope('output'): 549 | net = tf.pad(net, spatial_pad_3, 'REFLECT') 550 | logits = tf_layers.conv2d( 551 | net, 552 | num_outputs, [7, 7], 553 | activation_fn=None, 554 | normalizer_fn=None, 555 | padding='valid') 556 | 557 | end_points['logits'] = logits 558 | end_points['predictions'] = tf.tanh(logits) + logits * tanh_linear_slope 559 | print('Decoder output = {}'.format(logits)) 560 | 561 | return end_points['predictions'], end_points['hidden_state'], end_points 562 | 563 | 564 | def unet_conv2d(inputs, num_filters=64, num_down_samples=4, use_decoder=True): 565 | _, time_steps, _, _, _ = inputs.get_shape().as_list() 566 | 567 | in_node = tensor_stack_times(inputs, time_steps) 568 | conv = [] 569 | for layer in range(0, num_down_samples): 570 | features = 2**layer*num_filters 571 | 572 | conv1 = tf_layers.conv2d(inputs=in_node, num_outputs=features, kernel_size=3) 573 | # if layer == num_down_samples - 1: 574 | # conv2 = tf_layers.conv2d(inputs=conv1, num_outputs=features, kernel_size=3, activation_fn=tf.nn.tanh) 575 | # else: 576 | # conv2 = tf_layers.conv2d(inputs=conv1, num_outputs=features, kernel_size=3, activation_fn=tf.nn.relu) 577 | conv2 = tf_layers.conv2d(inputs=conv1, num_outputs=features, kernel_size=3, activation_fn=tf.nn.relu) 578 | 579 | conv.append(conv2) 580 | 581 | if layer < num_down_samples - 1: 582 | in_node = tf_layers.max_pool2d(inputs=conv2, kernel_size=2, padding='SAME') 583 | # in_node = conv2d(inputs=conv2, num_outputs=features, kernel_size=filter_size, stride=2) 584 | 585 | in_node = conv[-1] 586 | hidden_state = conv[-1] 587 | 588 | for layer in range(num_down_samples-2, -1, -1): 589 | features = 2**(layer+1)*num_filters 590 | 591 | h_deconv = tf_layers.conv2d_transpose(inputs=in_node, num_outputs=features//2, kernel_size=2, stride=2) 592 | h_deconv_concat = tf.concat([conv[layer], h_deconv], axis=3) 593 | 594 | conv1 = tf_layers.conv2d(inputs=h_deconv_concat, num_outputs=features//2, kernel_size=3) 595 | in_node = tf_layers.conv2d(inputs=conv1, num_outputs=features//2, kernel_size=3) 596 | 597 | output = tf_layers.conv2d(inputs=in_node, num_outputs=3, kernel_size=3, activation_fn=None) 598 | output = tf.tanh(output) 599 | return output, hidden_state, None 600 | 601 | 602 | def unet_conv2d_instance_norm(inputs, num_filters=64, num_down_samples=4, use_decoder=True): 603 | _, time_steps, _, _, _ = inputs.get_shape().as_list() 604 | 605 | in_node = tensor_stack_times(inputs, time_steps) 606 | conv = [] 607 | end_points = {} 608 | with tf.contrib.framework.arg_scope(cyclegan_arg_scope()): 609 | for layer in range(0, num_down_samples): 610 | features = 2**layer*num_filters 611 | 612 | conv1 = tf_layers.conv2d(inputs=in_node, num_outputs=features, kernel_size=3) 613 | conv2 = tf_layers.conv2d(inputs=conv1, num_outputs=features, kernel_size=3) 614 | conv.append(conv2) 615 | 616 | if layer < num_down_samples - 1: 617 | in_node = tf_layers.max_pool2d(inputs=conv2, kernel_size=2, padding='SAME') 618 | # in_node = conv2d(inputs=conv2, num_outputs=features, kernel_size=filter_size, stride=2) 619 | 620 | in_node = conv[-1] 621 | hidden_state = conv[-1] 622 | 623 | if use_decoder: 624 | for i, layer in enumerate(range(num_down_samples-2, -1, -1)): 625 | features = 2**(layer+1)*num_filters 626 | 627 | h_deconv = tf_layers.conv2d_transpose(inputs=in_node, num_outputs=features//2, kernel_size=2, stride=2) 628 | h_deconv_concat = tf.concat([conv[layer], h_deconv], axis=3) 629 | 630 | conv1 = tf_layers.conv2d(inputs=h_deconv_concat, num_outputs=features//2, kernel_size=3) 631 | in_node = tf_layers.conv2d(inputs=conv1, num_outputs=features//2, kernel_size=3) 632 | 633 | end_points['encoder_%d' % i] = in_node 634 | 635 | output = tf_layers.conv2d(inputs=in_node, num_outputs=3, kernel_size=3, activation_fn=None) 636 | output = tf.tanh(output) 637 | else: 638 | output = None 639 | return output, hidden_state, end_points 640 | 641 | 642 | def conv2d_deconv2d(inputs, num_filters=64, num_down_samples=4): 643 | _, time_steps, _, _, _ = inputs.get_shape().as_list() 644 | 645 | in_node = tensor_stack_times(inputs, time_steps) 646 | for layer in range(0, num_down_samples): 647 | features = 2**layer*num_filters 648 | 649 | conv1 = tf_layers.conv2d(inputs=in_node, num_outputs=features, kernel_size=3) 650 | conv2 = tf_layers.conv2d(inputs=conv1, num_outputs=features, kernel_size=3) 651 | 652 | if layer < num_down_samples - 1: 653 | in_node = tf_layers.max_pool2d(inputs=conv2, kernel_size=2, padding='SAME') 654 | # in_node = conv2d(inputs=conv2, num_outputs=features, kernel_size=filter_size, stride=2) 655 | 656 | in_node = conv2 657 | hidden_state = conv2 658 | 659 | for layer in range(num_down_samples-2, -1, -1): 660 | features = 2**(layer+1)*num_filters 661 | 662 | h_deconv = tf_layers.conv2d_transpose(inputs=in_node, num_outputs=features//2, kernel_size=2, stride=2) 663 | 664 | conv1 = tf_layers.conv2d(inputs=h_deconv, num_outputs=features//2, kernel_size=3) 665 | in_node = tf_layers.conv2d(inputs=conv1, num_outputs=features//2, kernel_size=3) 666 | 667 | output = tf_layers.conv2d(inputs=in_node, num_outputs=3, kernel_size=3, activation_fn=None) 668 | output = tf.tanh(output) 669 | return output, hidden_state 670 | 671 | 672 | def resnet_conv3d(inputs, num_filters=64, 673 | upsample_fn=cyclegan_upsample, 674 | kernel_size=3, 675 | num_outputs=3, 676 | tanh_linear_slope=0.0, use_decoder=True): 677 | 678 | end_points = {} 679 | 680 | with tf.contrib.framework.arg_scope(cyclegan_arg_scope()): 681 | 682 | ########### 683 | # Encoder # 684 | ########### 685 | with tf.variable_scope('input'): 686 | # 7x7 input stage 687 | net = tf_layers.conv3d(inputs, num_filters, kernel_size=[1, 7, 7], stride=1, padding='SAME') 688 | end_points['encoder_0'] = net 689 | 690 | with tf.variable_scope('encoder'): 691 | with tf.contrib.framework.arg_scope( 692 | [tf_layers.conv3d], 693 | kernel_size=kernel_size, 694 | stride=2, 695 | activation_fn=tf.nn.relu, 696 | padding='SAME'): 697 | net = tf_layers.conv3d(net, num_filters * 2) 698 | end_points['encoder_1'] = net 699 | net = tf_layers.conv3d(net, num_filters * 4) 700 | end_points['encoder_2'] = net 701 | 702 | ################### 703 | # Residual Blocks # 704 | ################### 705 | with tf.variable_scope('residual_blocks'): 706 | with tf.contrib.framework.arg_scope( 707 | [tf_layers.conv3d], 708 | kernel_size=kernel_size, 709 | stride=1, 710 | activation_fn=tf.nn.relu, 711 | padding='SAME'): 712 | for block_id in range(6): 713 | with tf.variable_scope('block_{}'.format(block_id)): 714 | res_net = tf_layers.conv3d(net, num_filters * 4) 715 | res_net = tf_layers.conv3d(res_net, num_filters * 4, 716 | activation_fn=None) 717 | net += res_net 718 | 719 | end_points['resnet_block_%d' % block_id] = net 720 | 721 | hidden_state = tf.nn.tanh(net) 722 | hidden_state = tf.squeeze(hidden_state, axis=1) 723 | end_points['hidden_state'] = hidden_state 724 | 725 | ########### 726 | # Decoder # 727 | ########### 728 | with tf.variable_scope('decoder'): 729 | with tf.contrib.framework.arg_scope( 730 | [tf_layers.conv2d], 731 | kernel_size=kernel_size, 732 | stride=1, 733 | activation_fn=tf.nn.relu): 734 | with tf.variable_scope('decoder1'): 735 | net = upsample_fn(hidden_state, num_outputs=num_filters * 2, stride=[2, 2]) 736 | end_points['decoder1'] = net 737 | 738 | with tf.variable_scope('decoder2'): 739 | net = upsample_fn(net, num_outputs=num_filters, stride=[2, 2]) 740 | end_points['decoder2'] = net 741 | 742 | with tf.variable_scope('output'): 743 | logits = tf_layers.conv2d( 744 | net, 745 | num_outputs, [7, 7], 746 | activation_fn=None, 747 | normalizer_fn=None, 748 | padding='SAME') 749 | 750 | end_points['logits'] = logits 751 | end_points['predictions'] = tf.tanh(logits) + logits * tanh_linear_slope 752 | print('Decoder output = {}'.format(logits)) 753 | 754 | return end_points['predictions'], end_points['hidden_state'], end_points 755 | 756 | 757 | def discriminator(inputs, num_filers=(128, 256, 512, 512)): 758 | logits, end_points = pix2pix.pix2pix_discriminator(inputs, num_filers) 759 | return logits, end_points['predictions'] 760 | 761 | 762 | if __name__ == '__main__': 763 | input_tensor = tf.placeholder(shape=[10, 8, 224, 224, 3], dtype=tf.float32) 764 | # logits, end_points = cyclegan_convlstm(inputs=input_tensor, num_outputs=3, num_filters=64) 765 | resnet_conv3d(inputs=input_tensor, num_outputs=3, num_filters=64) 766 | -------------------------------------------------------------------------------- /models/pix2pix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Implementation of the Image-to-Image Translation model. 16 | This network represents a port of the following work: 17 | Image-to-Image Translation with Conditional Adversarial Networks 18 | Phillip Isola, Jun-Yan Zhu, Tinghui Zhou and Alexei A. Efros 19 | Arxiv, 2017 20 | https://phillipi.github.io/pix2pix/ 21 | A reference implementation written in Lua can be found at: 22 | https://github.com/phillipi/pix2pix/blob/master/models.lua 23 | """ 24 | import collections 25 | import functools 26 | 27 | import tensorflow as tf 28 | 29 | layers = tf.contrib.layers 30 | 31 | 32 | def pix2pix_arg_scope(): 33 | """Returns a default argument scope for isola_net. 34 | Returns: 35 | An arg scope. 36 | """ 37 | # These parameters come from the online port, which don't necessarily match 38 | # those in the paper. 39 | # TODO(nsilberman): confirm these values with Philip. 40 | instance_norm_params = { 41 | 'center': True, 42 | 'scale': True, 43 | 'epsilon': 0.00001, 44 | } 45 | 46 | with tf.contrib.framework.arg_scope( 47 | [layers.conv2d, layers.conv2d_transpose], 48 | normalizer_fn=layers.instance_norm, 49 | normalizer_params=instance_norm_params, 50 | weights_initializer=tf.random_normal_initializer(0, 0.02)) as sc: 51 | return sc 52 | 53 | 54 | def upsample(net, num_outputs, kernel_size, method='nn_upsample_conv'): 55 | """Upsamples the given inputs. 56 | Args: 57 | net: A `Tensor` of size [batch_size, height, width, filters]. 58 | num_outputs: The number of output filters. 59 | kernel_size: A list of 2 scalars or a 1x2 `Tensor` indicating the scale, 60 | relative to the inputs, of the output dimensions. For example, if kernel 61 | size is [2, 3], then the output height and width will be twice and three 62 | times the input size. 63 | method: The upsampling method. 64 | Returns: 65 | An `Tensor` which was upsampled using the specified method. 66 | Raises: 67 | ValueError: if `method` is not recognized. 68 | """ 69 | net_shape = tf.shape(net) 70 | height = net_shape[1] 71 | width = net_shape[2] 72 | 73 | if method == 'nn_upsample_conv': 74 | net = tf.image.resize_nearest_neighbor( 75 | net, [kernel_size[0] * height, kernel_size[1] * width]) 76 | net = layers.conv2d(net, num_outputs, [4, 4], activation_fn=None) 77 | elif method == 'conv2d_transpose': 78 | net = layers.conv2d_transpose( 79 | net, num_outputs, [4, 4], stride=kernel_size, activation_fn=None) 80 | else: 81 | raise ValueError('Unknown method: [%s]', method) 82 | 83 | return net 84 | 85 | 86 | class Block( 87 | collections.namedtuple('Block', ['num_filters', 'decoder_keep_prob'])): 88 | """Represents a single block of encoder and decoder processing. 89 | The Image-to-Image translation paper works a bit differently than the original 90 | U-Net model. In particular, each block represents a single operation in the 91 | encoder which is concatenated with the corresponding decoder representation. 92 | A dropout layer follows the concatenation and convolution of the concatenated 93 | features. 94 | """ 95 | pass 96 | 97 | 98 | def _default_generator_blocks(): 99 | """Returns the default generator block definitions. 100 | Returns: 101 | A list of generator blocks. 102 | """ 103 | return [ 104 | Block(64, 0.5), 105 | Block(128, 0.5), 106 | Block(256, 0.5), 107 | Block(512, 0), 108 | Block(512, 0), 109 | Block(512, 0), 110 | Block(512, 0), 111 | ] 112 | 113 | 114 | def pix2pix_generator(net, 115 | num_outputs, 116 | blocks=None, 117 | upsample_method='nn_upsample_conv', 118 | is_training=False): # pylint: disable=unused-argument 119 | """Defines the network architecture. 120 | Args: 121 | net: A `Tensor` of size [batch, height, width, channels]. Note that the 122 | generator currently requires square inputs (e.g. height=width). 123 | num_outputs: The number of (per-pixel) outputs. 124 | blocks: A list of generator blocks or `None` to use the default generator 125 | definition. 126 | upsample_method: The method of upsampling images, one of 'nn_upsample_conv' 127 | or 'conv2d_transpose' 128 | is_training: Whether or not we're in training or testing mode. 129 | Returns: 130 | A `Tensor` representing the model output and a dictionary of model end 131 | points. 132 | Raises: 133 | ValueError: if the input heights do not match their widths. 134 | """ 135 | end_points = {} 136 | 137 | blocks = blocks or _default_generator_blocks() 138 | 139 | input_size = net.get_shape().as_list() 140 | height, width = input_size[1], input_size[2] 141 | if height != width: 142 | raise ValueError('The input height must match the input width.') 143 | 144 | input_size[3] = num_outputs 145 | 146 | upsample_fn = functools.partial(upsample, method=upsample_method) 147 | 148 | encoder_activations = [] 149 | 150 | ########### 151 | # Encoder # 152 | ########### 153 | with tf.variable_scope('encoder'): 154 | with tf.contrib.framework.arg_scope( 155 | [layers.conv2d], 156 | kernel_size=[4, 4], 157 | stride=2, 158 | activation_fn=tf.nn.leaky_relu): 159 | 160 | for block_id, block in enumerate(blocks): 161 | # No normalizer for the first encoder layers as per 'Image-to-Image', 162 | # Section 5.1.1 163 | if block_id == 0: 164 | # First layer doesn't use normalizer_fn 165 | net = layers.conv2d(net, block.num_filters, normalizer_fn=None) 166 | elif block_id < len(blocks) - 1: 167 | net = layers.conv2d(net, block.num_filters) 168 | else: 169 | # Last layer doesn't use activation_fn nor normalizer_fn 170 | net = layers.conv2d( 171 | net, block.num_filters, activation_fn=None, normalizer_fn=None) 172 | 173 | encoder_activations.append(net) 174 | end_points['encoder%d' % block_id] = net 175 | 176 | ########### 177 | # Decoder # 178 | ########### 179 | reversed_blocks = list(blocks) 180 | reversed_blocks.reverse() 181 | 182 | with tf.variable_scope('decoder'): 183 | # Dropout is used at both train and test time as per 'Image-to-Image', 184 | # Section 2.1 (last paragraph). 185 | with tf.contrib.framework.arg_scope([layers.dropout], is_training=is_training): 186 | 187 | for block_id, block in enumerate(reversed_blocks): 188 | if block_id > 0: 189 | net = tf.concat([net, encoder_activations[-block_id - 1]], axis=3) 190 | 191 | # The Relu comes BEFORE the upsample op: 192 | net = tf.nn.relu(net) 193 | net = upsample_fn(net, block.num_filters, [2, 2]) 194 | if block.decoder_keep_prob > 0: 195 | net = layers.dropout(net, keep_prob=block.decoder_keep_prob) 196 | end_points['decoder%d' % block_id] = net 197 | 198 | with tf.variable_scope('output'): 199 | logits = layers.conv2d(net, num_outputs, [4, 4], activation_fn=None) 200 | # print(logits) 201 | # logits = tf.reshape(logits, input_size) 202 | 203 | end_points['logits'] = logits 204 | end_points['predictions'] = tf.tanh(logits) 205 | 206 | return logits, end_points 207 | 208 | 209 | def pix2pix_discriminator(net, num_filters, padding=2, is_training=False): 210 | """Creates the Image2Image Translation Discriminator. 211 | Args: 212 | net: A `Tensor` of size [batch_size, height, width, channels] representing 213 | the input. 214 | num_filters: A list of the filters in the discriminator. The length of the 215 | list determines the number of layers in the discriminator. 216 | padding: Amount of reflection padding applied before each convolution. 217 | is_training: Whether or not the model is training or testing. 218 | Returns: 219 | A logits `Tensor` of size [batch_size, N, N, 1] where N is the number of 220 | 'patches' we're attempting to discriminate and a dictionary of model end 221 | points. 222 | """ 223 | del is_training 224 | end_points = {} 225 | 226 | num_layers = len(num_filters) 227 | 228 | def padded(net, scope): 229 | if padding: 230 | with tf.variable_scope(scope): 231 | spatial_pad = tf.constant( 232 | [[0, 0], [padding, padding], [padding, padding], [0, 0]], 233 | dtype=tf.int32) 234 | return tf.pad(net, spatial_pad, 'REFLECT') 235 | else: 236 | return net 237 | 238 | with tf.contrib.framework.arg_scope( 239 | [layers.conv2d], 240 | kernel_size=[4, 4], 241 | stride=2, 242 | padding='valid', 243 | activation_fn=tf.nn.leaky_relu): 244 | 245 | # No normalization on the input layer. 246 | net = layers.conv2d( 247 | padded(net, 'conv0'), num_filters[0], normalizer_fn=None, scope='conv0') 248 | 249 | end_points['conv0'] = net 250 | 251 | for i in range(1, num_layers - 1): 252 | net = layers.conv2d( 253 | padded(net, 'conv%d' % i), num_filters[i], scope='conv%d' % i) 254 | end_points['conv%d' % i] = net 255 | 256 | # Stride 1 on the last layer. 257 | net = layers.conv2d( 258 | padded(net, 'conv%d' % (num_layers - 1)), 259 | num_filters[-1], 260 | stride=1, 261 | scope='conv%d' % (num_layers - 1)) 262 | end_points['conv%d' % (num_layers - 1)] = net 263 | 264 | # 1-dim logits, stride 1, no activation, no normalization. 265 | logits = layers.conv2d( 266 | padded(net, 'conv%d' % num_layers), 267 | 1, 268 | stride=1, 269 | activation_fn=None, 270 | normalizer_fn=None, 271 | scope='conv%d' % num_layers) 272 | end_points['logits'] = logits 273 | end_points['predictions'] = tf.sigmoid(logits) 274 | return logits, end_points 275 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.14.0 2 | tensorflow-gpu==1.14.0 3 | imageio 4 | scipy 5 | numpy 6 | scikit_learn 7 | opencv-python 8 | progressbar2 9 | -------------------------------------------------------------------------------- /train_scripts/train_normal_annotation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | from models import prediction_networks_dict 5 | from utils.dataloaders.only_normal_loader import NormalDataLoader 6 | from utils.util import load, save, psnr_error 7 | from constant import const 8 | 9 | 10 | os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID" 11 | os.environ['CUDA_VISIBLE_DEVICES'] = const.GPUS[0] 12 | 13 | dataset_name = const.DATASET 14 | train_folder = const.TRAIN_FOLDER 15 | test_folder = const.TEST_FOLDER 16 | frame_mask = const.FRAME_MASK 17 | pixel_mask = const.PIXEL_MASK 18 | k_folds = const.K_FOLDS 19 | kth = const.KTH 20 | interval = const.INTERVAL 21 | multi_interval = const.MULTI_INTERVAL 22 | 23 | batch_size = const.BATCH_SIZE 24 | iterations = const.ITERATIONS 25 | num_his = const.NUM_HIS 26 | height, width = const.HEIGHT, const.WIDTH 27 | 28 | prednet = prediction_networks_dict[const.PREDNET] 29 | 30 | margin = const.MARGIN 31 | lam = const.LAMBDA 32 | 33 | model_save_freq = const.MODEL_SAVE_FREQ 34 | summary_dir = const.SUMMARY_DIR 35 | snapshot_dir = const.SNAPSHOT_DIR 36 | 37 | print(const) 38 | 39 | # define dataset 40 | # noinspection PyUnboundLocalVariable 41 | with tf.name_scope('dataset'): 42 | tf_dataset = NormalDataLoader(dataset_name, train_folder, height, width) 43 | 44 | train_dataset = tf_dataset(batch_size, time_steps=(num_his + 1), interval=interval) 45 | 46 | train_it = train_dataset.make_one_shot_iterator() 47 | train_tensor = train_it.get_next() 48 | train_tensor.set_shape([batch_size, (num_his + 1), height, width, 3]) 49 | 50 | train_positive = train_tensor[:, 0:num_his, ...] 51 | train_positive_gt = train_tensor[:, -1, ...] 52 | 53 | 54 | # define training generator function 55 | with tf.variable_scope('generator', reuse=None): 56 | train_positive_output, train_positive_feature, _ = prednet(train_positive, use_decoder=True) 57 | train_positive_psnr = psnr_error(train_positive_output, train_positive_gt) 58 | 59 | with tf.name_scope('training'): 60 | g_loss = tf.reduce_mean(tf.abs(train_positive_output - train_positive_gt)) 61 | 62 | g_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='g_step') 63 | g_lrate = tf.train.piecewise_constant(g_step, boundaries=const.LRATE_G_BOUNDARIES, values=const.LRATE_G) 64 | g_optimizer = tf.train.AdamOptimizer(learning_rate=g_lrate, name='g_optimizer') 65 | g_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 66 | 67 | g_train_op = g_optimizer.minimize(g_loss, global_step=g_step, var_list=g_vars, name='g_train_op') 68 | 69 | # add all to summaries 70 | tf.summary.scalar(tensor=g_loss, name='g_loss') 71 | tf.summary.image(tensor=train_positive_output, name='positive_output') 72 | tf.summary.image(tensor=train_positive_gt, name='positive_gt') 73 | tf.summary.scalar(tensor=train_positive_psnr, name='positive_psnr') 74 | summary_op = tf.summary.merge_all() 75 | 76 | config = tf.ConfigProto() 77 | config.gpu_options.allow_growth = True 78 | with tf.Session(config=config) as sess: 79 | # summaries 80 | summary_writer = tf.summary.FileWriter(summary_dir, graph=sess.graph) 81 | 82 | # initialize weights 83 | sess.run(tf.global_variables_initializer()) 84 | print('Init successfully!') 85 | 86 | # tf saver 87 | saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=None) 88 | restore_var = [v for v in tf.global_variables()] 89 | loader = tf.train.Saver(var_list=restore_var) 90 | if os.path.isdir(snapshot_dir): 91 | ckpt = tf.train.get_checkpoint_state(snapshot_dir) 92 | if ckpt and ckpt.model_checkpoint_path: 93 | load(loader, sess, ckpt.model_checkpoint_path) 94 | else: 95 | print('No checkpoint file found.') 96 | else: 97 | load(loader, sess, snapshot_dir) 98 | 99 | print('Start training ...') 100 | 101 | _step, _loss, _summaries = 0, None, None 102 | while _step < iterations: 103 | try: 104 | _, _step, _g_loss, _p_psnr, _summaries = \ 105 | sess.run([g_train_op, g_step, g_loss, train_positive_psnr, summary_op]) 106 | 107 | if _step % 10 == 0: 108 | print('Iteration = {}, global loss = {:.6f}, positive psnr = {:.6f}'.format(_step, _g_loss, _p_psnr)) 109 | 110 | if _step % 100 == 0: 111 | summary_writer.add_summary(_summaries, global_step=_step) 112 | print('Save summaries...') 113 | 114 | if _step % model_save_freq == 0: 115 | save(saver, sess, snapshot_dir, _step) 116 | 117 | except tf.errors.OutOfRangeError: 118 | print('Finish successfully!') 119 | save(saver, sess, snapshot_dir, _step) 120 | break 121 | -------------------------------------------------------------------------------- /train_scripts/train_temporal_annotation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | from constant import const 5 | from models import prediction_networks_dict 6 | from utils.dataloaders.temporal_triplet_loader import DataTemporalTripletLoader 7 | from utils.util import load, save, psnr_error 8 | 9 | 10 | os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID" 11 | os.environ['CUDA_VISIBLE_DEVICES'] = const.GPUS[0] 12 | 13 | dataset_name = const.DATASET 14 | train_folder = const.TRAIN_FOLDER 15 | test_folder = const.TEST_FOLDER 16 | frame_mask = const.FRAME_MASK 17 | pixel_mask = const.PIXEL_MASK 18 | k_folds = const.K_FOLDS 19 | kth = const.KTH 20 | interval = const.INTERVAL 21 | multi_interval = const.MULTI_INTERVAL 22 | 23 | batch_size = const.BATCH_SIZE 24 | iterations = const.ITERATIONS 25 | num_his = const.NUM_HIS 26 | height, width = const.HEIGHT, const.WIDTH 27 | 28 | prednet = prediction_networks_dict[const.PREDNET] 29 | 30 | margin = const.MARGIN 31 | lam = const.LAMBDA 32 | 33 | model_save_freq = const.MODEL_SAVE_FREQ 34 | summary_dir = const.SUMMARY_DIR 35 | snapshot_dir = const.SNAPSHOT_DIR 36 | 37 | print(const) 38 | 39 | # define dataset 40 | # noinspection PyUnboundLocalVariable 41 | with tf.name_scope('dataset'): 42 | triplet_loader = DataTemporalTripletLoader(dataset=dataset_name, train_folder=train_folder, 43 | test_folder=test_folder, k_folds=k_folds, kth=kth, 44 | frame_mask_file=frame_mask, pixel_mask_file=pixel_mask, 45 | resize_height=height, resize_width=width) 46 | 47 | train_dataset = triplet_loader(batch_size, time_steps=(num_his + 1), interval=interval) 48 | 49 | train_it = train_dataset.make_one_shot_iterator() 50 | train_tensor = train_it.get_next() 51 | train_tensor.set_shape([batch_size, 3, (num_his + 1), height, width, 3]) 52 | 53 | train_anchor = train_tensor[:, 0, 0:num_his, ...] 54 | train_anchor_gt = train_tensor[:, 0, -1, ...] 55 | 56 | train_positive = train_tensor[:, 1, 0:num_his, ...] 57 | train_positive_gt = train_tensor[:, 1, -1, ...] 58 | 59 | train_negative = train_tensor[:, 2, 0:num_his, ...] 60 | train_negative_gt = train_tensor[:, 2, -1, ...] 61 | 62 | 63 | # define training generator function 64 | with tf.variable_scope('generator', reuse=None): 65 | train_anchor_output, train_anchor_feature, _ = prednet(train_anchor, use_decoder=True) 66 | with tf.variable_scope('generator', reuse=True): 67 | train_positive_output, train_positive_feature, _ = prednet(train_positive, use_decoder=True) 68 | train_positive_psnr = psnr_error(train_positive_output, train_positive_gt) 69 | with tf.variable_scope('generator', reuse=True): 70 | train_negative_output, train_negative_feature, _ = prednet(train_negative, use_decoder=True) 71 | train_negative_psnr = psnr_error(train_negative_output, train_negative_gt) 72 | 73 | 74 | with tf.name_scope('training'): 75 | pred_loss = tf.reduce_mean(tf.abs(train_anchor_output - train_anchor_gt)) + \ 76 | tf.reduce_mean(tf.abs(train_positive_output - train_positive_gt)) 77 | 78 | # inter-class and intra-class distance 79 | intra_dis = tf.reduce_mean((train_anchor_feature - train_positive_feature) ** 2) 80 | inter_dis = tf.reduce_mean((train_anchor_feature - train_negative_feature) ** 2) 81 | 82 | # metric learning, triplet loss. 83 | margin_loss = tf.maximum(0.0, margin + intra_dis - inter_dis) 84 | 85 | # reconstruction + triplet 86 | g_loss = pred_loss + lam * margin_loss 87 | g_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='g_step') 88 | g_lrate = tf.train.piecewise_constant(g_step, boundaries=const.LRATE_G_BOUNDARIES, values=const.LRATE_G) 89 | g_optimizer = tf.train.AdamOptimizer(learning_rate=g_lrate, name='g_optimizer') 90 | g_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 91 | 92 | g_train_op = g_optimizer.minimize(g_loss, global_step=g_step, var_list=g_vars, name='g_train_op') 93 | 94 | # add all to summaries 95 | tf.summary.scalar(tensor=g_loss, name='g_loss') 96 | tf.summary.scalar(tensor=pred_loss, name='pred_loss') 97 | tf.summary.scalar(tensor=margin_loss, name='margin_loss') 98 | tf.summary.scalar(tensor=intra_dis, name='intra_dis') 99 | tf.summary.scalar(tensor=inter_dis, name='inter_dis') 100 | tf.summary.scalar(tensor=train_positive_psnr, name='positive_psnr') 101 | tf.summary.scalar(tensor=train_negative_psnr, name='negative_psnr') 102 | tf.summary.image(tensor=train_positive_output, name='positive_output') 103 | tf.summary.image(tensor=train_positive_gt, name='positive_gt') 104 | tf.summary.image(tensor=train_negative_output, name='negative_output') 105 | tf.summary.image(tensor=train_negative_gt, name='negative_gt') 106 | summary_op = tf.summary.merge_all() 107 | 108 | config = tf.ConfigProto() 109 | config.gpu_options.allow_growth = True 110 | with tf.Session(config=config) as sess: 111 | # summaries 112 | summary_writer = tf.summary.FileWriter(summary_dir, graph=sess.graph) 113 | 114 | # initialize weights 115 | sess.run(tf.global_variables_initializer()) 116 | print('Init successfully!') 117 | 118 | # tf saver 119 | saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=None) 120 | restore_var = [v for v in tf.global_variables()] 121 | loader = tf.train.Saver(var_list=restore_var) 122 | if os.path.isdir(snapshot_dir): 123 | ckpt = tf.train.get_checkpoint_state(snapshot_dir) 124 | if ckpt and ckpt.model_checkpoint_path: 125 | load(loader, sess, ckpt.model_checkpoint_path) 126 | else: 127 | print('No checkpoint file found.') 128 | else: 129 | load(loader, sess, snapshot_dir) 130 | 131 | _step, _loss, _summaries = 0, None, None 132 | while _step < iterations: 133 | try: 134 | 135 | _, _step, _inter_dis, _intra_dis, _pred_loss, _margin_loss, _g_loss, _p_psnr, _n_psnr, _summaries = \ 136 | sess.run([g_train_op, g_step, inter_dis, intra_dis, pred_loss, margin_loss, g_loss, 137 | train_positive_psnr, train_negative_psnr, summary_op]) 138 | 139 | print('Training, pred loss = {:.6f}, margin loss = {:.6f}'.format(_pred_loss, _margin_loss)) 140 | if _step % 10 == 0: 141 | print('Iteration = {}, global loss = {:.6f}'.format(_step, _g_loss)) 142 | print(' intra dis = {:.6f}'.format(_intra_dis)) 143 | print(' inter dis = {:.6f}'.format(_inter_dis)) 144 | print(' positive psnr = {:.6f}'.format(_p_psnr)) 145 | print(' negative psnr = {:.6f}'.format(_n_psnr)) 146 | 147 | if _step % 100 == 0: 148 | summary_writer.add_summary(_summaries, global_step=_step) 149 | print('Save summaries...') 150 | 151 | if _step % model_save_freq == 0: 152 | save(saver, sess, snapshot_dir, _step) 153 | 154 | except tf.errors.OutOfRangeError: 155 | print('Finish successfully!') 156 | save(saver, sess, snapshot_dir, _step) 157 | break 158 | -------------------------------------------------------------------------------- /train_scripts/train_tune_video_annotation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | from models import prediction_networks_dict 5 | from utils.dataloaders.tune_video_loader import DataTuneVideoGtLoaderImage 6 | from utils.util import load, psnr_error, save 7 | from constant import const 8 | 9 | 10 | os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID" 11 | os.environ['CUDA_VISIBLE_DEVICES'] = const.GPUS[0] 12 | 13 | 14 | dataset_name = const.DATASET 15 | train_folder = const.TRAIN_FOLDER 16 | test_folder = const.TEST_FOLDER 17 | frame_mask = const.FRAME_MASK 18 | pixel_mask = const.PIXEL_MASK 19 | k_folds = const.K_FOLDS 20 | kth = const.KTH 21 | interval = const.INTERVAL 22 | multi_interval = const.MULTI_INTERVAL 23 | 24 | batch_size = const.BATCH_SIZE 25 | iterations = const.ITERATIONS 26 | num_his = const.NUM_HIS 27 | height, width = const.HEIGHT, const.WIDTH 28 | 29 | prednet = prediction_networks_dict[const.PREDNET] 30 | 31 | margin = const.MARGIN 32 | lam = const.LAMBDA 33 | 34 | pretrain_model = const.PRETRAIN_MODEL 35 | psnr_file = const.PSNR_FILE 36 | model_save_freq = const.MODEL_SAVE_FREQ 37 | summary_dir = const.SUMMARY_DIR 38 | snapshot_dir = const.SNAPSHOT_DIR 39 | 40 | 41 | print(const) 42 | 43 | # define dataset 44 | # noinspection PyUnboundLocalVariable 45 | with tf.name_scope('dataset'): 46 | data_loader = DataTuneVideoGtLoaderImage(dataset=dataset_name, train_folder=train_folder, 47 | test_folder=test_folder, 48 | k_folds=k_folds, kth=kth, 49 | frame_mask_file=frame_mask, 50 | psnr_file=psnr_file, 51 | resize_height=height, resize_width=width) 52 | 53 | tf_dataset = data_loader(batch_size, time_steps=(num_his + 1), interval=interval) 54 | train_it = tf_dataset.make_one_shot_iterator() 55 | train_tensor, train_scores = train_it.get_next() 56 | train_tensor.set_shape([batch_size, 3, (num_his + 1), height, width, 3]) 57 | train_scores.set_shape([batch_size, 3, (num_his + 1)]) 58 | 59 | print(train_tensor) 60 | print(train_scores) 61 | 62 | train_anchor = train_tensor[:, 0, 0:num_his, ...] 63 | train_anchor_gt = train_tensor[:, 0, -1, ...] 64 | 65 | train_positive = train_tensor[:, 1, 0:num_his, ...] 66 | train_positive_gt = train_tensor[:, 1, -1, ...] 67 | train_positive_score = train_scores[:, 1, -1] 68 | 69 | train_negative = train_tensor[:, 2, 0:num_his, ...] 70 | train_negative_gt = train_tensor[:, 2, -1, ...] 71 | train_negative_score = train_scores[:, 2, -1] 72 | 73 | # define training generator function 74 | with tf.variable_scope('generator', reuse=None): 75 | train_anchor_output, train_anchor_feature, _ = prednet(train_anchor, use_decoder=True) 76 | train_anchor_psnr = psnr_error(train_anchor_output, train_anchor_gt) 77 | with tf.variable_scope('generator', reuse=True): 78 | train_positive_output, train_positive_feature, _ = prednet(train_positive, use_decoder=True) 79 | train_positive_psnr = psnr_error(train_positive_output, train_positive_gt) 80 | with tf.variable_scope('generator', reuse=True): 81 | train_negative_output, train_negative_feature, _ = prednet(train_negative, use_decoder=True) 82 | train_negative_psnr = psnr_error(train_negative_output, train_negative_gt) 83 | 84 | 85 | with tf.name_scope('training'): 86 | pred_loss = tf.reduce_mean(tf.abs(train_anchor_output - train_anchor_gt)) + \ 87 | tf.reduce_mean(train_positive_score 88 | * tf.reduce_mean(tf.abs(train_positive_output - train_positive_gt), axis=[1, 2, 3])) 89 | # pred_loss = tf.reduce_mean(tf.abs(train_anchor_output - train_anchor_gt)) 90 | 91 | # train_anchor_feature = tf.Print(train_anchor_feature, [tf.reduce_max(train_anchor_feature), tf.reduce_min(train_anchor_feature)]) 92 | # train_positive_feature = tf.Print(train_positive_feature, [tf.reduce_max(train_anchor_feature), tf.reduce_min(train_anchor_feature)]) 93 | # train_negative_feature = tf.Print(train_negative_feature, [tf.reduce_max(train_negative_feature), tf.reduce_min(train_negative_feature)]) 94 | 95 | # inter, between different group 96 | intra_dis = tf.reduce_mean((train_anchor_feature - train_positive_feature) ** 2) 97 | inter_dis = tf.reduce_mean((train_anchor_feature - train_negative_feature) ** 2) 98 | 99 | margin_loss = tf.reduce_mean(tf.abs(train_positive_score - train_negative_score)) * tf.maximum( 100 | 0.0, margin + intra_dis - inter_dis) 101 | 102 | g_loss = pred_loss + lam * margin_loss 103 | g_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='g_step') 104 | g_lrate = tf.train.piecewise_constant(g_step, boundaries=const.LRATE_G_BOUNDARIES, values=const.LRATE_G) 105 | g_optimizer = tf.train.AdamOptimizer(learning_rate=g_lrate, name='g_optimizer') 106 | g_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 107 | 108 | g_train_op = g_optimizer.minimize(g_loss, global_step=g_step, var_list=g_vars, name='g_train_op') 109 | 110 | # add all to summaries 111 | tf.summary.scalar(tensor=g_loss, name='g_loss') 112 | tf.summary.scalar(tensor=pred_loss, name='pred_loss') 113 | tf.summary.scalar(tensor=margin_loss, name='margin_loss') 114 | tf.summary.scalar(tensor=intra_dis, name='intra_dis') 115 | tf.summary.scalar(tensor=inter_dis, name='inter_dis') 116 | tf.summary.scalar(tensor=train_anchor_psnr, name='anchor_psnr') 117 | tf.summary.scalar(tensor=train_positive_psnr, name='positive_psnr') 118 | tf.summary.scalar(tensor=train_negative_psnr, name='negative_psnr') 119 | tf.summary.image(tensor=train_positive_output, name='positive_output') 120 | tf.summary.image(tensor=train_positive_gt, name='positive_gt') 121 | tf.summary.image(tensor=train_negative_output, name='negative_output') 122 | tf.summary.image(tensor=train_negative_gt, name='negative_gt') 123 | summary_op = tf.summary.merge_all() 124 | 125 | 126 | config = tf.ConfigProto() 127 | config.gpu_options.allow_growth = True 128 | with tf.Session(config=config) as sess: 129 | # summaries 130 | summary_writer = tf.summary.FileWriter(summary_dir, graph=sess.graph) 131 | 132 | # initialize weights 133 | sess.run(tf.global_variables_initializer()) 134 | print('Init successfully!') 135 | 136 | # tf saver 137 | saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=None) 138 | restore_var = [v for v in tf.global_variables()] 139 | loader = tf.train.Saver(var_list=restore_var) 140 | if os.path.isdir(snapshot_dir): 141 | ckpt = tf.train.get_checkpoint_state(snapshot_dir) 142 | if ckpt and ckpt.model_checkpoint_path: 143 | load(loader, sess, ckpt.model_checkpoint_path) 144 | elif pretrain_model: 145 | load(loader, sess, pretrain_model) 146 | print('Pretrain model from {}.'.format(pretrain_model)) 147 | else: 148 | print('No checkpoint file found.') 149 | else: 150 | load(loader, sess, snapshot_dir) 151 | 152 | _step, _loss, _summaries = 0, None, None 153 | while _step < iterations: 154 | try: 155 | _, _step, _inter_dis, _intra_dis, _pred_loss, _margin_loss, _g_loss, _a_psnr, _p_psnr, _n_psnr = \ 156 | sess.run([g_train_op, g_step, inter_dis, intra_dis, pred_loss, margin_loss, g_loss, 157 | train_anchor_psnr, train_positive_psnr, train_negative_psnr]) 158 | 159 | print('Training, pred loss = {:.6f}, margin loss = {:.6f}'.format(_pred_loss, _margin_loss)) 160 | if _step % 10 == 0: 161 | print('Iteration = {}, global loss = {:.6f}'.format(_step, _g_loss)) 162 | print(' intra dis = {:.6f}'.format(_intra_dis)) 163 | print(' inter dis = {:.6f}'.format(_inter_dis)) 164 | print(' anchor psnr = {:.6f}'.format(_a_psnr)) 165 | print(' positive psnr = {:.6f}'.format(_p_psnr)) 166 | print(' negative psnr = {:.6f}'.format(_n_psnr)) 167 | 168 | if _step % 250 == 0: 169 | _summaries = sess.run(summary_op) 170 | summary_writer.add_summary(_summaries, global_step=_step) 171 | print('Save summaries...') 172 | 173 | if _step % model_save_freq == 0: 174 | save(saver, sess, snapshot_dir, _step) 175 | 176 | except tf.errors.OutOfRangeError: 177 | print('Finish successfully!') 178 | save(saver, sess, snapshot_dir, _step) 179 | break 180 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/MLEP/22e96634fe43a5b8394413a1aae267b673b3c4ce/utils/__init__.py -------------------------------------------------------------------------------- /utils/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module define and implements all the dataloaders. 3 | """ 4 | from abc import ABCMeta, abstractmethod 5 | import numpy as np 6 | import scipy.io as scio 7 | from collections import OrderedDict 8 | import os 9 | import glob 10 | import json 11 | import imageio 12 | import pickle 13 | 14 | RNG = np.random.RandomState(2017) 15 | 16 | 17 | class LazyProperty(object): 18 | def __init__(self, func): 19 | self.func = func 20 | 21 | def __get__(self, instance, owner): 22 | if instance is None: 23 | return self 24 | else: 25 | value = self.func(instance) 26 | setattr(instance, self.func.__name__, value) 27 | return value 28 | 29 | 30 | class BaseDataAbstractLoader(object): 31 | __metaclass__ = ABCMeta 32 | 33 | def __init__(self, dataset, folder, resize_height, resize_width): 34 | self.dataset = dataset 35 | self.folder = folder 36 | self.resize_height = resize_height 37 | self.resize_width = resize_width 38 | 39 | @abstractmethod 40 | def get_video_clip(self, video, start, end, interval=1): 41 | pass 42 | 43 | @abstractmethod 44 | def get_video_names(self): 45 | pass 46 | 47 | @staticmethod 48 | def sample_normal_clip_list(videos_info, time_steps, interval=1): 49 | video_clips_list = [] 50 | 51 | for video_name, video_info in videos_info.items(): 52 | length = video_info['length'] 53 | images_paths = video_info['images'] 54 | 55 | for t in range(1, interval + 1): 56 | inv = t * time_steps 57 | 58 | for start in range(0, length): 59 | end = start + inv 60 | if end > length: 61 | break 62 | 63 | video_clips = images_paths[start:end:t] 64 | video_clips_list.append(video_clips) 65 | 66 | # flip sequence 67 | video_clips_list.append(list(reversed(video_clips))) 68 | 69 | print('sample video {} at time {}.'.format(video_name, t)) 70 | 71 | np.random.shuffle(video_clips_list) 72 | return video_clips_list 73 | 74 | @staticmethod 75 | def sample_normal_abnormal_clips(videos_info, time_steps, interval=1): 76 | normal_clips_list = [] 77 | normal_number_list = [] 78 | abnormal_clips_list = [] 79 | abnormal_number_list = [] 80 | 81 | # for shanghaitech and avenue which has been trained 82 | # at_least_anomalies = 2 83 | # at_least_anomalies = time_steps // 3 + 1 84 | at_least_anomalies = time_steps 85 | 86 | for v_name, v_info in videos_info.items(): 87 | length = v_info['length'] 88 | images_paths = v_info['images'] 89 | frame_mask = v_info['frame_mask'] 90 | 91 | video_normal_clips = [] 92 | video_abnormal_clips = [] 93 | 94 | for t in range(1, interval + 1): 95 | inv = t * time_steps 96 | 97 | for start in range(0, length - inv): 98 | end = start + inv 99 | video_clips = images_paths[start:end:t] 100 | reversed_video_clips = list(reversed(video_clips)) 101 | 102 | # check is normal or abnormal 103 | if len(frame_mask) != 0 and np.count_nonzero(frame_mask[start:end:t]) >= at_least_anomalies: 104 | video_abnormal_clips.append(video_clips) 105 | video_abnormal_clips.append(reversed_video_clips) 106 | else: 107 | video_normal_clips.append(video_clips) 108 | video_normal_clips.append(reversed_video_clips) 109 | 110 | print('sample video {} at time {}.'.format(v_name, t)) 111 | 112 | if len(video_normal_clips) != 0: 113 | normal_clips_list.append(video_normal_clips) 114 | normal_number_list.append(len(video_normal_clips)) 115 | if len(video_abnormal_clips) != 0: 116 | abnormal_clips_list.append(video_abnormal_clips) 117 | abnormal_number_list.append(len(video_abnormal_clips)) 118 | 119 | video_clips_dict = { 120 | 'normal': normal_clips_list, 121 | 'normal_numbers': normal_number_list, 122 | 'abnormal': abnormal_clips_list, 123 | 'abnormal_numbers': abnormal_number_list 124 | } 125 | return video_clips_dict 126 | 127 | @staticmethod 128 | def sample_normal_abnormal_clips_scores(videos_info, time_steps, interval=1, max_scores=0.5, min_scores=0.5): 129 | normal_clips_list = [] 130 | normal_scores_list = [] 131 | normal_number_list = [] 132 | 133 | abnormal_clips_list = [] 134 | abnormal_scores_list = [] 135 | abnormal_number_list = [] 136 | 137 | at_least_anomalies = min_scores * time_steps 138 | at_most_anomalies = max_scores * time_steps 139 | 140 | for v_name, v_info in videos_info.items(): 141 | length = v_info['length'] 142 | images_paths = v_info['images'] 143 | 144 | if 'scores' in v_info: 145 | scores = v_info['scores'] 146 | else: 147 | scores = np.ones(shape=(length,), dtype=np.float32) 148 | 149 | video_normal_clips = [] 150 | video_normal_scores = [] 151 | video_abnormal_clips = [] 152 | video_abnormal_scores = [] 153 | 154 | for t in range(1, interval + 1): 155 | inv = t * time_steps 156 | 157 | for start in range(0, length - inv): 158 | end = start + inv 159 | video_clips = images_paths[start:end:t] 160 | reversed_video_clips = list(reversed(video_clips)) 161 | 162 | video_scores = scores[start:end:t] 163 | reversed_video_scores = list(reversed(video_scores)) 164 | 165 | # check is normal or abnormal 166 | total_scores = np.sum(scores[start:end:t]) 167 | if total_scores < at_least_anomalies: 168 | video_abnormal_clips.append(video_clips) 169 | video_abnormal_clips.append(reversed_video_clips) 170 | 171 | video_abnormal_scores.append(video_scores) 172 | video_abnormal_scores.append(reversed_video_scores) 173 | 174 | elif total_scores >= at_most_anomalies: 175 | video_normal_clips.append(video_clips) 176 | video_normal_clips.append(reversed_video_clips) 177 | 178 | video_normal_scores.append(video_scores) 179 | video_normal_scores.append(reversed_video_scores) 180 | 181 | print('sample video {} at time {}.'.format(v_name, t)) 182 | 183 | if len(video_normal_clips) != 0: 184 | normal_clips_list.append(video_normal_clips) 185 | normal_number_list.append(len(video_normal_clips)) 186 | normal_scores_list.append(video_normal_scores) 187 | if len(video_abnormal_clips) != 0: 188 | abnormal_clips_list.append(video_abnormal_clips) 189 | abnormal_scores_list.append(video_abnormal_scores) 190 | abnormal_number_list.append(len(video_abnormal_clips)) 191 | 192 | video_clips_dict = { 193 | 'normal': normal_clips_list, 194 | 'normal_scores': normal_scores_list, 195 | 'normal_numbers': normal_number_list, 196 | 'abnormal': abnormal_clips_list, 197 | 'abnormal_scores': abnormal_scores_list, 198 | 'abnormal_numbers': abnormal_number_list 199 | } 200 | return video_clips_dict 201 | 202 | @staticmethod 203 | def sample_normal_abnormal_clips_masks(videos_info, time_steps, interval=1): 204 | normal_clips_list = [] 205 | normal_pixels_list = [] 206 | normal_number_list = [] 207 | 208 | abnormal_clips_list = [] 209 | abnormal_pixels_list = [] 210 | abnormal_number_list = [] 211 | 212 | at_least_anomalies = time_steps 213 | 214 | for v_name, v_info in videos_info.items(): 215 | length = v_info['length'] 216 | images_paths = v_info['images'] 217 | frame_mask = v_info['frame_mask'] 218 | pixel_mask = v_info['pixel_mask'] 219 | 220 | has_abnormal = len(frame_mask) > 0 221 | 222 | video_normal_clips = [] 223 | video_normal_pixels = [] 224 | video_abnormal_clips = [] 225 | video_abnormal_pixels = [] 226 | 227 | for t in range(1, interval + 1): 228 | inv = t * time_steps 229 | 230 | for start in range(0, length - inv): 231 | end = start + inv 232 | video_clips = images_paths[start:end:t] 233 | reversed_video_clips = list(reversed(video_clips)) 234 | 235 | video_pixels = pixel_mask[start:end:t] 236 | reversed_video_pixels = list(reversed(video_pixels)) 237 | 238 | # check is normal or abnormal 239 | if has_abnormal and np.sum(frame_mask[start:end:t]) >= at_least_anomalies: 240 | video_abnormal_clips.append(video_clips) 241 | video_abnormal_clips.append(reversed_video_clips) 242 | 243 | video_abnormal_pixels.append(video_pixels) 244 | video_abnormal_pixels.append(reversed_video_pixels) 245 | 246 | else: 247 | video_normal_clips.append(video_clips) 248 | video_normal_clips.append(reversed_video_clips) 249 | 250 | video_normal_pixels.append(video_pixels) 251 | video_normal_pixels.append(reversed_video_pixels) 252 | 253 | print('sample video {} at time {}.'.format(v_name, t)) 254 | 255 | if len(video_normal_clips) != 0: 256 | normal_clips_list.append(video_normal_clips) 257 | normal_pixels_list.append(video_normal_pixels) 258 | normal_number_list.append(len(video_normal_clips)) 259 | if len(video_abnormal_clips) != 0: 260 | abnormal_clips_list.append(video_abnormal_clips) 261 | abnormal_pixels_list.append(video_abnormal_pixels) 262 | abnormal_number_list.append(len(video_abnormal_clips)) 263 | 264 | video_clips_dict = { 265 | 'normal': normal_clips_list, 266 | 'normal_pixels': normal_pixels_list, 267 | 'normal_numbers': normal_number_list, 268 | 'abnormal': abnormal_clips_list, 269 | 'abnormal_pixels': abnormal_pixels_list, 270 | 'abnormal_numbers': abnormal_number_list 271 | } 272 | return video_clips_dict 273 | 274 | @staticmethod 275 | def sample_fragments_seq_num(videos_info, time_steps, interval=1, fragments=20): 276 | fragment_clip_list = [] 277 | video_path_list = [] 278 | 279 | for v_info in videos_info: 280 | length = v_info['length'] 281 | video_path = v_info['path'] 282 | 283 | fragment_clip = [] 284 | fragment_size = int(np.ceil(length / fragments)) 285 | for frag_start in range(0, length, fragment_size): 286 | frag_end = min(frag_start + fragment_size, length) 287 | 288 | itv_fragment = [] 289 | for t in range(1, interval + 1): 290 | inv = t * time_steps 291 | if inv > frag_end - frag_start: 292 | break 293 | start = RNG.randint(frag_start, frag_end - inv) 294 | 295 | clip = list(range(start, start + inv, t)) 296 | reversed_clip = list(reversed(clip)) 297 | 298 | itv_fragment.append(clip) 299 | itv_fragment.append(reversed_clip) 300 | if len(itv_fragment) > 0: 301 | fragment_clip.append(itv_fragment) 302 | 303 | if len(fragment_clip) == fragments: 304 | fragment_clip_list.append(fragment_clip) 305 | video_path_list.append(video_path) 306 | 307 | return fragment_clip_list, video_path_list 308 | 309 | @staticmethod 310 | def sample_fragments_seq_path(videos_info, time_steps, interval=1, multi_interval=False, fragments=20): 311 | if multi_interval: 312 | interval_list = range(1, interval + 1) 313 | else: 314 | interval_list = range(interval, interval + 1) 315 | 316 | fragment_clip_list = [] 317 | 318 | for v_info in videos_info: 319 | length = v_info['length'] 320 | images_paths = v_info['images'] 321 | 322 | fragment_clip = [] 323 | fragment_size = int(np.ceil(length / fragments)) 324 | for frag_start in range(0, length, fragment_size): 325 | frag_end = min(frag_start + fragment_size, length) 326 | 327 | itv_fragment = [] 328 | 329 | for t in interval_list: 330 | inv = t * time_steps 331 | if inv > frag_end - frag_start: 332 | break 333 | for start in range(frag_start, frag_end - inv): 334 | clip = images_paths[start: start + inv: t] 335 | reversed_clip = list(reversed(clip)) 336 | 337 | itv_fragment.append(clip) 338 | itv_fragment.append(reversed_clip) 339 | if len(itv_fragment) > 0: 340 | fragment_clip.append(itv_fragment) 341 | 342 | if len(fragment_clip) == fragments: 343 | fragment_clip_list.append(fragment_clip) 344 | 345 | return fragment_clip_list 346 | 347 | @staticmethod 348 | def parse_videos_folder(folder): 349 | print('parsing video folder = {}'.format(folder)) 350 | 351 | videos_info = OrderedDict() 352 | for video_name in sorted(os.listdir(folder)): 353 | images_paths = glob.glob(os.path.join(folder, video_name, '*')) 354 | images_paths.sort() 355 | length = len(images_paths) 356 | 357 | videos_info[video_name] = { 358 | 'length': length, 359 | 'images': images_paths, 360 | 'frame_mask': [], 361 | 'pixel_mask': [] 362 | } 363 | 364 | print('parsing video successfully...') 365 | return videos_info 366 | 367 | @staticmethod 368 | def parser_videos_images_txt_split_classes(folder, txt_file, frame_mask_folder=''): 369 | print('parsing video txt = {}'.format(txt_file)) 370 | 371 | videos_info = OrderedDict() 372 | with open(txt_file, 'r') as reader: 373 | def add_to_videos_info(video, video_class, video_path): 374 | # video_folder/Abuse/Abuse001_x264 375 | assert os.path.exists(video_path), 'video = {} dose not exist!'.format(video_path) 376 | 377 | # check and load temporal annotation 378 | if frame_mask_folder: 379 | # frame_folder/Abuse001_x264 380 | frame_mat = os.path.join(frame_mask_folder, os.path.split(video_path)[-1] + '.mat') 381 | load_frame = scio.loadmat(frame_mat) 382 | frame_mask = load_frame['Annotation_file']['Anno'][0][0] 383 | else: 384 | frame_mask = [] 385 | 386 | images_paths = glob.glob(os.path.join(video_path, '*')) 387 | images_paths.sort() 388 | length = len(images_paths) 389 | 390 | info = { 391 | 'length': length, 392 | 'images': images_paths, 393 | 'frame_mask': frame_mask 394 | } 395 | 396 | if 'normal' in video_class.lower(): 397 | class_label = 'Normal' 398 | else: 399 | class_label = video_class 400 | 401 | if class_label in videos_info: 402 | videos_info[class_label][video] = info 403 | else: 404 | videos_info[class_label] = {video: info} 405 | 406 | for line in reader: 407 | line = line.strip() 408 | 409 | # Abuse/Abuse001_x264.mp4 410 | splits = line.split('/') 411 | video_class = str(splits[0]) 412 | 413 | # video_folder/Abuse/Abuse001_x264 414 | video_path = os.path.join(folder, video_class, str(splits[-1].split('.')[0])) 415 | add_to_videos_info(line, video_class, video_path) 416 | print(txt_file, line) 417 | 418 | return videos_info 419 | 420 | @staticmethod 421 | def parser_videos_images_txt(folder, txt_file, frame_mask_folder=''): 422 | print('parsing video txt = {}'.format(txt_file)) 423 | 424 | videos_info = OrderedDict() 425 | with open(txt_file, 'r') as reader: 426 | def add_to_videos_info(video, video_class, video_path): 427 | # video_folder/Abuse/Abuse001_x264 428 | assert os.path.exists(video_path), 'video = {} dose not exist!'.format(video_path) 429 | 430 | # check and load temporal annotation 431 | if frame_mask_folder: 432 | 433 | # frame_folder/Abuse001_x264 434 | frame_mat = os.path.join(frame_mask_folder, os.path.split(video_path)[-1] + '.mat') 435 | load_frame = scio.loadmat(frame_mat) 436 | frame_mask = load_frame['Annotation_file']['Anno'][0][0] 437 | else: 438 | frame_mask = [] 439 | 440 | images_paths = glob.glob(os.path.join(video_path, '*')) 441 | images_paths.sort() 442 | length = len(images_paths) 443 | 444 | info = { 445 | 'length': length, 446 | 'images': images_paths, 447 | 'frame_mask': frame_mask 448 | } 449 | videos_info[video] = info 450 | 451 | for line in reader: 452 | line = line.strip() 453 | 454 | # Abuse/Abuse001_x264.mp4 455 | splits = line.split('/') 456 | video_class = str(splits[0]) 457 | 458 | # video_folder/Abuse/Abuse001_x264 459 | video_path = os.path.join(folder, video_class, str(splits[-1].split('.')[0])) 460 | add_to_videos_info(line, video_class, video_path) 461 | print(txt_file, line) 462 | 463 | return videos_info 464 | 465 | @staticmethod 466 | def parser_videos_images_json(folder, frame_mask_file=''): 467 | print('parsing video json = {}'.format(frame_mask_file)) 468 | 469 | videos_info = OrderedDict() 470 | with open(frame_mask_file, 'r') as file: 471 | data = json.load(file) 472 | 473 | for video_name in sorted(os.listdir(folder)): 474 | images_paths = glob.glob(os.path.join(folder, video_name, '*')) 475 | images_paths.sort() 476 | length = len(images_paths) 477 | 478 | assert length == data[video_name]['length'] 479 | anomalies = data[video_name]['anomalies'] 480 | 481 | frame_mask = [] 482 | for event in anomalies: 483 | for name, annotation in event.items(): 484 | frame_mask.append(annotation) 485 | 486 | videos_info[video_name] = { 487 | 'length': length, 488 | 'images': images_paths, 489 | 'frame_mask': frame_mask, 490 | 'pixel_mask': [] 491 | } 492 | 493 | print('parsing video successfully...') 494 | return videos_info 495 | 496 | @staticmethod 497 | def parser_videos_paths_txt(folder, txt_file, frame_mask_file=''): 498 | print('parsing video txt = {}'.format(txt_file)) 499 | 500 | videos_info = OrderedDict() 501 | with open(txt_file, 'r') as reader: 502 | def add_to_videos_info(video_class, video_path): 503 | assert os.path.exists(video_path), 'video = {} dose not exist!'.format(video_path) 504 | 505 | # get length 506 | vid = imageio.get_reader(video_path, 'ffmpeg') 507 | video_length = vid.get_length() 508 | 509 | # check and load temporal annotation 510 | if frame_mask_file: 511 | frame_mat = os.path.join(frame_mask_file, os.path.split(video_path)[-1].split('.')[0] + '.mat') 512 | load_frame = scio.loadmat(frame_mat) 513 | frame_mask = load_frame['Annotation_file']['Anno'][0][0] 514 | else: 515 | frame_mask = [] 516 | 517 | info = { 518 | 'length': video_length, 519 | 'path': video_path, 520 | 'frame_mask': frame_mask 521 | } 522 | 523 | if 'normal' in video_class.lower(): 524 | class_label = 'Normal' 525 | else: 526 | class_label = video_class 527 | 528 | if class_label in videos_info: 529 | videos_info[class_label].append(info) 530 | else: 531 | videos_info[class_label] = [info] 532 | vid.close() 533 | 534 | for line in reader: 535 | line = line.strip() 536 | 537 | # Abuse/Abuse001_x264.mp4 538 | video_class = str(line.split('/')[0]) 539 | video_path = os.path.join(folder, line) 540 | 541 | add_to_videos_info(video_class, video_path) 542 | print(txt_file, line) 543 | 544 | return videos_info 545 | 546 | @staticmethod 547 | def parser_paths_txt(folder, txt_file): 548 | print('parsing video txt = {}'.format(txt_file)) 549 | 550 | videos_info = {} 551 | with open(txt_file, 'r') as reader: 552 | def add_to_videos_info(video_class, video_path): 553 | assert os.path.exists(video_path), 'video = {} dose not exist!'.format(video_path) 554 | if 'normal' in video_class.lower(): 555 | class_label = 'Normal' 556 | else: 557 | class_label = video_class 558 | 559 | if class_label in videos_info: 560 | videos_info[class_label].append(video_path) 561 | else: 562 | videos_info[class_label] = [video_path] 563 | 564 | for line in reader: 565 | line = line.strip() 566 | 567 | # Abuse/Abuse001_x264.mp4 568 | video_class = str(line.split('/')[0]) 569 | video_path = os.path.join(folder, line) 570 | 571 | add_to_videos_info(video_class, video_path) 572 | print(txt_file, line) 573 | 574 | return videos_info 575 | 576 | @staticmethod 577 | def load_frame_scores(psnr_file): 578 | scores = [] 579 | with open(psnr_file, 'rb') as reader: 580 | # results { 581 | # 'dataset': the name of dataset 582 | # 'psnr': the psnr of each testing videos, 583 | # } 584 | 585 | # psnr_records['psnr'] is np.array, shape(#videos) 586 | # psnr_records[0] is np.array ------> 01.avi 587 | # psnr_records[1] is np.array ------> 02.avi 588 | # ...... 589 | # psnr_records[n] is np.array ------> xx.avi 590 | 591 | results = pickle.load(reader) 592 | psnr_records = results['psnr'] 593 | 594 | for psnr in psnr_records: 595 | score = (psnr - psnr.min()) / (psnr.max() - psnr.min()) 596 | scores.append(score) 597 | 598 | return scores 599 | 600 | @staticmethod 601 | def load_frame_psnrs(psnr_file, threshold=None): 602 | psnrs_records = [] 603 | with open(psnr_file, 'rb') as reader: 604 | 605 | results = pickle.load(reader) 606 | psnr_records = results['psnr'] 607 | 608 | for i, psnrs in enumerate(psnr_records): 609 | if threshold: 610 | invalid_index = np.logical_or(np.isnan(psnrs), np.isinf(psnrs)) 611 | psnrs[invalid_index] = threshold + 1 612 | 613 | too_big_index = np.logical_or(invalid_index, psnrs > threshold) 614 | not_too_big_index = np.logical_not(too_big_index) 615 | 616 | psnr_max = np.max(psnrs[not_too_big_index]) 617 | psnrs[too_big_index] = psnr_max 618 | 619 | psnrs_records.append(psnrs) 620 | return psnrs_records 621 | 622 | @staticmethod 623 | def filter_caption_psnrs(x): 624 | length = len(x) 625 | # p = np.zeros(shape=(length,), dtype=np.float32) 626 | 627 | x_mean = np.max(x) 628 | mid_idx = length // 2 629 | # p[mid_idx] = x[mid_idx] 630 | delta = length - mid_idx 631 | for i in range(mid_idx + 1, length): 632 | alpha = (i - mid_idx) / delta 633 | x[i] = alpha * x_mean + (1 - alpha) * x[i] 634 | 635 | for i in range(mid_idx - 1, -1, -1): 636 | alpha = (mid_idx - i) / delta 637 | x[i] = alpha * x_mean + (1 - alpha) * x[i] 638 | 639 | return x 640 | 641 | def load_frame_mask(self, videos_info, gt_file_path): 642 | # initialize the load frame mask function 643 | if self.dataset in ['ped1', 'ped2', 'avenue', 'enter', 'exit']: 644 | if gt_file_path.endswith('.json'): 645 | frame_mask = self._load_json_gt_file(gt_file_path) 646 | else: 647 | frame_mask = self._load_ucsd_avenue_subway_gt(videos_info, gt_file_path) 648 | elif self.dataset == 'shanghaitech': 649 | frame_mask = self._load_shanghaitech_gt(gt_file_path) 650 | else: 651 | print('Warning, dataset {} is not in {}, be careful when loading the frame mask and ' 652 | 'here, we use _load_uscd_avenue_subway_gt()'.format(self.dataset, 653 | ['ped1', 'ped2', 'avenue', 'enter', 'exit', 654 | 'shanghaitech'])) 655 | frame_mask = self._load_ucsd_avenue_subway_gt(videos_info, gt_file_path) 656 | 657 | return frame_mask 658 | 659 | @staticmethod 660 | def load_semantic_frame_mask(json_file): 661 | semantic_gts = [] 662 | anomalies_names = [] 663 | anomalies_names_set = set() 664 | 665 | with open(json_file, 'r') as file: 666 | data = json.load(file) 667 | video_names = list(sorted(data.keys())) 668 | 669 | for video_name in video_names: 670 | info = data[video_name] 671 | anomalies = info['anomalies'] 672 | length = info['length'] 673 | semantic_label = ['normal'] * length 674 | name_set = set() 675 | 676 | for event in anomalies: 677 | for name, annotation in event.items(): 678 | for start, end in annotation: 679 | semantic_label[start:end] = [name] * (end - start) 680 | 681 | name_set.add(name) 682 | 683 | semantic_gts.append(semantic_label) 684 | anomalies_names.append(name_set) 685 | anomalies_names_set |= name_set 686 | 687 | print('gt json file = {}'.format(json_file)) 688 | return semantic_gts, anomalies_names, anomalies_names_set 689 | 690 | @staticmethod 691 | def _load_json_gt_file(json_file): 692 | gts = [] 693 | with open(json_file, 'r') as file: 694 | data = json.load(file) 695 | video_names = list(sorted(data.keys())) 696 | 697 | for video_name in video_names: 698 | info = data[video_name] 699 | anomalies = info['anomalies'] 700 | length = info['length'] 701 | label = np.zeros((length,), dtype=np.int8) 702 | 703 | for event in anomalies: 704 | for name, annotation in event.items(): 705 | for start, end in annotation: 706 | label[start - 1: end] = 1 707 | gts.append(label) 708 | 709 | print('gt json file = {}'.format(json_file)) 710 | return gts 711 | 712 | @staticmethod 713 | def _load_ucsd_avenue_subway_gt(videos_info, gt_file_path): 714 | """ 715 | :param videos_info: videos information, parsed by parse_videos_folder 716 | :param gt_file_path: the path of gt file 717 | :type videos_info: dict or OrderedDict 718 | :type gt_file_path: str 719 | :return: 720 | """ 721 | assert os.path.exists(gt_file_path), 'gt file path = {} dose not exits!'.format(gt_file_path) 722 | 723 | abnormal_events = scio.loadmat(gt_file_path, squeeze_me=True)['gt'] 724 | 725 | if abnormal_events.ndim == 2: 726 | abnormal_events = abnormal_events.reshape(-1, abnormal_events.shape[0], abnormal_events.shape[1]) 727 | 728 | num_video = abnormal_events.shape[0] 729 | assert num_video == len(videos_info), 'ground true does not match the number of testing videos. {} != {}' \ 730 | .format(num_video, len(videos_info)) 731 | 732 | # need to test [].append, or np.array().append(), which one is faster 733 | gt = [] 734 | for i, video_info in enumerate(videos_info.values()): 735 | length = video_info['length'] 736 | 737 | sub_video_gt = np.zeros((length,), dtype=np.int8) 738 | sub_abnormal_events = abnormal_events[i] 739 | if sub_abnormal_events.ndim == 1: 740 | sub_abnormal_events = sub_abnormal_events.reshape((sub_abnormal_events.shape[0], -1)) 741 | 742 | _, num_abnormal = sub_abnormal_events.shape 743 | 744 | for j in range(num_abnormal): 745 | # (start - 1, end - 1) 746 | start = sub_abnormal_events[0, j] - 1 747 | end = sub_abnormal_events[1, j] 748 | 749 | sub_video_gt[start: end] = 1 750 | 751 | gt.append(sub_video_gt) 752 | 753 | return gt 754 | 755 | @staticmethod 756 | def _load_shanghaitech_gt(gt_file_folder): 757 | """ 758 | :param gt_file_folder: the folder path of test_frame_mask of ShanghaiTech dataset. 759 | :type gt_file_folder: str 760 | :return: 761 | """ 762 | video_path_list = os.listdir(gt_file_folder) 763 | video_path_list.sort() 764 | 765 | gt = [] 766 | for video in video_path_list: 767 | gt.append(np.load(os.path.join(gt_file_folder, video))) 768 | 769 | return gt 770 | 771 | @staticmethod 772 | def load_pixel_mask_file_list(videos_info, pixel_mask_folder=''): 773 | """ 774 | :param videos_info: videos information, parsed by parse_videos_folder 775 | :param pixel_mask_folder: the path of pixel mask folder 776 | :type videos_info: dict or OrderedDict 777 | :type pixel_mask_folder: str 778 | :return: 779 | """ 780 | 781 | if pixel_mask_folder: 782 | pixel_mask_file_list = os.listdir(pixel_mask_folder) 783 | pixel_mask_file_list.sort() 784 | 785 | num_videos = len(videos_info) 786 | assert num_videos == len(pixel_mask_file_list), \ 787 | 'ground true does not match the number of testing videos. {} != {}'.format( 788 | num_videos, len(pixel_mask_file_list)) 789 | 790 | for video_name, pixel_mask_file in zip(videos_info.keys(), pixel_mask_file_list): 791 | assert video_name + '.npy' == pixel_mask_file, 'video name {} does not have pixel mask {}'.format( 792 | video_name, pixel_mask_file 793 | ) 794 | 795 | for i in range(num_videos): 796 | pixel_mask_file_list[i] = os.path.join(pixel_mask_folder, pixel_mask_file_list[i]) 797 | else: 798 | pixel_mask_file_list = [] 799 | 800 | return pixel_mask_file_list 801 | 802 | @staticmethod 803 | def load_image_mask_file_list(videos_info, pixel_mask_folder=''): 804 | """ 805 | :param videos_info: videos information, parsed by parse_videos_folder 806 | :param pixel_mask_folder: the path of pixel mask folder 807 | :type videos_info: dict or OrderedDict 808 | :type pixel_mask_folder: str 809 | :return: 810 | """ 811 | 812 | if pixel_mask_folder: 813 | pixel_mask_file_list = os.listdir(pixel_mask_folder) 814 | pixel_mask_file_list.sort() 815 | 816 | num_videos = len(videos_info) 817 | assert num_videos == len(pixel_mask_file_list), \ 818 | 'ground true does not match the number of testing videos. {} != {}'.format( 819 | num_videos, len(pixel_mask_file_list)) 820 | 821 | for video_name, pixel_mask_file in zip(videos_info.keys(), pixel_mask_file_list): 822 | assert video_name == pixel_mask_file, 'video name {} does not have pixel mask {}'.format( 823 | video_name, pixel_mask_file 824 | ) 825 | 826 | image_mask_files = [] 827 | for i in range(num_videos): 828 | mask_files = glob.glob(os.path.join(pixel_mask_folder, pixel_mask_file_list[i], '*.jpg')) 829 | mask_files.sort() 830 | image_mask_files.append(mask_files) 831 | else: 832 | image_mask_files = [] 833 | 834 | return image_mask_files 835 | -------------------------------------------------------------------------------- /utils/dataloaders/only_normal_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | from concurrent.futures import ThreadPoolExecutor 5 | 6 | from utils.util import load_frame, multi_scale_crop_load_frame 7 | from utils.dataloaders import BaseDataAbstractLoader, RNG 8 | 9 | 10 | class NormalDataLoader(BaseDataAbstractLoader): 11 | 12 | def __init__(self, dataset, folder, resize_height, resize_width): 13 | super().__init__(dataset=dataset, folder=folder, resize_height=resize_height, resize_width=resize_width) 14 | 15 | self.videos_info = None 16 | self._setup() 17 | 18 | def __call__(self, batch_size, time_steps, interval=1, multi_scale_crop=False): 19 | if multi_scale_crop: 20 | dataset = self.tf_multi_scale_crop_dataset(batch_size, time_steps, interval) 21 | else: 22 | dataset = self.tf_dataset(batch_size, time_steps, interval) 23 | return dataset 24 | 25 | def tf_multi_scale_crop_dataset(self, batch_size, time_steps, interval): 26 | videos_clips_list = self.sample_normal_clip_list(self.videos_info, time_steps, interval) 27 | num_video_clips = len(videos_clips_list) 28 | 29 | crop_size = 224 30 | short_size_list = [224, 256, 384, 480] 31 | 32 | height, width = self.resize_height, self.resize_width 33 | scale = width / height 34 | 35 | def video_clip_generator(): 36 | i = 0 37 | while True: 38 | video_clips_paths = videos_clips_list[i] 39 | video_clips = np.empty(shape=[time_steps, crop_size, crop_size, 3], dtype=np.float32) 40 | 41 | resize_height = RNG.choice(short_size_list) 42 | resize_width = int(scale * resize_height) 43 | start_h = RNG.randint(0, resize_height - crop_size + 1) 44 | start_w = RNG.randint(0, resize_width - crop_size + 1) 45 | crop_bbox = ((start_h, start_h + crop_size), (start_w, start_w + crop_size)) 46 | # print(crop_bbox, resize_height, resize_width) 47 | 48 | with ThreadPoolExecutor(max_workers=time_steps * 5) as pool: 49 | for t_idx, frame in enumerate(pool.map(multi_scale_crop_load_frame, video_clips_paths, 50 | [resize_height] * time_steps, 51 | [resize_width] * time_steps, 52 | [crop_bbox] * time_steps)): 53 | video_clips[t_idx, ...] = frame 54 | 55 | i = (i + 1) % num_video_clips 56 | yield video_clips 57 | 58 | # video clip paths 59 | dataset = tf.data.Dataset.from_generator(generator=video_clip_generator, 60 | output_types=tf.float32, 61 | output_shapes=[time_steps, crop_size, crop_size, 3]) 62 | print('generator dataset, {}'.format(dataset)) 63 | dataset = dataset.prefetch(buffer_size=256) 64 | dataset = dataset.shuffle(buffer_size=256).batch(batch_size) 65 | print('epoch dataset, {}'.format(dataset)) 66 | return dataset 67 | 68 | def tf_dataset(self, batch_size, time_steps, interval): 69 | videos_clips_list = self.sample_normal_clip_list(self.videos_info, time_steps, interval) 70 | num_video_clips = len(videos_clips_list) 71 | 72 | height, width = self.resize_height, self.resize_width 73 | 74 | def video_clip_generator(): 75 | i = 0 76 | while True: 77 | video_clips_paths = videos_clips_list[i] 78 | video_clips = np.empty(shape=[time_steps, height, width, 3], dtype=np.float32) 79 | for t, filename in enumerate(video_clips_paths): 80 | video_clips[t, ...] = load_frame(filename, height, width) 81 | 82 | i = (i + 1) % num_video_clips 83 | yield video_clips 84 | 85 | # video clip paths 86 | dataset = tf.data.Dataset.from_generator(generator=video_clip_generator, 87 | output_types=tf.float32, 88 | output_shapes=[time_steps, height, width, 3]) 89 | print('generator dataset, {}'.format(dataset)) 90 | dataset = dataset.prefetch(buffer_size=256) 91 | dataset = dataset.shuffle(buffer_size=256).batch(batch_size) 92 | print('epoch dataset, {}'.format(dataset)) 93 | return dataset 94 | 95 | def debug(self, batch_size, time_steps, interval=1): 96 | videos_clips_list = self.sample_normal_clip_list(self.videos_info, time_steps, interval) 97 | num_video_clips = len(videos_clips_list) 98 | 99 | crop_size = 224 100 | short_size_list = [224, 256, 384, 480] 101 | 102 | height, width = self.resize_height, self.resize_width 103 | scale = width / height 104 | 105 | def video_clip_generator(): 106 | i = 0 107 | while True: 108 | batch_video_clips = [] 109 | for b_idx in range(batch_size): 110 | video_clips_paths = videos_clips_list[i] 111 | video_clips = np.empty(shape=[time_steps, crop_size, crop_size, 3], dtype=np.float32) 112 | 113 | resize_height = RNG.choice(short_size_list) 114 | resize_width = int(scale * resize_height) 115 | start_h = RNG.randint(0, resize_height - crop_size + 1) 116 | start_w = RNG.randint(0, resize_width - crop_size + 1) 117 | crop_bbox = ((start_h, start_h + crop_size), (start_w, start_w + crop_size)) 118 | # print(crop_bbox, resize_height, resize_width) 119 | 120 | with ThreadPoolExecutor(max_workers=time_steps * 5) as pool: 121 | for t_idx, frame in enumerate(pool.map(multi_scale_crop_load_frame, video_clips_paths, 122 | [resize_height] * time_steps, 123 | [resize_width] * time_steps, 124 | [crop_bbox] * time_steps)): 125 | video_clips[t_idx, ...] = frame 126 | 127 | i = (i + 1) % num_video_clips 128 | 129 | batch_video_clips.append(video_clips) 130 | batch_video_clips = np.stack(batch_video_clips, axis=0) 131 | yield batch_video_clips 132 | 133 | return video_clip_generator 134 | 135 | def _setup(self): 136 | self.videos_info = self.parse_videos_folder(self.folder) 137 | 138 | def get_video_clip(self, video, start, end, interval=1): 139 | # assert video_name in self._videos_info, 'video {} is not in {}!'.format(video_name, self._videos_info.keys()) 140 | # assert (start >= 0) and (start <= end) and (end < self._videos_info[video_name]['length']) 141 | 142 | video_idx = np.arange(start, end, interval) 143 | video_clip = np.empty(shape=[len(video_idx), self.resize_height, self.resize_width, 3], dtype=np.float32) 144 | for idx, v_idx in enumerate(video_idx): 145 | filename = self.videos_info[video]['images'][v_idx] 146 | video_clip[idx, ...] = load_frame(filename, self.resize_height, self.resize_width) 147 | 148 | return video_clip 149 | 150 | def get_video_names(self): 151 | return list(self.videos_info.keys()) 152 | 153 | -------------------------------------------------------------------------------- /utils/dataloaders/temporal_triplet_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from utils.util import load_frame 7 | from utils.dataloaders import BaseDataAbstractLoader, LazyProperty, RNG 8 | 9 | 10 | class DataTemporalGtLoader(BaseDataAbstractLoader): 11 | 12 | def __init__(self, dataset, folder, resize_height, resize_width, k_folds, kth, 13 | frame_mask_file, pixel_mask_file=''): 14 | super().__init__(dataset=dataset, folder=folder, resize_height=resize_height, resize_width=resize_width) 15 | 16 | self.k_folds = k_folds 17 | self.kth = kth 18 | self.frame_mask_file = frame_mask_file 19 | self.pixel_mask_file = pixel_mask_file 20 | 21 | self.val_videos_info = None 22 | self.test_videos_info = None 23 | 24 | self._setup() 25 | 26 | def __call__(self, batch_size, time_steps, interval=1): 27 | val_clips_list = self.sample_normal_abnormal_clips(self.val_videos_info, time_steps, interval) 28 | test_clips_list = self.sample_normal_abnormal_clips(self.test_videos_info, time_steps, interval) 29 | 30 | val_dataset = self.convert_to_tf_dataset(val_clips_list, batch_size, time_steps=time_steps) 31 | test_dataset = self.convert_to_tf_dataset(test_clips_list, batch_size, time_steps=time_steps) 32 | 33 | return val_dataset, test_dataset 34 | 35 | def convert_to_tf_dataset(self, videos_clips_dict, batch_size, time_steps): 36 | normal_clips = videos_clips_dict['normal'] 37 | normal_number_list = videos_clips_dict['normal_numbers'] 38 | abnormal_clips = videos_clips_dict['abnormal'] 39 | abnormal_number_list = videos_clips_dict['abnormal_numbers'] 40 | 41 | num_normal_videos = len(normal_clips) 42 | num_abnormal_videos = len(abnormal_clips) 43 | 44 | height, width = self.resize_height, self.resize_width 45 | 46 | def _sample_video_clip(clips_list, length): 47 | clip = np.empty(shape=[time_steps, height, width, 3], dtype=np.float32) 48 | 49 | sample_idx = RNG.randint(length) 50 | for t, filename in enumerate(clips_list[sample_idx]): 51 | clip[t, ...] = load_frame(filename, height, width) 52 | 53 | return clip 54 | 55 | def video_clip_generator(): 56 | while True: 57 | a_vid, p_vid = RNG.choice(num_normal_videos, size=2) 58 | n_vid = RNG.randint(num_abnormal_videos) 59 | 60 | video_clips = np.empty(shape=[3, time_steps, height, width, 3], dtype=np.float32) 61 | video_clips[0, ...] = _sample_video_clip(normal_clips[a_vid], normal_number_list[a_vid]) 62 | video_clips[1, ...] = _sample_video_clip(normal_clips[p_vid], normal_number_list[p_vid]) 63 | video_clips[2, ...] = _sample_video_clip(abnormal_clips[n_vid], abnormal_number_list[n_vid]) 64 | 65 | # show_triplet_video_clips(video_clips) 66 | yield video_clips 67 | 68 | # video clip paths 69 | dataset = tf.data.Dataset.from_generator(generator=video_clip_generator, 70 | output_types=tf.float32, 71 | output_shapes=[3, time_steps, height, width, 3]) 72 | print('generator dataset, {}'.format(dataset)) 73 | dataset = dataset.prefetch(buffer_size=256) 74 | dataset = dataset.shuffle(buffer_size=256).batch(batch_size) 75 | print('epoch dataset, {}'.format(dataset)) 76 | 77 | return dataset 78 | 79 | def convert_to_tf_dataset_debug(self, videos_clips_dict, batch_size, time_steps): 80 | normal_clips = videos_clips_dict['normal'] 81 | normal_number_list = videos_clips_dict['normal_numbers'] 82 | abnormal_clips = videos_clips_dict['abnormal'] 83 | abnormal_number_list = videos_clips_dict['abnormal_numbers'] 84 | 85 | num_normal_videos = len(normal_clips) 86 | num_abnormal_videos = len(abnormal_clips) 87 | 88 | height, width = self.resize_height, self.resize_width 89 | 90 | def _sample_video_clip(clips_list, length): 91 | clip = np.empty(shape=[time_steps, height, width, 3], dtype=np.float32) 92 | 93 | sample_idx = RNG.randint(length) 94 | for t, filename in enumerate(clips_list[sample_idx]): 95 | clip[t, ...] = load_frame(filename, height, width) 96 | 97 | return clip 98 | 99 | def video_clip_generator(): 100 | while True: 101 | a_vid, p_vid = RNG.choice(num_normal_videos, size=2) 102 | n_vid = RNG.randint(num_abnormal_videos) 103 | 104 | video_clips = np.empty(shape=[3, time_steps, height, width, 3], dtype=np.float32) 105 | video_clips[0, ...] = _sample_video_clip(normal_clips[a_vid], normal_number_list[a_vid]) 106 | video_clips[1, ...] = _sample_video_clip(normal_clips[p_vid], normal_number_list[p_vid]) 107 | video_clips[2, ...] = _sample_video_clip(abnormal_clips[n_vid], abnormal_number_list[n_vid]) 108 | 109 | # show_triplet_video_clips(video_clips) 110 | yield video_clips 111 | 112 | batch_video_clips = [] 113 | for i in range(batch_size): 114 | batch_video_clips.append(next(video_clip_generator())) 115 | batch_video_clips = np.stack(batch_video_clips, axis=0) 116 | return batch_video_clips 117 | 118 | def _setup(self): 119 | videos_info = self.parse_videos_folder(self.folder) 120 | frame_mask = self.load_frame_mask(videos_info, self.frame_mask_file) 121 | pixel_mask_list = self.load_image_mask_file_list(videos_info, self.pixel_mask_file) 122 | 123 | num_videos = len(videos_info) 124 | if self.k_folds != 0: 125 | k_folds_ids = np.array_split(np.arange(num_videos), self.k_folds) 126 | val_ids = k_folds_ids[self.kth - 1].tolist() 127 | test_ids = [] 128 | for k in range(self.k_folds): 129 | if k != (self.kth - 1): 130 | test_ids += k_folds_ids[k].tolist() 131 | else: 132 | val_ids = [] 133 | test_ids = np.arange(num_videos) 134 | 135 | def add_gt_to_info(ids): 136 | videos_info_gt = OrderedDict() 137 | videos_names = list(videos_info.keys()) 138 | 139 | for i in ids: 140 | v_name = videos_names[i] 141 | pixel_mask = pixel_mask_list[i] if pixel_mask_list else [] 142 | 143 | videos_info_gt[v_name] = { 144 | 'length': videos_info[v_name]['length'], 145 | 'images': videos_info[v_name]['images'], 146 | 'frame_mask': frame_mask[i], 147 | 'pixel_mask': pixel_mask 148 | } 149 | 150 | return videos_info_gt 151 | 152 | val_videos_info = add_gt_to_info(val_ids) 153 | test_videos_info = add_gt_to_info(test_ids) 154 | 155 | del videos_info 156 | 157 | self.val_videos_info = val_videos_info 158 | self.test_videos_info = test_videos_info 159 | 160 | def read_video_clip(self, images_paths): 161 | video_clip = [] 162 | for filename in images_paths: 163 | video_clip.append(load_frame(filename, self.resize_height, self.resize_width)) 164 | 165 | video_clip = np.stack(video_clip, axis=0) 166 | return video_clip 167 | 168 | def get_video_clip(self, video, start, end, interval=1): 169 | # assert video_name in self._videos_info, 'video {} is not in {}!'.format(video_name, self._videos_info.keys()) 170 | # assert (start >= 0) and (start <= end) and (end < self._videos_info[video_name]['length']) 171 | 172 | video_idx = np.arange(start, end, interval) 173 | video_clip = np.empty(shape=[len(video_idx), self.resize_height, self.resize_width, 3], dtype=np.float32) 174 | for idx, v_idx in enumerate(video_idx): 175 | filename = self.test_videos_info[video]['images'][v_idx] 176 | video_clip[idx, ...] = load_frame(filename, self.resize_height, self.resize_width) 177 | 178 | return video_clip 179 | 180 | def get_val_video_clip(self, video, start, end, interval=1): 181 | video_idx = np.arange(start, end, interval) 182 | video_clip = np.empty(shape=[len(video_idx), self.resize_height, self.resize_width, 3], dtype=np.float32) 183 | for idx, v_idx in enumerate(video_idx): 184 | filename = self.val_videos_info[video]['images'][v_idx] 185 | video_clip[idx, ...] = load_frame(filename, self.resize_height, self.resize_width) 186 | 187 | return video_clip 188 | 189 | def get_video_names(self): 190 | return list(self.test_videos_info) 191 | 192 | @LazyProperty 193 | def get_frame_mask(self): 194 | frame_mask = [] 195 | for video_info in self.test_videos_info.values(): 196 | frame_mask.append(video_info['frame_mask']) 197 | 198 | return frame_mask 199 | 200 | @LazyProperty 201 | def get_total_frames(self): 202 | total = 0 203 | for video_info in self.test_videos_info.values(): 204 | total += video_info['length'] 205 | return total 206 | 207 | 208 | class DataTemporalTripletLoader(DataTemporalGtLoader): 209 | def __init__(self, dataset, train_folder, test_folder, k_folds, kth, 210 | frame_mask_file, pixel_mask_file='', 211 | resize_height=256, resize_width=256): 212 | super().__init__(dataset, test_folder, resize_height, resize_width, 213 | k_folds=k_folds, kth=kth, frame_mask_file=frame_mask_file, pixel_mask_file=pixel_mask_file) 214 | 215 | self.train_folder = train_folder 216 | self.train_videos_info = self.parse_videos_folder(self.train_folder) 217 | 218 | def __call__(self, batch_size, time_steps, interval=1, is_training=True): 219 | train_clips_dict = self.sample_normal_abnormal_clips(self.train_videos_info, time_steps, interval) 220 | val_clips_dict = self.sample_normal_abnormal_clips(self.val_videos_info, time_steps, interval) 221 | test_clips_dict = self.sample_normal_abnormal_clips(self.test_videos_info, time_steps, interval) 222 | 223 | train_val_clips_dict = { 224 | 'normal': train_clips_dict['normal'] + val_clips_dict['normal'], 225 | 'normal_numbers': train_clips_dict['normal_numbers'] + val_clips_dict['normal_numbers'], 226 | 'abnormal': val_clips_dict['abnormal'], 227 | 'abnormal_numbers': train_clips_dict['abnormal_numbers'] + val_clips_dict['abnormal_numbers'] 228 | } 229 | 230 | if is_training: 231 | dataset = self.convert_to_tf_dataset(train_val_clips_dict, batch_size, time_steps) 232 | else: 233 | dataset = self.convert_to_tf_dataset(test_clips_dict, batch_size, time_steps) 234 | return dataset 235 | -------------------------------------------------------------------------------- /utils/dataloaders/test_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from utils.util import load_frame 7 | from utils.dataloaders import BaseDataAbstractLoader, LazyProperty, RNG 8 | 9 | 10 | class DataTemporalGtLoader(BaseDataAbstractLoader): 11 | 12 | def __init__(self, dataset, folder, resize_height, resize_width, k_folds, kth, 13 | frame_mask_file, pixel_mask_file=''): 14 | super().__init__(dataset=dataset, folder=folder, resize_height=resize_height, resize_width=resize_width) 15 | 16 | self.k_folds = k_folds 17 | self.kth = kth 18 | self.frame_mask_file = frame_mask_file 19 | self.pixel_mask_file = pixel_mask_file 20 | 21 | self.val_videos_info = None 22 | self.test_videos_info = None 23 | 24 | self._setup() 25 | 26 | def __call__(self, batch_size, time_steps, interval=1): 27 | val_clips_list = self.sample_normal_abnormal_clips(self.val_videos_info, time_steps, interval) 28 | test_clips_list = self.sample_normal_abnormal_clips(self.test_videos_info, time_steps, interval) 29 | 30 | val_dataset = self.convert_to_tf_dataset(val_clips_list, batch_size, time_steps=time_steps) 31 | test_dataset = self.convert_to_tf_dataset(test_clips_list, batch_size, time_steps=time_steps) 32 | 33 | return val_dataset, test_dataset 34 | 35 | def convert_to_tf_dataset(self, videos_clips_dict, batch_size, time_steps): 36 | normal_clips = videos_clips_dict['normal'] 37 | normal_number_list = videos_clips_dict['normal_numbers'] 38 | abnormal_clips = videos_clips_dict['abnormal'] 39 | abnormal_number_list = videos_clips_dict['abnormal_numbers'] 40 | 41 | num_normal_videos = len(normal_clips) 42 | num_abnormal_videos = len(abnormal_clips) 43 | 44 | height, width = self.resize_height, self.resize_width 45 | 46 | def _sample_video_clip(clips_list, length): 47 | clip = np.empty(shape=[time_steps, height, width, 3], dtype=np.float32) 48 | 49 | sample_idx = RNG.randint(length) 50 | for t, filename in enumerate(clips_list[sample_idx]): 51 | clip[t, ...] = load_frame(filename, height, width) 52 | 53 | return clip 54 | 55 | def video_clip_generator(): 56 | while True: 57 | a_vid, p_vid = RNG.choice(num_normal_videos, size=2) 58 | n_vid = RNG.randint(num_abnormal_videos) 59 | 60 | video_clips = np.empty(shape=[3, time_steps, height, width, 3], dtype=np.float32) 61 | video_clips[0, ...] = _sample_video_clip(normal_clips[a_vid], normal_number_list[a_vid]) 62 | video_clips[1, ...] = _sample_video_clip(normal_clips[p_vid], normal_number_list[p_vid]) 63 | video_clips[2, ...] = _sample_video_clip(abnormal_clips[n_vid], abnormal_number_list[n_vid]) 64 | 65 | # show_triplet_video_clips(video_clips) 66 | yield video_clips 67 | 68 | # video clip paths 69 | dataset = tf.data.Dataset.from_generator(generator=video_clip_generator, 70 | output_types=tf.float32, 71 | output_shapes=[3, time_steps, height, width, 3]) 72 | print('generator dataset, {}'.format(dataset)) 73 | dataset = dataset.prefetch(buffer_size=256) 74 | dataset = dataset.shuffle(buffer_size=256).batch(batch_size) 75 | print('epoch dataset, {}'.format(dataset)) 76 | 77 | return dataset 78 | 79 | def convert_to_tf_dataset_debug(self, videos_clips_dict, batch_size, time_steps): 80 | normal_clips = videos_clips_dict['normal'] 81 | normal_number_list = videos_clips_dict['normal_numbers'] 82 | abnormal_clips = videos_clips_dict['abnormal'] 83 | abnormal_number_list = videos_clips_dict['abnormal_numbers'] 84 | 85 | num_normal_videos = len(normal_clips) 86 | num_abnormal_videos = len(abnormal_clips) 87 | 88 | height, width = self.resize_height, self.resize_width 89 | 90 | def _sample_video_clip(clips_list, length): 91 | clip = np.empty(shape=[time_steps, height, width, 3], dtype=np.float32) 92 | 93 | sample_idx = RNG.randint(length) 94 | for t, filename in enumerate(clips_list[sample_idx]): 95 | clip[t, ...] = load_frame(filename, height, width) 96 | 97 | return clip 98 | 99 | def video_clip_generator(): 100 | while True: 101 | a_vid, p_vid = RNG.choice(num_normal_videos, size=2) 102 | n_vid = RNG.randint(num_abnormal_videos) 103 | 104 | video_clips = np.empty(shape=[3, time_steps, height, width, 3], dtype=np.float32) 105 | video_clips[0, ...] = _sample_video_clip(normal_clips[a_vid], normal_number_list[a_vid]) 106 | video_clips[1, ...] = _sample_video_clip(normal_clips[p_vid], normal_number_list[p_vid]) 107 | video_clips[2, ...] = _sample_video_clip(abnormal_clips[n_vid], abnormal_number_list[n_vid]) 108 | 109 | # show_triplet_video_clips(video_clips) 110 | yield video_clips 111 | 112 | batch_video_clips = [] 113 | for i in range(batch_size): 114 | batch_video_clips.append(next(video_clip_generator())) 115 | batch_video_clips = np.stack(batch_video_clips, axis=0) 116 | return batch_video_clips 117 | 118 | def _setup(self): 119 | videos_info = self.parse_videos_folder(self.folder) 120 | frame_mask = self.load_frame_mask(videos_info, self.frame_mask_file) 121 | pixel_mask_list = self.load_image_mask_file_list(videos_info, self.pixel_mask_file) 122 | 123 | num_videos = len(videos_info) 124 | if self.k_folds != 0: 125 | k_folds_ids = np.array_split(np.arange(num_videos), self.k_folds) 126 | val_ids = k_folds_ids[self.kth - 1].tolist() 127 | test_ids = [] 128 | for k in range(self.k_folds): 129 | if k != (self.kth - 1): 130 | test_ids += k_folds_ids[k].tolist() 131 | else: 132 | val_ids = [] 133 | test_ids = np.arange(num_videos) 134 | 135 | def add_gt_to_info(ids): 136 | videos_info_gt = OrderedDict() 137 | videos_names = list(videos_info.keys()) 138 | 139 | for i in ids: 140 | v_name = videos_names[i] 141 | pixel_mask = pixel_mask_list[i] if pixel_mask_list else [] 142 | 143 | videos_info_gt[v_name] = { 144 | 'length': videos_info[v_name]['length'], 145 | 'images': videos_info[v_name]['images'], 146 | 'frame_mask': frame_mask[i], 147 | 'pixel_mask': pixel_mask 148 | } 149 | 150 | return videos_info_gt 151 | 152 | val_videos_info = add_gt_to_info(val_ids) 153 | test_videos_info = add_gt_to_info(test_ids) 154 | 155 | del videos_info 156 | 157 | self.val_videos_info = val_videos_info 158 | self.test_videos_info = test_videos_info 159 | 160 | def read_video_clip(self, images_paths): 161 | video_clip = [] 162 | for filename in images_paths: 163 | video_clip.append(load_frame(filename, self.resize_height, self.resize_width)) 164 | 165 | video_clip = np.stack(video_clip, axis=0) 166 | return video_clip 167 | 168 | def get_video_clip(self, video, start, end, interval=1): 169 | # assert video_name in self._videos_info, 'video {} is not in {}!'.format(video_name, self._videos_info.keys()) 170 | # assert (start >= 0) and (start <= end) and (end < self._videos_info[video_name]['length']) 171 | 172 | video_idx = np.arange(start, end, interval) 173 | video_clip = np.empty(shape=[len(video_idx), self.resize_height, self.resize_width, 3], dtype=np.float32) 174 | for idx, v_idx in enumerate(video_idx): 175 | filename = self.test_videos_info[video]['images'][v_idx] 176 | video_clip[idx, ...] = load_frame(filename, self.resize_height, self.resize_width) 177 | 178 | return video_clip 179 | 180 | def get_val_video_clip(self, video, start, end, interval=1): 181 | video_idx = np.arange(start, end, interval) 182 | video_clip = np.empty(shape=[len(video_idx), self.resize_height, self.resize_width, 3], dtype=np.float32) 183 | for idx, v_idx in enumerate(video_idx): 184 | filename = self.val_videos_info[video]['images'][v_idx] 185 | video_clip[idx, ...] = load_frame(filename, self.resize_height, self.resize_width) 186 | 187 | return video_clip 188 | 189 | def get_video_names(self): 190 | return list(self.test_videos_info) 191 | 192 | @LazyProperty 193 | def get_frame_mask(self): 194 | frame_mask = [] 195 | for video_info in self.test_videos_info.values(): 196 | frame_mask.append(video_info['frame_mask']) 197 | 198 | return frame_mask 199 | 200 | @LazyProperty 201 | def get_total_frames(self): 202 | total = 0 203 | for video_info in self.test_videos_info.values(): 204 | total += video_info['length'] 205 | return total 206 | -------------------------------------------------------------------------------- /utils/dataloaders/tune_video_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from utils.util import load_frame 7 | from utils.dataloaders import BaseDataAbstractLoader, LazyProperty, RNG 8 | 9 | 10 | 11 | class DataTuneVideoGtLoaderImage(BaseDataAbstractLoader): 12 | def __init__(self, dataset, train_folder, test_folder, k_folds, kth, 13 | frame_mask_file, pixel_mask_file='', psnr_file='', 14 | resize_height=256, resize_width=256): 15 | 16 | super().__init__(dataset=dataset, folder=train_folder, 17 | resize_height=resize_height, resize_width=resize_width) 18 | 19 | self.k_folds = k_folds 20 | self.kth = kth 21 | self.frame_mask_file = frame_mask_file 22 | self.pixel_mask_file = pixel_mask_file 23 | self.psnr_file = psnr_file 24 | self.test_folder = test_folder 25 | 26 | self.score_min = 0.5 27 | self.score_max = 0.5 28 | 29 | self.train_videos_info = None 30 | self.val_videos_info = None 31 | self.test_videos_info = None 32 | 33 | self._setup() 34 | 35 | def __call__(self, batch_size, time_steps, interval=1, is_training=True): 36 | if is_training: 37 | train_clips_dict = self.sample_normal_abnormal_clips_scores(self.train_videos_info, time_steps, interval) 38 | val_clips_dict = self.sample_normal_abnormal_clips_scores(self.val_videos_info, time_steps, interval, 39 | max_scores=self.score_max, 40 | min_scores=self.score_min) 41 | dataset = self.convert_to_tf_dataset_training(train_clips_dict, val_clips_dict, batch_size, time_steps) 42 | else: 43 | test_clips_dict = self.sample_normal_abnormal_clips_scores(self.test_videos_info, time_steps, interval) 44 | dataset = self.convert_to_tf_dataset_testing(test_clips_dict, batch_size, time_steps) 45 | return dataset 46 | 47 | def debug(self, batch_size, time_steps, interval=1, is_training=True): 48 | train_clips_dict = self.sample_normal_abnormal_clips_scores(self.train_videos_info, time_steps, interval) 49 | val_clips_dict = self.sample_normal_abnormal_clips_scores(self.val_videos_info, time_steps, interval) 50 | 51 | train_normal_clips = train_clips_dict['normal'] 52 | train_normal_number_list = train_clips_dict['normal_numbers'] 53 | train_normal_scores = train_clips_dict['normal_scores'] 54 | 55 | val_abnormal_clips = val_clips_dict['abnormal'] 56 | val_abnormal_scores = val_clips_dict['abnormal_scores'] 57 | val_abnormal_number_list = val_clips_dict['abnormal_numbers'] 58 | 59 | val_normal_clips = val_clips_dict['normal'] 60 | val_normal_scores = val_clips_dict['normal_scores'] 61 | val_normal_number_list = val_clips_dict['normal_numbers'] 62 | 63 | num_train_normal_videos = len(train_normal_clips) 64 | num_val_normal_videos = len(val_normal_clips) 65 | num_val_abnormal_videos = len(val_abnormal_clips) 66 | 67 | height, width = self.resize_height, self.resize_width 68 | 69 | def _sample_video_clip_score(clips_list, length, scores_list): 70 | clip = np.empty(shape=[time_steps, height, width, 3], dtype=np.float32) 71 | 72 | sample_idx = RNG.randint(length) 73 | for t, filename in enumerate(clips_list[sample_idx]): 74 | clip[t, ...] = load_frame(filename, height, width) 75 | return clip, np.array(scores_list[sample_idx]) 76 | 77 | def video_clip_generator(): 78 | while True: 79 | batch_video_clips = [] 80 | batch_video_scores = [] 81 | 82 | for i in range(batch_size): 83 | a_vid = RNG.randint(num_train_normal_videos) 84 | p_vid = RNG.randint(num_val_normal_videos) 85 | n_vid = RNG.randint(num_val_abnormal_videos) 86 | 87 | video_clips = np.empty(shape=[3, time_steps, height, width, 3], dtype=np.float32) 88 | video_scores = np.empty(shape=[3, time_steps], dtype=np.float32) 89 | video_clips[0, ...], video_scores[0, ...] = _sample_video_clip_score(train_normal_clips[a_vid], 90 | train_normal_number_list[a_vid], 91 | train_normal_scores[a_vid]) 92 | video_clips[1, ...], video_scores[1, ...] = _sample_video_clip_score(val_normal_clips[p_vid], 93 | val_normal_number_list[p_vid], 94 | val_normal_scores[p_vid]) 95 | video_clips[2, ...], video_scores[2, ...] = _sample_video_clip_score(val_abnormal_clips[n_vid], 96 | val_abnormal_number_list[n_vid], 97 | val_abnormal_scores[n_vid]) 98 | 99 | batch_video_clips.append(video_clips) 100 | batch_video_scores.append(video_scores) 101 | 102 | batch_video_clips = np.stack(batch_video_clips, axis=0) 103 | batch_video_scores = np.stack(batch_video_scores, axis=0) 104 | 105 | yield batch_video_clips, batch_video_scores 106 | 107 | return video_clip_generator 108 | 109 | def convert_to_tf_dataset_training(self, train_clips_dict, val_clips_dict, batch_size, time_steps): 110 | train_normal_clips = train_clips_dict['normal'] 111 | train_normal_number_list = train_clips_dict['normal_numbers'] 112 | train_normal_scores = train_clips_dict['normal_scores'] 113 | 114 | val_abnormal_clips = val_clips_dict['abnormal'] 115 | val_abnormal_scores = val_clips_dict['abnormal_scores'] 116 | val_abnormal_number_list = val_clips_dict['abnormal_numbers'] 117 | 118 | val_normal_clips = val_clips_dict['normal'] 119 | val_normal_scores = val_clips_dict['normal_scores'] 120 | val_normal_number_list = val_clips_dict['normal_numbers'] 121 | 122 | num_train_normal_videos = len(train_normal_clips) 123 | num_val_normal_videos = len(val_normal_clips) 124 | num_val_abnormal_videos = len(val_abnormal_clips) 125 | 126 | height, width = self.resize_height, self.resize_width 127 | 128 | def _sample_video_clip_score(clips_list, length, scores_list): 129 | clip = np.empty(shape=[time_steps, height, width, 3], dtype=np.float32) 130 | 131 | sample_idx = RNG.randint(length) 132 | for t, filename in enumerate(clips_list[sample_idx]): 133 | clip[t, ...] = load_frame(filename, height, width) 134 | return clip, np.array(scores_list[sample_idx]) 135 | 136 | def video_clip_generator(): 137 | while True: 138 | a_vid = RNG.randint(num_train_normal_videos) 139 | p_vid = RNG.randint(num_val_normal_videos) 140 | n_vid = RNG.randint(num_val_abnormal_videos) 141 | 142 | video_clips = np.empty(shape=[3, time_steps, height, width, 3], dtype=np.float32) 143 | video_scores = np.empty(shape=[3, time_steps], dtype=np.float32) 144 | video_clips[0, ...], video_scores[0, ...] = _sample_video_clip_score(train_normal_clips[a_vid], 145 | train_normal_number_list[a_vid], 146 | train_normal_scores[a_vid]) 147 | video_clips[1, ...], video_scores[1, ...] = _sample_video_clip_score(val_normal_clips[p_vid], 148 | val_normal_number_list[p_vid], 149 | val_normal_scores[p_vid]) 150 | video_clips[2, ...], video_scores[2, ...] = _sample_video_clip_score(val_abnormal_clips[n_vid], 151 | val_abnormal_number_list[n_vid], 152 | val_abnormal_scores[n_vid]) 153 | 154 | # show_triplet_video_clips(video_clips) 155 | yield video_clips, video_scores 156 | 157 | # video clip paths 158 | dataset = tf.data.Dataset.from_generator(generator=video_clip_generator, 159 | output_types=(tf.float32, tf.float32), 160 | output_shapes=([3, time_steps, height, width, 3], [3, time_steps])) 161 | dataset = dataset.prefetch(buffer_size=128) 162 | dataset = dataset.shuffle(buffer_size=128).batch(batch_size) 163 | return dataset 164 | 165 | def convert_to_tf_dataset_testing(self, val_clips_dict, batch_size, time_steps): 166 | val_abnormal_clips = val_clips_dict['abnormal'] 167 | val_abnormal_scores = val_clips_dict['abnormal_scores'] 168 | val_abnormal_number_list = val_clips_dict['abnormal_numbers'] 169 | 170 | val_normal_clips = val_clips_dict['normal'] 171 | val_normal_scores = val_clips_dict['normal_scores'] 172 | val_normal_number_list = val_clips_dict['normal_numbers'] 173 | 174 | num_val_normal_videos = len(val_normal_clips) 175 | num_val_abnormal_videos = len(val_abnormal_clips) 176 | 177 | height, width = self.resize_height, self.resize_width 178 | 179 | def _sample_video_clip_score(clips_list, length, scores_list): 180 | clip = np.empty(shape=[time_steps, height, width, 3], dtype=np.float32) 181 | 182 | sample_idx = RNG.randint(length) 183 | for t, filename in enumerate(clips_list[sample_idx]): 184 | clip[t, ...] = load_frame(filename, height, width) 185 | return clip, np.array(scores_list[sample_idx]) 186 | 187 | def video_clip_generator(): 188 | while True: 189 | a_vid, p_vid = RNG.choice(num_val_normal_videos, size=2, replace=False) 190 | n_vid = RNG.randint(num_val_abnormal_videos) 191 | 192 | video_clips = np.empty(shape=[3, time_steps, height, width, 3], dtype=np.float32) 193 | video_scores = np.empty(shape=[3, time_steps], dtype=np.float32) 194 | video_clips[0, ...], video_scores[0, ...] = _sample_video_clip_score(val_normal_clips[a_vid], 195 | val_normal_number_list[a_vid], 196 | val_normal_scores[a_vid]) 197 | video_clips[1, ...], video_scores[1, ...] = _sample_video_clip_score(val_normal_clips[p_vid], 198 | val_normal_number_list[p_vid], 199 | val_normal_scores[p_vid]) 200 | video_clips[2, ...], video_scores[2, ...] = _sample_video_clip_score(val_abnormal_clips[n_vid], 201 | val_abnormal_number_list[n_vid], 202 | val_abnormal_scores[n_vid]) 203 | 204 | # show_triplet_video_clips(video_clips) 205 | yield video_clips, video_scores 206 | 207 | # video clip paths 208 | dataset = tf.data.Dataset.from_generator(generator=video_clip_generator, 209 | output_types=(tf.float32, tf.float32), 210 | output_shapes=([3, time_steps, height, width, 3], [3, time_steps])) 211 | dataset = dataset.prefetch(buffer_size=128) 212 | dataset = dataset.shuffle(buffer_size=128).batch(batch_size) 213 | return dataset 214 | 215 | def _setup(self): 216 | videos_info = self.parse_videos_folder(self.test_folder) 217 | frame_mask = self.load_frame_mask(videos_info, self.frame_mask_file) 218 | pixel_mask_list = self.load_pixel_mask_file_list(videos_info, self.pixel_mask_file) 219 | scores = self.load_frame_scores(self.psnr_file) 220 | 221 | num_videos = len(videos_info) 222 | k_folds_ids = np.array_split(np.arange(num_videos), self.k_folds) 223 | 224 | val_ids = k_folds_ids[self.kth - 1].tolist() 225 | test_ids = [] 226 | for k in range(self.k_folds): 227 | if k != (self.kth - 1): 228 | test_ids += k_folds_ids[k].tolist() 229 | 230 | def add_gt_to_info(ids): 231 | videos_info_gt = OrderedDict() 232 | videos_names = list(videos_info.keys()) 233 | 234 | for i in ids: 235 | v_name = videos_names[i] 236 | pixel_mask = pixel_mask_list[i] if pixel_mask_list else [] 237 | 238 | videos_info_gt[v_name] = { 239 | 'length': videos_info[v_name]['length'], 240 | 'images': videos_info[v_name]['images'], 241 | 'frame_mask': frame_mask[i], 242 | 'pixel_mask': pixel_mask, 243 | 'scores': scores[i] 244 | } 245 | 246 | return videos_info_gt 247 | 248 | val_videos_info = add_gt_to_info(val_ids) 249 | test_videos_info = add_gt_to_info(test_ids) 250 | 251 | del videos_info 252 | 253 | self.val_videos_info = val_videos_info 254 | self.test_videos_info = test_videos_info 255 | 256 | self.train_videos_info = self.parse_videos_folder(self.folder) 257 | 258 | def get_video_clip(self, video, start, end, interval=1): 259 | # assert video_name in self._videos_info, 'video {} is not in {}!'.format(video_name, self._videos_info.keys()) 260 | # assert (start >= 0) and (start <= end) and (end < self._videos_info[video_name]['length']) 261 | 262 | video_idx = np.arange(start, end, interval) 263 | video_clip = np.empty(shape=[len(video_idx), self.resize_height, self.resize_width, 3], dtype=np.float32) 264 | for idx, v_idx in enumerate(video_idx): 265 | filename = self.test_videos_info[video]['images'][v_idx] 266 | video_clip[idx, ...] = load_frame(filename, self.resize_height, self.resize_width) 267 | 268 | return video_clip 269 | 270 | def get_video_names(self): 271 | return list(self.test_videos_info) 272 | 273 | @LazyProperty 274 | def get_frame_mask(self): 275 | frame_mask = [] 276 | for video_info in self.test_videos_info.values(): 277 | frame_mask.append(video_info['frame_mask']) 278 | 279 | return frame_mask 280 | 281 | @LazyProperty 282 | def get_total_frames(self): 283 | total = 0 284 | for video_info in self.test_videos_info.values(): 285 | total += video_info['length'] 286 | return total 287 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def load_frame(filename, height, width): 8 | image_decoded = cv2.imread(filename) 9 | image_decoded = cv2.cvtColor(image_decoded, cv2.COLOR_BGR2RGB) 10 | image_resized = cv2.resize(image_decoded, (width, height)) 11 | image_resized = image_resized.astype(dtype=np.float32) 12 | image_resized = (image_resized / 127.5) - 1.0 13 | return image_resized 14 | 15 | 16 | def inverse_transform(images): 17 | return (images + 1.) / 2. 18 | 19 | 20 | def multi_scale_crop_load_frame(filename, height, width, bbox): 21 | image = cv2.imread(filename) 22 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 23 | image = cv2.resize(image, (width, height)) 24 | image = image.astype(dtype=np.float32) 25 | image = (image / 127.5) - 1.0 26 | 27 | return image[bbox[0][0]: bbox[0][1], bbox[1][0]:bbox[1][1], :] 28 | 29 | 30 | def log10(t): 31 | """ 32 | Calculates the base-10 log of each element in t. 33 | 34 | @param t: The tensor from which to calculate the base-10 log. 35 | 36 | @return: A tensor with the base-10 log of each element in t. 37 | """ 38 | 39 | numerator = tf.log(t) 40 | denominator = tf.log(tf.constant(10, dtype=numerator.dtype)) 41 | return numerator / denominator 42 | 43 | 44 | def psnr_error(gen_frames, gt_frames): 45 | """ 46 | Computes the Peak Signal to Noise Ratio error between the generated images and the ground 47 | truth images. 48 | 49 | @param gen_frames: A tensor of shape [batch_size, height, width, 3]. The frames generated by the 50 | generator model. 51 | @param gt_frames: A tensor of shape [batch_size, height, width, 3]. The ground-truth frames for 52 | each frame in gen_frames. 53 | 54 | @return: A scalar tensor. The mean Peak Signal to Noise Ratio error over each frame in the 55 | batch. 56 | """ 57 | shape = tf.shape(gen_frames) 58 | num_pixels = tf.to_float(shape[1] * shape[2] * shape[3]) 59 | # gt_frames = (gt_frames + 1.0) / 2.0 60 | # gen_frames = (gen_frames + 1.0) / 2.0 61 | square_diff = tf.square((gt_frames - gen_frames) / 2.0) 62 | 63 | batch_errors = 10 * log10(1 / ((1 / num_pixels) * tf.reduce_sum(square_diff, [1, 2, 3]))) 64 | return tf.reduce_mean(batch_errors) 65 | 66 | 67 | def diff_square_mask(gen_frames, gt_frames): 68 | square_diff = tf.square((gt_frames - gen_frames) / 2.0) 69 | return square_diff 70 | 71 | 72 | def diff_gray_mask(gen_frames, gt_frames, min_value=-1, max_value=1): 73 | # normalize to [0, 1] 74 | delta = max_value - min_value 75 | gen_frames = (gen_frames - min_value) / delta 76 | gt_frames = (gt_frames - min_value) / delta 77 | 78 | gen_gray_frames = tf.image.rgb_to_grayscale(gen_frames) 79 | gt_gray_frames = tf.image.rgb_to_grayscale(gt_frames) 80 | 81 | diff = tf.abs(gen_gray_frames - gt_gray_frames) 82 | return diff 83 | 84 | 85 | def load(saver, sess, ckpt_path): 86 | saver.restore(sess, ckpt_path) 87 | print("Restored model parameters from {}".format(ckpt_path)) 88 | 89 | 90 | def save(saver, sess, logdir, step): 91 | model_name = 'model.ckpt' 92 | checkpoint_path = os.path.join(logdir, model_name) 93 | if not os.path.exists(logdir): 94 | os.makedirs(logdir) 95 | saver.save(sess, checkpoint_path, global_step=step) 96 | print('The checkpoint has been created.') --------------------------------------------------------------------------------