├── .gitignore ├── LICENSE ├── README.md ├── _init_paths.py ├── configs ├── __init__.py ├── base_configs.py ├── build.py ├── hdd │ ├── __init__.py │ └── trn_hdd.py └── thumos │ ├── __init__.py │ └── trn_thumos.py ├── data └── data_info.json ├── demo └── network.jpg ├── lib ├── datasets │ ├── __init__.py │ ├── datasets.py │ ├── hdd_data_layer.py │ └── thumos_data_layer.py ├── models │ ├── __init__.py │ ├── feature_extractor.py │ ├── generalized_trn.py │ └── models.py └── utils │ ├── __init__.py │ ├── eval_utils.py │ ├── logger.py │ ├── multicrossentropy_loss.py │ └── net_utils.py └── tools ├── trn_hdd ├── eval.py └── train.py └── trn_thumos ├── eval.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | */*.pth 2 | */**/*.pth 3 | */**/**/*.pth 4 | */**/**/**/*.pth 5 | */*.pyc 6 | */**/*.pyc 7 | */**/**/*.pyc 8 | */**/**/**/*.pyc 9 | */*.swp 10 | */**/*.swp 11 | */**/**/*.swp 12 | */**/**/**/*.swp 13 | */*.txt 14 | */**/*.txt 15 | */**/**/*.txt 16 | */**/**/**/*.txt 17 | 18 | data/HDD 19 | data/THUMOS 20 | 21 | model_zoo/* 22 | 23 | tools/*.json 24 | tools/**/*.json 25 | tools/**/**/*.json 26 | tools/**/**/**/*.json 27 | 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Mingze Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Temporal Recurrent Networks for Online Action Detection 2 | 3 | ## Updates 4 | 5 | :boom: **November 18th 2021**: The code of [`Long Short-Term Transformer (LSTR)`](https://arxiv.org/pdf/2107.03377.pdf) is released [`here`](https://github.com/amazon-research/long-short-term-transformer). 6 | 7 | :boom: **July 08th 2021**: We are releasing [`Long Short-Term Transformer (LSTR)`](https://arxiv.org/pdf/2107.03377.pdf), a more effective and efficient method for modeling prolonged sequence data! [`LSTR`](https://arxiv.org/pdf/2107.03377.pdf) achieves SoTA on Online Action Detection benchmarks. 8 | 9 | :boom: **May 25th 2021**: For future comparison with TRN using Kinetics pretrained features, we report our results on THUMOS as 62.1% in mAP, on TVSeries as 86.2% in cAP, and on HACS Segment as 78.9% in mAP. 10 | 11 | For feature encoding, we use [`ResNet-50`](https://arxiv.org/pdf/1512.03385.pdf) model for the RGB input, and the [`BN-Inception`](https://arxiv.org/pdf/1502.03167.pdf) model for the optical flow input. To replicate our results, please use the pretrained weights of ResNet-50 in [`MMAction2`](https://github.com/open-mmlab/mmaction2/blob/master/configs/recognition/tsn/README.md#kinetics-400) and BN-Inception in this [`repo`](http://yjxiong.me/others/kinetics_action/). 12 | 13 | ## Introduction 14 | 15 | This is a PyTorch **reimplementation** for our ICCV 2019 paper "[`Temporal Recurrent Networks for Online Action Detection`](https://arxiv.org/pdf/1811.07391.pdf)". 16 | 17 | ![network](demo/network.jpg?raw=true) 18 | 19 | ## Environment 20 | 21 | - The code is developed with CUDA 9.0, ***Python >= 3.6***, ***PyTorch >= 1.0*** 22 | 23 | ## Data Preparation 24 | 25 | #### Option1: Prepare the features and targets by yourself. 26 | 27 | 1. Download the [`HDD`](https://usa.honda-ri.com/hdd) and [`THUMOS'14`](https://www.crcv.ucf.edu/THUMOS14/) datasets. 28 | 29 | 2. Extract feature representations for video frames. 30 | 31 | * For HDD dataset, we use the [`Inception-ResNet-V2`](https://arxiv.org/pdf/1602.07261.pdf) pretrained on ImageNet for the RGB input. 32 | 33 | * For THUMOS'14 dataset, we use the [`ResNet-200`](https://arxiv.org/pdf/1512.03385.pdf) model for the RGB input, and the [`BN-Inception`](https://arxiv.org/pdf/1502.03167.pdf) model for the optical flow input. To replicate our results, please follow the repo here: [`https://github.com/yjxiong/anet2016-cuhk`](https://github.com/yjxiong/anet2016-cuhk). 34 | 35 | ***Note:*** We compute the optical flow for the THUMOS'14 dataset using [`FlowNet2.0`](https://arxiv.org/pdf/1612.01925.pdf). 36 | 37 | 3. If you want to use our [dataloaders](./lib/datasets), please make sure to put the files as the following structure: 38 | 39 | * HDD dataset: 40 | ``` 41 | $YOUR_PATH_TO_HDD_DATASET 42 | ├── inceptionresnetv2/ 43 | | ├── 201702271017.npy (of size L x 1536 x 8 x 8) 44 | │   ├── ... 45 | ├── sensor/ 46 | | ├── 201702271017.npy (of size L x 8) 47 | | ├── ... 48 | ├── target/ 49 | | ├── 201702271017.npy (of size L) 50 | | ├── ... 51 | ``` 52 | 53 | * THUMOS'14 dataset: 54 | ``` 55 | $YOUR_PATH_TO_THUMOS_DATASET 56 | ├── resnet200-fc/ 57 | | ├── video_validation_0000051.npy (of size L x 2048) 58 | │   ├── ... 59 | ├── bn_inception/ 60 | | ├── video_validation_0000051.npy (of size L x 1024) 61 | | ├── ... 62 | ├── target/ 63 | | ├── video_validation_0000051.npy (of size L x 22) 64 | | ├── ... 65 | ``` 66 | 67 | #### Option2: Directly download the pre-extracted features and targets from TeSTra. 68 | 69 | You can skip the step of 1, 2, 3 above and directly use the pre-extracted features and targets from [TeSTra](https://github.com/zhaoyue-zephyrus/TeSTra). They extactly follow our data structure and should be able to reproduce TRN's performance. However, if you have any question about the processing of these features and targets, please contact the authors of TeSTra directly. 70 | 71 | 4. Create softlinks of datasets: 72 | ``` 73 | cd TRN.pytorch 74 | ln -s $YOUR_PATH_TO_HDD_DATASET data/HDD 75 | ln -s $YOUR_PATH_TO_THUMOS_DATASET data/THUMOS 76 | ``` 77 | 78 | ## Training 79 | 80 | * Single GPU training on HDD dataset: 81 | ``` 82 | cd TRN.pytorch 83 | # Training from scratch 84 | python tools/trn_hdd/train.py --gpu $CUDA_VISIBLE_DEVICES 85 | # Finetuning from a pretrained model 86 | python tools/trn_hdd/train.py --checkpoint $PATH_TO_CHECKPOINT --gpu $CUDA_VISIBLE_DEVICES 87 | ``` 88 | 89 | * Multi-GPU training on HDD dataset: 90 | ``` 91 | cd TRN.pytorch 92 | # Training from scratch 93 | python tools/trn_hdd/train.py --gpu $CUDA_VISIBLE_DEVICES --distributed 94 | # Finetuning from a pretrained model 95 | python tools/trn_hdd/train.py --checkpoint $PATH_TO_CHECKPOINT --gpu $CUDA_VISIBLE_DEVICES --distributed 96 | ``` 97 | 98 | * Single GPU training on THUMOS'14 dataset: 99 | ``` 100 | cd TRN.pytorch 101 | # Training from scratch 102 | python tools/trn_thumos/train.py --gpu $CUDA_VISIBLE_DEVICES 103 | # Finetuning from a pretrained model 104 | python tools/trn_thumos/train.py --checkpoint $PATH_TO_CHECKPOINT --gpu $CUDA_VISIBLE_DEVICES 105 | ``` 106 | 107 | * Multi-GPU training on THUMOS'14 dataset: 108 | ``` 109 | cd TRN.pytorch 110 | # Training from scratch 111 | python tools/trn_thumos/train.py --gpu $CUDA_VISIBLE_DEVICES --distributed 112 | # Finetuning from a pretrained model 113 | python tools/trn_thumos/train.py --checkpoint $PATH_TO_CHECKPOINT --gpu $CUDA_VISIBLE_DEVICES --distributed 114 | ``` 115 | 116 | ## Evaluation 117 | 118 | * HDD dataset: 119 | ``` 120 | cd TRN.pytorch 121 | python tools/trn_hdd/eval.py --checkpoint $PATH_TO_CHECKPOINT --gpu $CUDA_VISIBLE_DEVICES 122 | ``` 123 | 124 | * THUMOS'14 dataset: 125 | ``` 126 | cd TRN.pytorch 127 | python tools/trn_thumos/eval.py --checkpoint $PATH_TO_CHECKPOINT --gpu $CUDA_VISIBLE_DEVICES 128 | ``` 129 | 130 | ***NOTE:*** There are two kinds of evaluation methods in our code. (1) Using `--debug` during training considers each short video clip (consisting of 90 and 64 consecutive frames for HDD and THUMOS'14 datasets, respectively) as one test sample, and separately runs inference and evaluates on all short video clips (even though some of them are from the same long video). (2) Using `eval.py` after training runs inference and evaluates on long videos (frame by frame, from the beginning to the end), which is the evaluation method we reported in the paper. 131 | 132 | ## Citations 133 | 134 | If you are using the data/code/model provided here in a publication, please cite our paper: 135 | 136 | @inproceedings{onlineaction2019iccv, 137 | title = {Temporal Recurrent Networks for Online Action Detection}, 138 | author = {Mingze Xu and Mingfei Gao and Yi-Ting Chen and Larry S. Davis and David J. Crandall}, 139 | booktitle = {IEEE International Conference on Computer Vision (ICCV)}, 140 | year = {2019} 141 | } 142 | -------------------------------------------------------------------------------- /_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: 6 | sys.path.insert(0, path) 7 | 8 | this_dir = osp.dirname(__file__) 9 | 10 | # Add lib to PYTHONPATH 11 | lib_path = osp.join(this_dir, 'lib') 12 | add_path(lib_path) 13 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_configs import * 2 | from .build import * 3 | -------------------------------------------------------------------------------- /configs/base_configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | __all__ = ['parse_base_args'] 4 | 5 | def parse_base_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--data_info', default='data/data_info.json', type=str) 8 | parser.add_argument('--checkpoint', default='', type=str) 9 | parser.add_argument('--start_epoch', default=1, type=int) 10 | parser.add_argument('--verbose', action='store_true') 11 | parser.add_argument('--debug', action='store_true') 12 | parser.add_argument('--distributed', action='store_true') 13 | parser.add_argument('--gpu', default='0', type=str) 14 | parser.add_argument('--num_workers', default=4, type=int) 15 | parser.add_argument('--epochs', default=21, type=int) 16 | parser.add_argument('--batch_size', default=32, type=int) 17 | parser.add_argument('--lr', default=5e-04, type=float) 18 | parser.add_argument('--weight_decay', default=5e-04, type=float) 19 | parser.add_argument('--seed', default=25, type=int) 20 | parser.add_argument('--phases', default=['train', 'test'], type=list) 21 | return parser 22 | -------------------------------------------------------------------------------- /configs/build.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import json 3 | 4 | __all__ = ['build_data_info'] 5 | 6 | def build_data_info(args): 7 | args.dataset = osp.basename(osp.normpath(args.data_root)) 8 | with open(args.data_info, 'r') as f: 9 | data_info = json.load(f)[args.dataset] 10 | args.train_session_set = data_info['train_session_set'] 11 | args.test_session_set = data_info['test_session_set'] 12 | args.class_index = data_info['class_index'] 13 | args.num_classes = len(args.class_index) 14 | return args 15 | -------------------------------------------------------------------------------- /configs/hdd/__init__.py: -------------------------------------------------------------------------------- 1 | from .trn_hdd import * 2 | -------------------------------------------------------------------------------- /configs/hdd/trn_hdd.py: -------------------------------------------------------------------------------- 1 | from configs import parse_base_args, build_data_info 2 | 3 | __all__ = ['parse_trn_args'] 4 | 5 | def parse_trn_args(): 6 | parser = parse_base_args() 7 | parser.add_argument('--data_root', default='data/HDD', type=str) 8 | parser.add_argument('--model', default='TRN', type=str) 9 | parser.add_argument('--inputs', default='multimodal', type=str) 10 | parser.add_argument('--hidden_size', default=2000, type=int) 11 | parser.add_argument('--camera_feature', default='inceptionresnetv2', type=str) 12 | parser.add_argument('--enc_steps', default=90, type=int) 13 | parser.add_argument('--dec_steps', default=6, type=int) 14 | parser.add_argument('--dropout', default=0.1, type=float) 15 | return build_data_info(parser.parse_args()) 16 | -------------------------------------------------------------------------------- /configs/thumos/__init__.py: -------------------------------------------------------------------------------- 1 | from .trn_thumos import * 2 | -------------------------------------------------------------------------------- /configs/thumos/trn_thumos.py: -------------------------------------------------------------------------------- 1 | from configs import parse_base_args, build_data_info 2 | 3 | __all__ = ['parse_trn_args'] 4 | 5 | def parse_trn_args(): 6 | parser = parse_base_args() 7 | parser.add_argument('--data_root', default='data/THUMOS', type=str) 8 | parser.add_argument('--model', default='TRN', type=str) 9 | parser.add_argument('--inputs', default='multistream', type=str) 10 | parser.add_argument('--hidden_size', default=4096, type=int) 11 | parser.add_argument('--camera_feature', default='resnet200-fc', type=str) 12 | parser.add_argument('--motion_feature', default='bn_inception', type=str) 13 | parser.add_argument('--enc_steps', default=64, type=int) 14 | parser.add_argument('--dec_steps', default=8, type=int) 15 | parser.add_argument('--dropout', default=0.1, type=float) 16 | return build_data_info(parser.parse_args()) 17 | -------------------------------------------------------------------------------- /data/data_info.json: -------------------------------------------------------------------------------- 1 | {"HDD": {"class_index": ["background", "intersection passing", "left turn", "right turn", "left lane change", "right lane change", "left lane branch", "right lane branch", "crosswalk passing", "railroad passing", "merge", "U-turn"], "train_session_set": ["201702271017", "201702271123", "201702271136", "201702271438", "201702271632", "201702281017", "201702281511", "201702281709", "201703011016", "201703061033", "201703061107", "201703061323", "201703061353", "201703061418", "201703061429", "201703061456", "201703061519", "201703061541", "201703061606", "201703061635", "201703061700", "201703061725", "201703080946", "201703081008", "201703081055", "201703081152", "201703081407", "201703081437", "201703081509", "201703081549", "201703081617", "201703081653", "201703081723", "201703081749", "201704101354", "201704101504", "201704101624", "201704101658", "201704110943", "201704111011", "201704111041", "201704111138", "201704111202", "201704111315", "201704111335", "201704111402", "201704111412", "201704111540", "201706061021", "201706070945", "201706071021", "201706071319", "201706071458", "201706071518", "201706071532", "201706071602", "201706071620", "201706071630", "201706071658", "201706071735", "201706071752", "201706080945", "201706081335", "201706081445", "201706081626", "201706081707", "201706130952", "201706131127", "201706131318", "201706141033", "201706141147", "201706141538", "201706141720", "201706141819", "201709200946", "201709201027", "201709201221", "201709201319", "201709201530", "201709201605", "201709201700", "201709210940", "201709211047", "201709211317", "201709211444", "201709211547", "201709220932", "201709221037", "201709221238", "201709221313", "201709221435", "201709221527", "201710031224", "201710031247", "201710031436", "201710040938", "201710060950", "201710061114", "201710061311", "201710061345"], "test_session_set": ["201704101118", "201704130952", "201704131020", "201704131047", "201704131123", "201704131537", "201704131634", "201704131655", "201704140944", "201704141033", "201704141055", "201704141117", "201704141145", "201704141243", "201704141420", "201704141608", "201704141639", "201704141725", "201704150933", "201704151035", "201704151103", "201704151140", "201704151315", "201704151347", "201704151502", "201706061140", "201706061309", "201706061536", "201706061647", "201706140912", "201710031458", "201710031645", "201710041102", "201710041209", "201710041351", "201710041448"]}, "TVSeries": {"class_index": ["Pick something up", "Point", "Drink", "Stand up", "Run", "Sit down", "Read", "Smoke", "Drive car", "Open door", "Give something", "Use computer", "Write", "Go down stairway", "Close door", "Throw something", "Go up stairway", "Get in/out of car", "Hang up phone", "Eat", "Answer phone", "Dress up", "Clap", "Undress", "Kiss", "Fall/trip", "Wave", "Pour", "Punch", "Fire weapon"], "train_session_set": ["24_ep1", "24_ep2", "24_ep3", "Breaking_Bad_ep1", "Breaking_Bad_ep2", "How_I_Met_Your_Mother_ep1", "How_I_Met_Your_Mother_ep2", "How_I_Met_Your_Mother_ep3", "How_I_Met_Your_Mother_ep4", "How_I_Met_Your_Mother_ep5", "How_I_Met_Your_Mother_ep6", "Mad_Men_ep1", "Mad_Men_ep2", "Modern_Family_ep1", "Modern_Family_ep2", "Modern_Family_ep3", "Modern_Family_ep4", "Modern_Family_ep6", "Sons_of_Anarchy_ep1", "Sons_of_Anarchy_ep2"], "test_session_set": ["24_ep4", "Breaking_Bad_ep3", "Mad_Men_ep3", "How_I_Met_Your_Mother_ep7", "How_I_Met_Your_Mother_ep8", "Modern_Family_ep5", "Sons_of_Anarchy_ep3"]}, "THUMOS": {"class_index": ["Background", "BaseballPitch", "BasketballDunk", "Billiards", "CleanAndJerk", "CliffDiving", "CricketBowling", "CricketShot", "Diving", "FrisbeeCatch", "GolfSwing", "HammerThrow", "HighJump", "JavelinThrow", "LongJump", "PoleVault", "Shotput", "SoccerPenalty", "TennisSwing", "ThrowDiscus", "VolleyballSpiking", "Ambiguous"], "train_session_set": ["video_validation_0000690", "video_validation_0000288", "video_validation_0000289", "video_validation_0000416", "video_validation_0000282", "video_validation_0000283", "video_validation_0000281", "video_validation_0000286", "video_validation_0000287", "video_validation_0000284", "video_validation_0000285", "video_validation_0000202", "video_validation_0000203", "video_validation_0000201", "video_validation_0000206", "video_validation_0000207", "video_validation_0000204", "video_validation_0000205", "video_validation_0000790", "video_validation_0000208", "video_validation_0000209", "video_validation_0000420", "video_validation_0000364", "video_validation_0000853", "video_validation_0000950", "video_validation_0000937", "video_validation_0000367", "video_validation_0000290", "video_validation_0000210", "video_validation_0000059", "video_validation_0000058", "video_validation_0000057", "video_validation_0000056", "video_validation_0000055", "video_validation_0000054", "video_validation_0000053", "video_validation_0000052", "video_validation_0000051", "video_validation_0000933", "video_validation_0000949", "video_validation_0000948", "video_validation_0000945", "video_validation_0000944", "video_validation_0000947", "video_validation_0000946", "video_validation_0000941", "video_validation_0000940", "video_validation_0000190", "video_validation_0000942", "video_validation_0000261", "video_validation_0000262", "video_validation_0000263", "video_validation_0000264", "video_validation_0000265", "video_validation_0000266", "video_validation_0000267", "video_validation_0000268", "video_validation_0000269", "video_validation_0000989", "video_validation_0000060", "video_validation_0000370", "video_validation_0000938", "video_validation_0000935", "video_validation_0000668", "video_validation_0000669", "video_validation_0000664", "video_validation_0000665", "video_validation_0000932", "video_validation_0000667", "video_validation_0000934", "video_validation_0000661", "video_validation_0000662", "video_validation_0000663", "video_validation_0000181", "video_validation_0000180", "video_validation_0000183", "video_validation_0000182", "video_validation_0000185", "video_validation_0000184", "video_validation_0000187", "video_validation_0000186", "video_validation_0000189", "video_validation_0000188", "video_validation_0000936", "video_validation_0000270", "video_validation_0000854", "video_validation_0000178", "video_validation_0000179", "video_validation_0000174", "video_validation_0000175", "video_validation_0000176", "video_validation_0000177", "video_validation_0000170", "video_validation_0000171", "video_validation_0000172", "video_validation_0000173", "video_validation_0000670", "video_validation_0000419", "video_validation_0000943", "video_validation_0000485", "video_validation_0000369", "video_validation_0000368", "video_validation_0000318", "video_validation_0000319", "video_validation_0000415", "video_validation_0000414", "video_validation_0000413", "video_validation_0000412", "video_validation_0000411", "video_validation_0000311", "video_validation_0000312", "video_validation_0000313", "video_validation_0000314", "video_validation_0000315", "video_validation_0000316", "video_validation_0000317", "video_validation_0000418", "video_validation_0000365", "video_validation_0000482", "video_validation_0000169", "video_validation_0000168", "video_validation_0000167", "video_validation_0000166", "video_validation_0000165", "video_validation_0000164", "video_validation_0000163", "video_validation_0000162", "video_validation_0000161", "video_validation_0000160", "video_validation_0000857", "video_validation_0000856", "video_validation_0000855", "video_validation_0000366", "video_validation_0000488", "video_validation_0000489", "video_validation_0000851", "video_validation_0000484", "video_validation_0000361", "video_validation_0000486", "video_validation_0000487", "video_validation_0000481", "video_validation_0000910", "video_validation_0000483", "video_validation_0000363", "video_validation_0000990", "video_validation_0000939", "video_validation_0000362", "video_validation_0000987", "video_validation_0000859", "video_validation_0000787", "video_validation_0000786", "video_validation_0000785", "video_validation_0000784", "video_validation_0000783", "video_validation_0000782", "video_validation_0000781", "video_validation_0000981", "video_validation_0000983", "video_validation_0000982", "video_validation_0000985", "video_validation_0000984", "video_validation_0000417", "video_validation_0000788", "video_validation_0000152", "video_validation_0000153", "video_validation_0000151", "video_validation_0000156", "video_validation_0000157", "video_validation_0000154", "video_validation_0000155", "video_validation_0000158", "video_validation_0000159", "video_validation_0000901", "video_validation_0000903", "video_validation_0000902", "video_validation_0000905", "video_validation_0000904", "video_validation_0000907", "video_validation_0000906", "video_validation_0000909", "video_validation_0000908", "video_validation_0000490", "video_validation_0000860", "video_validation_0000858", "video_validation_0000988", "video_validation_0000320", "video_validation_0000688", "video_validation_0000689", "video_validation_0000686", "video_validation_0000687", "video_validation_0000684", "video_validation_0000685", "video_validation_0000682", "video_validation_0000683", "video_validation_0000681", "video_validation_0000789", "video_validation_0000986", "video_validation_0000931", "video_validation_0000852", "video_validation_0000666"], "test_session_set": ["video_test_0000292", "video_test_0001078", "video_test_0000896", "video_test_0000897", "video_test_0000950", "video_test_0001159", "video_test_0001079", "video_test_0000807", "video_test_0000179", "video_test_0000173", "video_test_0001072", "video_test_0001075", "video_test_0000767", "video_test_0001076", "video_test_0000007", "video_test_0000006", "video_test_0000556", "video_test_0001307", "video_test_0001153", "video_test_0000718", "video_test_0000716", "video_test_0001309", "video_test_0000714", "video_test_0000558", "video_test_0001267", "video_test_0000367", "video_test_0001324", "video_test_0000085", "video_test_0000887", "video_test_0001281", "video_test_0000882", "video_test_0000671", "video_test_0000964", "video_test_0001164", "video_test_0001114", "video_test_0000771", "video_test_0001163", "video_test_0001118", "video_test_0001201", "video_test_0001040", "video_test_0001207", "video_test_0000723", "video_test_0000569", "video_test_0000672", "video_test_0000673", "video_test_0000278", "video_test_0001162", "video_test_0000405", "video_test_0000073", "video_test_0000560", "video_test_0001276", "video_test_0000270", "video_test_0000273", "video_test_0000374", "video_test_0000372", "video_test_0001168", "video_test_0000379", "video_test_0001446", "video_test_0001447", "video_test_0001098", "video_test_0000873", "video_test_0000039", "video_test_0000442", "video_test_0001219", "video_test_0000762", "video_test_0000611", "video_test_0000617", "video_test_0000615", "video_test_0001270", "video_test_0000740", "video_test_0000293", "video_test_0000504", "video_test_0000505", "video_test_0000665", "video_test_0000664", "video_test_0000577", "video_test_0000814", "video_test_0001369", "video_test_0001194", "video_test_0001195", "video_test_0001512", "video_test_0001235", "video_test_0001459", "video_test_0000691", "video_test_0000765", "video_test_0001452", "video_test_0000188", "video_test_0000591", "video_test_0001268", "video_test_0000593", "video_test_0000864", "video_test_0000601", "video_test_0001135", "video_test_0000004", "video_test_0000903", "video_test_0000285", "video_test_0001174", "video_test_0000046", "video_test_0000045", "video_test_0001223", "video_test_0001358", "video_test_0001134", "video_test_0000698", "video_test_0000461", "video_test_0001182", "video_test_0000450", "video_test_0000602", "video_test_0001229", "video_test_0000989", "video_test_0000357", "video_test_0001039", "video_test_0000355", "video_test_0000353", "video_test_0001508", "video_test_0000981", "video_test_0000242", "video_test_0000854", "video_test_0001484", "video_test_0000635", "video_test_0001129", "video_test_0001339", "video_test_0001483", "video_test_0001123", "video_test_0001127", "video_test_0000689", "video_test_0000756", "video_test_0001431", "video_test_0000129", "video_test_0001433", "video_test_0001343", "video_test_0000324", "video_test_0001064", "video_test_0001531", "video_test_0001532", "video_test_0000413", "video_test_0000991", "video_test_0001255", "video_test_0000464", "video_test_0001202", "video_test_0001080", "video_test_0001081", "video_test_0000847", "video_test_0000028", "video_test_0000844", "video_test_0000622", "video_test_0000026", "video_test_0001325", "video_test_0001496", "video_test_0001495", "video_test_0000624", "video_test_0000724", "video_test_0001409", "video_test_0000131", "video_test_0000448", "video_test_0000444", "video_test_0000443", "video_test_0001038", "video_test_0000238", "video_test_0001527", "video_test_0001522", "video_test_0000051", "video_test_0001058", "video_test_0001391", "video_test_0000429", "video_test_0000426", "video_test_0000785", "video_test_0000786", "video_test_0001314", "video_test_0000392", "video_test_0000423", "video_test_0001146", "video_test_0001313", "video_test_0001008", "video_test_0001247", "video_test_0000737", "video_test_0001319", "video_test_0000308", "video_test_0000730", "video_test_0000058", "video_test_0000538", "video_test_0001556", "video_test_0000113", "video_test_0000626", "video_test_0000839", "video_test_0000220", "video_test_0001389", "video_test_0000437", "video_test_0000940", "video_test_0000211", "video_test_0000946", "video_test_0001558", "video_test_0000796", "video_test_0000062", "video_test_0000793", "video_test_0000987", "video_test_0001066", "video_test_0000412", "video_test_0000798", "video_test_0001549", "video_test_0000011", "video_test_0001257", "video_test_0000541", "video_test_0000701", "video_test_0000250", "video_test_0000254", "video_test_0000549", "video_test_0001209", "video_test_0001463", "video_test_0001460", "video_test_0000319", "video_test_0001468", "video_test_0000846", "video_test_0001292"]}} -------------------------------------------------------------------------------- /demo/network.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xumingze0308/TRN.pytorch/d432c1eddfc02ae4f255bbb8db4acd798c2ebe74/demo/network.jpg -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import build_dataset 2 | -------------------------------------------------------------------------------- /lib/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | from .hdd_data_layer import TRNHDDDataLayer 2 | from .thumos_data_layer import TRNTHUMOSDataLayer 3 | 4 | _DATA_LAYERS = { 5 | 'TRNHDD': TRNHDDDataLayer, 6 | 'TRNTHUMOS': TRNTHUMOSDataLayer, 7 | } 8 | 9 | def build_dataset(args, phase): 10 | data_layer = _DATA_LAYERS[args.model + args.dataset] 11 | return data_layer(args, phase) 12 | -------------------------------------------------------------------------------- /lib/datasets/hdd_data_layer.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.utils.data as data 5 | import numpy as np 6 | 7 | class TRNHDDDataLayer(data.Dataset): 8 | def __init__(self, args, phase='train'): 9 | self.data_root = args.data_root 10 | self.camera_feature = args.camera_feature 11 | self.sessions = getattr(args, phase+'_session_set') 12 | self.enc_steps = args.enc_steps 13 | self.dec_steps = args.dec_steps 14 | self.training = phase=='train' 15 | 16 | self.inputs = [] 17 | for session in self.sessions: 18 | sensor = np.load(osp.join(self.data_root, 'sensor', session+'.npy')) 19 | target = np.load(osp.join(self.data_root, 'target', session+'.npy')) 20 | seed = np.random.randint(self.enc_steps) if self.training else 90 21 | for start, end in zip( 22 | range(seed, target.shape[0] - self.dec_steps, self.enc_steps), 23 | range(seed + self.enc_steps, target.shape[0] - self.dec_steps, self.enc_steps)): 24 | enc_target = target[start:end] 25 | dec_target = self.get_dec_target(target[start:end + self.dec_steps]) 26 | self.inputs.append([ 27 | session, start, end, sensor[start:end], 28 | enc_target, dec_target, 29 | ]) 30 | 31 | def get_dec_target(self, target_vector): 32 | target_matrix = np.zeros((self.enc_steps, self.dec_steps)) 33 | for i in range(self.enc_steps): 34 | for j in range(self.dec_steps): 35 | # 0 -> [1, 2, 3] 36 | # target_matrix[i,j] = target_vector[i+j+1] 37 | # 0 -> [0, 1, 2] 38 | target_matrix[i,j] = target_vector[i+j] 39 | return target_matrix 40 | 41 | def __getitem__(self, index): 42 | session, start, end, sensor_inputs, enc_target, dec_target = self.inputs[index] 43 | 44 | camera_inputs = np.load( 45 | osp.join(self.data_root, self.camera_feature, session+'.npy'), mmap_mode='r')[start:end] 46 | camera_inputs = torch.as_tensor(camera_inputs.astype(np.float32)) 47 | sensor_inputs = torch.as_tensor(sensor_inputs.astype(np.float32)) 48 | enc_target = torch.as_tensor(enc_target.astype(np.int64)) 49 | dec_target = torch.as_tensor(dec_target.astype(np.int64)) 50 | 51 | return camera_inputs, sensor_inputs, enc_target, dec_target.view(-1) 52 | 53 | def __len__(self): 54 | return len(self.inputs) 55 | -------------------------------------------------------------------------------- /lib/datasets/thumos_data_layer.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.utils.data as data 5 | import numpy as np 6 | 7 | class TRNTHUMOSDataLayer(data.Dataset): 8 | def __init__(self, args, phase='train'): 9 | self.data_root = args.data_root 10 | self.camera_feature = args.camera_feature 11 | self.motion_feature = args.motion_feature 12 | self.sessions = getattr(args, phase+'_session_set') 13 | self.enc_steps = args.enc_steps 14 | self.dec_steps = args.dec_steps 15 | self.training = phase=='train' 16 | 17 | self.inputs = [] 18 | for session in self.sessions: 19 | target = np.load(osp.join(self.data_root, 'target', session+'.npy')) 20 | seed = np.random.randint(self.enc_steps) if self.training else 0 21 | for start, end in zip( 22 | range(seed, target.shape[0] - self.dec_steps, self.enc_steps), 23 | range(seed + self.enc_steps, target.shape[0] - self.dec_steps, self.enc_steps)): 24 | enc_target = target[start:end] 25 | dec_target = self.get_dec_target(target[start:end + self.dec_steps]) 26 | self.inputs.append([ 27 | session, start, end, enc_target, dec_target, 28 | ]) 29 | 30 | def get_dec_target(self, target_vector): 31 | target_matrix = np.zeros((self.enc_steps, self.dec_steps, target_vector.shape[-1])) 32 | for i in range(self.enc_steps): 33 | for j in range(self.dec_steps): 34 | # 0 -> [1, 2, 3] 35 | # target_matrix[i,j] = target_vector[i+j+1,:] 36 | # 0 -> [0, 1, 2] 37 | target_matrix[i,j] = target_vector[i+j,:] 38 | return target_matrix 39 | 40 | def __getitem__(self, index): 41 | session, start, end, enc_target, dec_target = self.inputs[index] 42 | 43 | camera_inputs = np.load( 44 | osp.join(self.data_root, self.camera_feature, session+'.npy'), mmap_mode='r')[start:end] 45 | camera_inputs = torch.as_tensor(camera_inputs.astype(np.float32)) 46 | motion_inputs = np.load( 47 | osp.join(self.data_root, self.motion_feature, session+'.npy'), mmap_mode='r')[start:end] 48 | motion_inputs = torch.as_tensor(motion_inputs.astype(np.float32)) 49 | enc_target = torch.as_tensor(enc_target.astype(np.float32)) 50 | dec_target = torch.as_tensor(dec_target.astype(np.float32)) 51 | 52 | return camera_inputs, motion_inputs, enc_target, dec_target.view(-1, enc_target.shape[-1]) 53 | 54 | def __len__(self): 55 | return len(self.inputs) 56 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import build_model 2 | -------------------------------------------------------------------------------- /lib/models/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Flatten(nn.Module): 5 | def __init__(self): 6 | super(Flatten, self).__init__() 7 | 8 | def forward(self, x): 9 | return x.view(x.shape[0], -1) 10 | 11 | class HDDFeatureExtractor(nn.Module): 12 | def __init__(self, args): 13 | super(HDDFeatureExtractor, self).__init__() 14 | 15 | if args.inputs in ['camera', 'sensor', 'multimodal']: 16 | self.with_camera = 'sensor' not in args.inputs 17 | self.with_sensor = 'camera' not in args.inputs 18 | else: 19 | raise(RuntimeError('Unknown inputs of {}'.format(args.inputs))) 20 | 21 | if self.with_camera and self.with_sensor: 22 | self.fusion_size = 1280 + 20 23 | elif self.with_camera: 24 | self.fusion_size = 1280 25 | elif self.with_sensor: 26 | self.fusion_size = 20 27 | 28 | self.camera_linear = nn.Sequential( 29 | nn.Conv2d(1536, 20, kernel_size=1), 30 | nn.ReLU(inplace=True), 31 | Flatten(), 32 | ) 33 | 34 | self.sensor_linear = nn.Sequential( 35 | nn.Linear(8, 20), 36 | nn.ReLU(inplace=True), 37 | ) 38 | 39 | def forward(self, camera_input, sensor_input): 40 | if self.with_camera: 41 | camera_input = self.camera_linear(camera_input) 42 | if self.with_sensor: 43 | sensor_input = self.sensor_linear(sensor_input) 44 | 45 | if self.with_camera and self.with_sensor: 46 | fusion_input = torch.cat((camera_input, sensor_input), 1) 47 | elif self.with_camera: 48 | fusion_input = camera_input 49 | elif self.with_sensor: 50 | fusion_input = sensor_input 51 | return fusion_input 52 | 53 | class THUMOSFeatureExtractor(nn.Module): 54 | def __init__(self, args): 55 | super(THUMOSFeatureExtractor, self).__init__() 56 | 57 | if args.inputs in ['camera', 'motion', 'multistream']: 58 | self.with_camera = 'motion' not in args.inputs 59 | self.with_motion = 'camera' not in args.inputs 60 | else: 61 | raise(RuntimeError('Unknown inputs of {}'.format(args.inputs))) 62 | 63 | if self.with_camera and self.with_motion: 64 | self.fusion_size = 2048 + 1024 65 | elif self.with_camera: 66 | self.fusion_size = 2048 67 | elif self.with_motion: 68 | self.fusion_size = 1024 69 | 70 | self.input_linear = nn.Sequential( 71 | nn.Linear(self.fusion_size, self.fusion_size), 72 | nn.ReLU(inplace=True), 73 | ) 74 | 75 | def forward(self, camera_input, motion_input): 76 | if self.with_camera and self.with_motion: 77 | fusion_input = torch.cat((camera_input, motion_input), 1) 78 | elif self.with_camera: 79 | fusion_input = camera_input 80 | elif self.with_motion: 81 | fusion_input = motion_input 82 | return self.input_linear(fusion_input) 83 | 84 | _FEATURE_EXTRACTORS = { 85 | 'HDD': HDDFeatureExtractor, 86 | 'THUMOS': THUMOSFeatureExtractor, 87 | } 88 | 89 | def build_feature_extractor(args): 90 | func = _FEATURE_EXTRACTORS[args.dataset] 91 | return func(args) 92 | -------------------------------------------------------------------------------- /lib/models/generalized_trn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .feature_extractor import build_feature_extractor 5 | 6 | def fc_relu(in_features, out_features, inplace=True): 7 | return nn.Sequential( 8 | nn.Linear(in_features, out_features), 9 | nn.ReLU(inplace=inplace), 10 | ) 11 | 12 | class GeneralizedTRN(nn.Module): 13 | def __init__(self, args): 14 | super(GeneralizedTRN, self).__init__() 15 | self.hidden_size = args.hidden_size 16 | self.enc_steps = args.enc_steps 17 | self.dec_steps = args.dec_steps 18 | self.num_classes = args.num_classes 19 | self.dropout = args.dropout 20 | 21 | self.feature_extractor = build_feature_extractor(args) 22 | # TODO: Support more fusion methods 23 | if True: 24 | self.future_size = self.feature_extractor.fusion_size 25 | self.fusion_size = self.feature_extractor.fusion_size * 2 26 | 27 | self.hx_trans = fc_relu(self.hidden_size, self.hidden_size) 28 | self.cx_trans = fc_relu(self.hidden_size, self.hidden_size) 29 | self.fusion_linear = fc_relu(self.num_classes, self.hidden_size) 30 | self.future_linear = fc_relu(self.hidden_size, self.future_size) 31 | 32 | self.enc_drop = nn.Dropout(self.dropout) 33 | self.enc_cell = nn.LSTMCell(self.fusion_size, self.hidden_size) 34 | self.dec_drop = nn.Dropout(self.dropout) 35 | self.dec_cell = nn.LSTMCell(self.hidden_size, self.hidden_size) 36 | 37 | self.classifier = nn.Linear(self.hidden_size, self.num_classes) 38 | 39 | def encoder(self, camera_input, sensor_input, future_input, enc_hx, enc_cx): 40 | fusion_input = self.feature_extractor(camera_input, sensor_input) 41 | fusion_input = torch.cat((fusion_input, future_input), 1) 42 | enc_hx, enc_cx = \ 43 | self.enc_cell(self.enc_drop(fusion_input), (enc_hx, enc_cx)) 44 | enc_score = self.classifier(self.enc_drop(enc_hx)) 45 | return enc_hx, enc_cx, enc_score 46 | 47 | def decoder(self, fusion_input, dec_hx, dec_cx): 48 | dec_hx, dec_cx = \ 49 | self.dec_cell(self.dec_drop(fusion_input), (dec_hx, dec_cx)) 50 | dec_score = self.classifier(self.dec_drop(dec_hx)) 51 | return dec_hx, dec_cx, dec_score 52 | 53 | def step(self, camera_input, sensor_input, future_input, enc_hx, enc_cx): 54 | # Encoder -> time t 55 | enc_hx, enc_cx, enc_score = \ 56 | self.encoder(camera_input, sensor_input, future_input, enc_hx, enc_cx) 57 | 58 | # Decoder -> time t + 1 59 | dec_score_stack = [] 60 | dec_hx = self.hx_trans(enc_hx) 61 | dec_cx = self.cx_trans(enc_cx) 62 | fusion_input = camera_input.new_zeros((camera_input.shape[0], self.hidden_size)) 63 | future_input = camera_input.new_zeros((camera_input.shape[0], self.future_size)) 64 | for dec_step in range(self.dec_steps): 65 | dec_hx, dec_cx, dec_score = self.decoder(fusion_input, dec_hx, dec_cx) 66 | dec_score_stack.append(dec_score) 67 | fusion_input = self.fusion_linear(dec_score) 68 | future_input = future_input + self.future_linear(dec_hx) 69 | future_input = future_input / self.dec_steps 70 | 71 | return future_input, enc_hx, enc_cx, enc_score, dec_score_stack 72 | 73 | def forward(self, camera_inputs, sensor_inputs): 74 | batch_size = camera_inputs.shape[0] 75 | enc_hx = camera_inputs.new_zeros((batch_size, self.hidden_size)) 76 | enc_cx = camera_inputs.new_zeros((batch_size, self.hidden_size)) 77 | future_input = camera_inputs.new_zeros((batch_size, self.future_size)) 78 | enc_score_stack = [] 79 | dec_score_stack = [] 80 | 81 | # Encoder -> time t 82 | for enc_step in range(self.enc_steps): 83 | enc_hx, enc_cx, enc_score = self.encoder( 84 | camera_inputs[:, enc_step], 85 | sensor_inputs[:, enc_step], 86 | future_input, enc_hx, enc_cx, 87 | ) 88 | enc_score_stack.append(enc_score) 89 | 90 | # Decoder -> time t + 1 91 | dec_hx = self.hx_trans(enc_hx) 92 | dec_cx = self.cx_trans(enc_cx) 93 | fusion_input = camera_inputs.new_zeros((batch_size, self.hidden_size)) 94 | future_input = camera_inputs.new_zeros((batch_size, self.future_size)) 95 | for dec_step in range(self.dec_steps): 96 | dec_hx, dec_cx, dec_score = self.decoder(fusion_input, dec_hx, dec_cx) 97 | dec_score_stack.append(dec_score) 98 | fusion_input = self.fusion_linear(dec_score) 99 | future_input = future_input + self.future_linear(dec_hx) 100 | future_input = future_input / self.dec_steps 101 | 102 | enc_scores = torch.stack(enc_score_stack, dim=1).view(-1, self.num_classes) 103 | dec_scores = torch.stack(dec_score_stack, dim=1).view(-1, self.num_classes) 104 | return enc_scores, dec_scores 105 | -------------------------------------------------------------------------------- /lib/models/models.py: -------------------------------------------------------------------------------- 1 | from .generalized_trn import GeneralizedTRN 2 | 3 | _META_ARCHITECTURES = { 4 | 'TRN': GeneralizedTRN, 5 | } 6 | 7 | def build_model(args): 8 | meta_arch = _META_ARCHITECTURES[args.model] 9 | return meta_arch(args) 10 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .net_utils import * 2 | from .eval_utils import * 3 | from .logger import * 4 | from .multicrossentropy_loss import * 5 | -------------------------------------------------------------------------------- /lib/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | from sklearn.metrics import average_precision_score 8 | from sklearn.metrics import confusion_matrix 9 | 10 | __all__ = [ 11 | 'compute_result_multilabel', 12 | 'compute_result', 13 | ] 14 | 15 | def compute_result_multilabel(class_index, score_metrics, target_metrics, save_dir, result_file, 16 | ignore_class=[0], save=True, verbose=False, smooth=False, switch=False): 17 | result = OrderedDict() 18 | score_metrics = np.array(score_metrics) 19 | pred_metrics = np.argmax(score_metrics, axis=1) 20 | target_metrics = np.array(target_metrics) 21 | 22 | ################################################################################################################### 23 | # We follow (Shou et al., 2017) and adopt their per-frame evaluation method of THUMOS'14 datset. 24 | # Source: https://bitbucket.org/columbiadvmm/cdc/src/master/THUMOS14/eval/PreFrameLabeling/compute_framelevel_mAP.m 25 | ################################################################################################################### 26 | 27 | # Simple temporal smoothing via NMS of 5-frames window 28 | if smooth: 29 | prob = np.copy(score_metrics) 30 | prob1 = prob.reshape(1, prob.shape[0], prob.shape[1]) 31 | prob2 = np.append(prob[0, :].reshape(1, -1), prob[0:-1, :], axis=0).reshape(1, prob.shape[0], prob.shape[1]) 32 | prob3 = np.append(prob[1:, :], prob[-1, :].reshape(1, -1), axis=0).reshape(1, prob.shape[0], prob.shape[1]) 33 | prob4 = np.append(prob[0:2, :], prob[0:-2, :], axis=0).reshape(1, prob.shape[0], prob.shape[1]) 34 | prob5 = np.append(prob[2:, :], prob[-2:, :], axis=0).reshape(1, prob.shape[0], prob.shape[1]) 35 | probsmooth = np.squeeze(np.max(np.concatenate((prob1, prob2, prob3, prob4, prob5), axis=0), axis=0)) 36 | score_metrics = np.copy(probsmooth) 37 | 38 | # Assign cliff diving (5) as diving (8) 39 | if switch: 40 | switch_index = np.where(score_metrics[:, 5] > score_metrics[:, 8])[0] 41 | score_metrics[switch_index, 8] = score_metrics[switch_index, 5] 42 | 43 | # Remove ambiguous (21) 44 | valid_index = np.where(target_metrics[:, 21]!=1)[0] 45 | 46 | # Compute AP 47 | result['AP'] = OrderedDict() 48 | for cls in range(len(class_index)): 49 | if cls not in ignore_class: 50 | result['AP'][class_index[cls]] = average_precision_score( 51 | (target_metrics[valid_index, cls]==1).astype(np.int), 52 | score_metrics[valid_index, cls]) 53 | if verbose: 54 | print('{} AP: {:.5f}'.format(class_index[cls], result['AP'][class_index[cls]])) 55 | 56 | # Compute mAP 57 | result['mAP'] = np.mean(list(result['AP'].values())) 58 | if verbose: 59 | print('mAP: {:.5f}'.format(result['mAP'])) 60 | 61 | # Save 62 | if save: 63 | if not osp.isdir(save_dir): 64 | os.makedirs(save_dir) 65 | with open(osp.join(save_dir, result_file), 'w') as f: 66 | json.dump(result, f) 67 | if verbose: 68 | print('Saved the result to {}'.format(osp.join(save_dir, result_file))) 69 | 70 | return result['mAP'] 71 | 72 | def compute_result(class_index, score_metrics, target_metrics, save_dir, result_file, 73 | ignore_class=[0], save=True, verbose=False): 74 | result = OrderedDict() 75 | score_metrics = np.array(score_metrics) 76 | pred_metrics = np.argmax(score_metrics, axis=1) 77 | target_metrics = np.array(target_metrics) 78 | 79 | # Compute ACC 80 | correct = np.sum((target_metrics!=0) & (target_metrics==pred_metrics)) 81 | total = np.sum(target_metrics!=0) 82 | result['ACC'] = correct / total 83 | if verbose: 84 | print('ACC: {:.5f}'.format(result['ACC'])) 85 | 86 | # Compute confusion matrix 87 | result['confusion_matrix'] = \ 88 | confusion_matrix(target_metrics, pred_metrics).tolist() 89 | 90 | # Compute AP 91 | result['AP'] = OrderedDict() 92 | for cls in range(len(class_index)): 93 | if cls not in ignore_class: 94 | result['AP'][class_index[cls]] = average_precision_score( 95 | (target_metrics==cls).astype(np.int), 96 | score_metrics[:, cls]) 97 | if verbose: 98 | print('{} AP: {:.5f}'.format(class_index[cls], result['AP'][class_index[cls]])) 99 | 100 | # Compute mAP 101 | result['mAP'] = np.mean(list(result['AP'].values())) 102 | if verbose: 103 | print('mAP: {:.5f}'.format(result['mAP'])) 104 | 105 | # Save 106 | if save: 107 | if not osp.isdir(save_dir): 108 | os.makedirs(save_dir) 109 | with open(osp.join(save_dir, result_file), 'w') as f: 110 | json.dump(result, f) 111 | if verbose: 112 | print('Saved the result to {}'.format(osp.join(save_dir, result_file))) 113 | 114 | return result['mAP'] 115 | -------------------------------------------------------------------------------- /lib/utils/logger.py: -------------------------------------------------------------------------------- 1 | __all__ = ['setup_logger'] 2 | 3 | class Logger(object): 4 | def __init__(self, log_file, command): 5 | self.log_file = log_file 6 | if command: 7 | # self._print(command) 8 | self._write(command) 9 | 10 | def output(self, epoch, enc_losses, dec_losses, training_samples, testing_samples, 11 | enc_mAP, dec_mAP, running_time, debug=True, log=''): 12 | log += 'Epoch: {:2} | train enc_loss: {:.5f} dec_loss: {:.5f} | '.format( 13 | epoch, 14 | enc_losses['train'] / training_samples, 15 | dec_losses['train'] / training_samples, 16 | ) 17 | log += 'test enc_loss: {:.5f} dec_loss: {:.5f} enc_mAP: {:.5f} dec_mAP: {:.5f} | '.format( 18 | enc_losses['test'] / testing_samples, 19 | dec_losses['test'] / testing_samples, 20 | enc_mAP, 21 | dec_mAP, 22 | ) if debug else '' 23 | log += 'running time: {:.2f} sec'.format( 24 | running_time, 25 | ) 26 | 27 | self._print(log) 28 | self._write(log) 29 | 30 | def _print(self, log): 31 | print(log) 32 | 33 | def _write(self, log): 34 | with open(self.log_file, 'a+') as f: 35 | f.write(log + '\n') 36 | 37 | def setup_logger(log_file, command=''): 38 | return Logger(log_file, command) 39 | 40 | -------------------------------------------------------------------------------- /lib/utils/multicrossentropy_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['MultiCrossEntropyLoss'] 5 | 6 | class MultiCrossEntropyLoss(nn.Module): 7 | def __init__(self, size_average=True, ignore_index=-500): 8 | super(MultiCrossEntropyLoss, self).__init__() 9 | 10 | self.size_average = size_average 11 | self.ignore_index = ignore_index 12 | 13 | def forward(self, input, target): 14 | logsoftmax = nn.LogSoftmax(dim=1).to(input.device) 15 | 16 | if self.ignore_index >= 0: 17 | notice_index = [i for i in range(target.shape[-1]) if i != self.ignore_index] 18 | output = torch.sum(-target[:, notice_index] * logsoftmax(input[:, notice_index]), 1) 19 | return torch.mean(output[target[:, self.ignore_index] != 1]) 20 | else: 21 | output = torch.sum(-target * logsoftmax(input), 1) 22 | if self.size_average: 23 | return torch.mean(output) 24 | else: 25 | return torch.sum(output) 26 | -------------------------------------------------------------------------------- /lib/utils/net_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.data as data 7 | 8 | from datasets import build_dataset 9 | 10 | __all__ = [ 11 | 'set_seed', 12 | 'build_data_loader', 13 | 'weights_init', 14 | 'count_parameters', 15 | ] 16 | 17 | def set_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | if torch.cuda.is_available(): 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | 27 | def build_data_loader(args, phase='train'): 28 | data_loaders = data.DataLoader( 29 | dataset=build_dataset(args, phase), 30 | batch_size=args.batch_size, 31 | shuffle=phase=='train', 32 | num_workers=args.num_workers, 33 | ) 34 | return data_loaders 35 | 36 | def weights_init(m): 37 | if isinstance(m, nn.Conv2d): 38 | m.weight.data.normal_(0.0, 0.001) 39 | elif isinstance(m, nn.Linear): 40 | m.weight.data.normal_(0.0, 0.001) 41 | elif isinstance(m, nn.LSTMCell): 42 | for param in m.parameters(): 43 | if len(param.shape) >= 2: 44 | nn.init.orthogonal_(param.data) 45 | else: 46 | nn.init.normal_(param.data) 47 | 48 | def count_parameters(model): 49 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 50 | -------------------------------------------------------------------------------- /tools/trn_hdd/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | import _init_paths 11 | import utils as utl 12 | from configs.hdd import parse_trn_args as parse_args 13 | from models import build_model 14 | 15 | def to_device(x, device): 16 | return x.unsqueeze(0).to(device) 17 | 18 | def main(args): 19 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | enc_score_metrics = [] 23 | enc_target_metrics = [] 24 | dec_score_metrics = [[] for i in range(args.dec_steps)] 25 | dec_target_metrics = [[] for i in range(args.dec_steps)] 26 | 27 | if osp.isfile(args.checkpoint): 28 | checkpoint = torch.load(args.checkpoint) 29 | else: 30 | raise(RuntimeError('Cannot find the checkpoint {}'.format(args.checkpoint))) 31 | model = build_model(args).to(device) 32 | model.load_state_dict(checkpoint['model_state_dict']) 33 | model.train(False) 34 | 35 | softmax = nn.Softmax(dim=1).to(device) 36 | 37 | for session_idx, session in enumerate(args.test_session_set, start=1): 38 | start = time.time() 39 | with torch.set_grad_enabled(False): 40 | camera_inputs = np.load(osp.join(args.data_root, args.camera_feature, session+'.npy'), mmap_mode='r') 41 | sensor_inputs = np.load(osp.join(args.data_root, 'sensor', session+'.npy'), mmap_mode='r') 42 | target = np.load(osp.join(args.data_root, 'target', session+'.npy')) 43 | future_input = to_device(torch.zeros(model.future_size), device) 44 | enc_hx = to_device(torch.zeros(model.hidden_size), device) 45 | enc_cx = to_device(torch.zeros(model.hidden_size), device) 46 | 47 | for l in range(target.shape[0]): 48 | camera_input = to_device( 49 | torch.as_tensor(camera_inputs[l].astype(np.float32)), device) 50 | sensor_input = to_device( 51 | torch.as_tensor(sensor_inputs[l].astype(np.float32)), device) 52 | 53 | future_input, enc_hx, enc_cx, enc_score, dec_score_stack = \ 54 | model.step(camera_input, sensor_input, future_input, enc_hx, enc_cx) 55 | 56 | enc_score_metrics.append(softmax(enc_score).cpu().numpy()[0]) 57 | enc_target_metrics.append(target[l]) 58 | 59 | for step in range(args.dec_steps): 60 | dec_score_metrics[step].append(softmax(dec_score_stack[step]).cpu().numpy()[0]) 61 | dec_target_metrics[step].append(target[min(l + step, target.shape[0] - 1)]) 62 | end = time.time() 63 | 64 | print('Processed session {}, {:2} of {}, running time {:.2f} sec'.format( 65 | session, session_idx, len(args.test_session_set), end - start)) 66 | 67 | save_dir = osp.dirname(args.checkpoint) 68 | result_file = osp.basename(args.checkpoint).replace('.pth', '.json') 69 | # Compute result for encoder 70 | utl.compute_result(args.class_index, 71 | enc_score_metrics, enc_target_metrics, 72 | save_dir, result_file, save=True, verbose=True) 73 | 74 | # Compute result for decoder 75 | for step in range(args.dec_steps): 76 | utl.compute_result(args.class_index, 77 | dec_score_metrics[step], dec_target_metrics[step], 78 | save_dir, result_file, save=False, verbose=True) 79 | 80 | if __name__ == '__main__': 81 | main(parse_args()) 82 | -------------------------------------------------------------------------------- /tools/trn_hdd/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | import _init_paths 11 | import utils as utl 12 | from configs.hdd import parse_trn_args as parse_args 13 | from models import build_model 14 | 15 | def main(args): 16 | this_dir = osp.join(osp.dirname(__file__), '.') 17 | save_dir = osp.join(this_dir, 'checkpoints') 18 | if not osp.isdir(save_dir): 19 | os.makedirs(save_dir) 20 | command = 'python ' + ' '.join(sys.argv) 21 | logger = utl.setup_logger(osp.join(this_dir, 'log.txt'), command=command) 22 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | utl.set_seed(int(args.seed)) 25 | 26 | model = build_model(args) 27 | if osp.isfile(args.checkpoint): 28 | checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu')) 29 | model.load_state_dict(checkpoint['model_state_dict']) 30 | else: 31 | model.apply(utl.weights_init) 32 | if args.distributed: 33 | model = nn.DataParallel(model) 34 | model = model.to(device) 35 | 36 | criterion = nn.CrossEntropyLoss().to(device) 37 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 38 | if osp.isfile(args.checkpoint): 39 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 40 | for param_group in optimizer.param_groups: 41 | param_group['lr'] = args.lr 42 | args.start_epoch += checkpoint['epoch'] 43 | softmax = nn.Softmax(dim=1).to(device) 44 | 45 | for epoch in range(args.start_epoch, args.start_epoch + args.epochs): 46 | data_loaders = { 47 | phase: utl.build_data_loader(args, phase) 48 | for phase in args.phases 49 | } 50 | 51 | enc_losses = {phase: 0.0 for phase in args.phases} 52 | enc_score_metrics = [] 53 | enc_target_metrics = [] 54 | enc_mAP = 0.0 55 | dec_losses = {phase: 0.0 for phase in args.phases} 56 | dec_score_metrics = [] 57 | dec_target_metrics = [] 58 | dec_mAP = 0.0 59 | 60 | start = time.time() 61 | for phase in args.phases: 62 | training = phase=='train' 63 | if training: 64 | model.train(True) 65 | elif not training and args.debug: 66 | model.train(False) 67 | else: 68 | continue 69 | 70 | with torch.set_grad_enabled(training): 71 | for batch_idx, (camera_inputs, sensor_inputs, enc_target, dec_target) \ 72 | in enumerate(data_loaders[phase], start=1): 73 | batch_size = camera_inputs.shape[0] 74 | camera_inputs = camera_inputs.to(device) 75 | sensor_inputs = sensor_inputs.to(device) 76 | enc_target = enc_target.to(device).view(-1) 77 | dec_target = dec_target.to(device).view(-1) 78 | 79 | enc_score, dec_score = model(camera_inputs, sensor_inputs) 80 | enc_loss = criterion(enc_score, enc_target) 81 | dec_loss = criterion(dec_score, dec_target) 82 | enc_losses[phase] += enc_loss.item() * batch_size 83 | dec_losses[phase] += dec_loss.item() * batch_size 84 | if args.verbose: 85 | print('Epoch: {:2} | iteration: {:3} | enc_loss: {:.5f} dec_loss: {:.5f}'.format( 86 | epoch, batch_idx, enc_loss.item(), dec_loss.item() 87 | )) 88 | 89 | if training: 90 | optimizer.zero_grad() 91 | loss = enc_loss + dec_loss 92 | loss.backward() 93 | optimizer.step() 94 | else: 95 | # Prepare metrics for encoder 96 | enc_score = softmax(enc_score).cpu().numpy() 97 | enc_target = enc_target.cpu().numpy() 98 | enc_score_metrics.extend(enc_score) 99 | enc_target_metrics.extend(enc_target) 100 | # Prepare metrics for decoder 101 | dec_score = softmax(dec_score).cpu().numpy() 102 | dec_target = dec_target.cpu().numpy() 103 | dec_score_metrics.extend(dec_score) 104 | dec_target_metrics.extend(dec_target) 105 | end = time.time() 106 | 107 | if args.debug: 108 | result_file = 'inputs-{}-epoch-{}.json'.format(args.inputs, epoch) 109 | # Compute result for encoder 110 | enc_mAP = utl.compute_result( 111 | args.class_index, 112 | enc_score_metrics, 113 | enc_target_metrics, 114 | save_dir, 115 | result_file, 116 | save=True, 117 | ) 118 | # Compute result for decoder 119 | dec_mAP = utl.compute_result( 120 | args.class_index, 121 | dec_score_metrics, 122 | dec_target_metrics, 123 | save_dir, 124 | result_file, 125 | save=False, 126 | ) 127 | 128 | # Output result 129 | logger.output(epoch, enc_losses, dec_losses, 130 | len(data_loaders['train'].dataset), len(data_loaders['test'].dataset), 131 | enc_mAP, dec_mAP, end - start, debug=args.debug) 132 | 133 | # Save model 134 | checkpoint_file = 'inputs-{}-epoch-{}.pth'.format(args.inputs, epoch) 135 | torch.save({ 136 | 'epoch': epoch, 137 | 'model_state_dict': model.module.state_dict() if args.distributed else model.state_dict(), 138 | 'optimizer_state_dict': optimizer.state_dict(), 139 | }, osp.join(save_dir, checkpoint_file)) 140 | 141 | if __name__ == '__main__': 142 | main(parse_args()) 143 | -------------------------------------------------------------------------------- /tools/trn_thumos/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | import _init_paths 11 | import utils as utl 12 | from configs.thumos import parse_trn_args as parse_args 13 | from models import build_model 14 | 15 | def to_device(x, device): 16 | return x.unsqueeze(0).to(device) 17 | 18 | def main(args): 19 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | enc_score_metrics = [] 23 | enc_target_metrics = [] 24 | dec_score_metrics = [[] for i in range(args.dec_steps)] 25 | dec_target_metrics = [[] for i in range(args.dec_steps)] 26 | 27 | if osp.isfile(args.checkpoint): 28 | checkpoint = torch.load(args.checkpoint) 29 | else: 30 | raise(RuntimeError('Cannot find the checkpoint {}'.format(args.checkpoint))) 31 | model = build_model(args).to(device) 32 | model.load_state_dict(checkpoint['model_state_dict']) 33 | model.train(False) 34 | 35 | softmax = nn.Softmax(dim=1).to(device) 36 | 37 | for session_idx, session in enumerate(args.test_session_set, start=1): 38 | start = time.time() 39 | with torch.set_grad_enabled(False): 40 | camera_inputs = np.load(osp.join(args.data_root, args.camera_feature, session+'.npy'), mmap_mode='r') 41 | motion_inputs = np.load(osp.join(args.data_root, args.motion_feature, session+'.npy'), mmap_mode='r') 42 | target = np.load(osp.join(args.data_root, 'target', session+'.npy')) 43 | future_input = to_device(torch.zeros(model.future_size), device) 44 | enc_hx = to_device(torch.zeros(model.hidden_size), device) 45 | enc_cx = to_device(torch.zeros(model.hidden_size), device) 46 | 47 | for l in range(target.shape[0]): 48 | camera_input = to_device( 49 | torch.as_tensor(camera_inputs[l].astype(np.float32)), device) 50 | motion_input = to_device( 51 | torch.as_tensor(motion_inputs[l].astype(np.float32)), device) 52 | 53 | future_input, enc_hx, enc_cx, enc_score, dec_score_stack = \ 54 | model.step(camera_input, motion_input, future_input, enc_hx, enc_cx) 55 | 56 | enc_score_metrics.append(softmax(enc_score).cpu().numpy()[0]) 57 | enc_target_metrics.append(target[l]) 58 | 59 | for step in range(args.dec_steps): 60 | dec_score_metrics[step].append(softmax(dec_score_stack[step]).cpu().numpy()[0]) 61 | dec_target_metrics[step].append(target[min(l + step, target.shape[0] - 1)]) 62 | end = time.time() 63 | 64 | print('Processed session {}, {:2} of {}, running time {:.2f} sec'.format( 65 | session, session_idx, len(args.test_session_set), end - start)) 66 | 67 | save_dir = osp.dirname(args.checkpoint) 68 | result_file = osp.basename(args.checkpoint).replace('.pth', '.json') 69 | # Compute result for encoder 70 | utl.compute_result_multilabel(args.class_index, 71 | enc_score_metrics, enc_target_metrics, 72 | save_dir, result_file, ignore_class=[0,21], save=True, verbose=True) 73 | 74 | # Compute result for decoder 75 | for step in range(args.dec_steps): 76 | utl.compute_result_multilabel(args.class_index, 77 | dec_score_metrics[step], dec_target_metrics[step], 78 | save_dir, result_file, ignore_class=[0,21], save=False, verbose=True) 79 | 80 | if __name__ == '__main__': 81 | main(parse_args()) 82 | -------------------------------------------------------------------------------- /tools/trn_thumos/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | import _init_paths 11 | import utils as utl 12 | from configs.thumos import parse_trn_args as parse_args 13 | from models import build_model 14 | 15 | def main(args): 16 | this_dir = osp.join(osp.dirname(__file__), '.') 17 | save_dir = osp.join(this_dir, 'checkpoints') 18 | if not osp.isdir(save_dir): 19 | os.makedirs(save_dir) 20 | command = 'python ' + ' '.join(sys.argv) 21 | logger = utl.setup_logger(osp.join(this_dir, 'log.txt'), command=command) 22 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | utl.set_seed(int(args.seed)) 25 | 26 | model = build_model(args) 27 | if osp.isfile(args.checkpoint): 28 | checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu')) 29 | model.load_state_dict(checkpoint['model_state_dict']) 30 | else: 31 | model.apply(utl.weights_init) 32 | if args.distributed: 33 | model = nn.DataParallel(model) 34 | model = model.to(device) 35 | 36 | criterion = utl.MultiCrossEntropyLoss(ignore_index=21).to(device) 37 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 38 | if osp.isfile(args.checkpoint): 39 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 40 | for param_group in optimizer.param_groups: 41 | param_group['lr'] = args.lr 42 | args.start_epoch += checkpoint['epoch'] 43 | softmax = nn.Softmax(dim=1).to(device) 44 | 45 | for epoch in range(args.start_epoch, args.start_epoch + args.epochs): 46 | if epoch == 21: 47 | args.lr = args.lr * 0.1 48 | for param_group in optimizer.param_groups: 49 | param_group['lr'] = args.lr 50 | 51 | data_loaders = { 52 | phase: utl.build_data_loader(args, phase) 53 | for phase in args.phases 54 | } 55 | 56 | enc_losses = {phase: 0.0 for phase in args.phases} 57 | enc_score_metrics = [] 58 | enc_target_metrics = [] 59 | enc_mAP = 0.0 60 | dec_losses = {phase: 0.0 for phase in args.phases} 61 | dec_score_metrics = [] 62 | dec_target_metrics = [] 63 | dec_mAP = 0.0 64 | 65 | start = time.time() 66 | for phase in args.phases: 67 | training = phase=='train' 68 | if training: 69 | model.train(True) 70 | elif not training and args.debug: 71 | model.train(False) 72 | else: 73 | continue 74 | 75 | with torch.set_grad_enabled(training): 76 | for batch_idx, (camera_inputs, motion_inputs, enc_target, dec_target) \ 77 | in enumerate(data_loaders[phase], start=1): 78 | batch_size = camera_inputs.shape[0] 79 | camera_inputs = camera_inputs.to(device) 80 | motion_inputs = motion_inputs.to(device) 81 | enc_target = enc_target.to(device).view(-1, args.num_classes) 82 | dec_target = dec_target.to(device).view(-1, args.num_classes) 83 | 84 | enc_score, dec_score = model(camera_inputs, motion_inputs) 85 | enc_loss = criterion(enc_score, enc_target) 86 | dec_loss = criterion(dec_score, dec_target) 87 | enc_losses[phase] += enc_loss.item() * batch_size 88 | dec_losses[phase] += dec_loss.item() * batch_size 89 | if args.verbose: 90 | print('Epoch: {:2} | iteration: {:3} | enc_loss: {:.5f} dec_loss: {:.5f}'.format( 91 | epoch, batch_idx, enc_loss.item(), dec_loss.item() 92 | )) 93 | 94 | if training: 95 | optimizer.zero_grad() 96 | loss = enc_loss + dec_loss 97 | loss.backward() 98 | optimizer.step() 99 | else: 100 | # Prepare metrics for encoder 101 | enc_score = softmax(enc_score).cpu().numpy() 102 | enc_target = enc_target.cpu().numpy() 103 | enc_score_metrics.extend(enc_score) 104 | enc_target_metrics.extend(enc_target) 105 | # Prepare metrics for decoder 106 | dec_score = softmax(dec_score).cpu().numpy() 107 | dec_target = dec_target.cpu().numpy() 108 | dec_score_metrics.extend(dec_score) 109 | dec_target_metrics.extend(dec_target) 110 | end = time.time() 111 | 112 | if args.debug: 113 | result_file = 'inputs-{}-epoch-{}.json'.format(args.inputs, epoch) 114 | # Compute result for encoder 115 | enc_mAP = utl.compute_result_multilabel( 116 | args.class_index, 117 | enc_score_metrics, 118 | enc_target_metrics, 119 | save_dir, 120 | result_file, 121 | ignore_class=[0,21], 122 | save=True, 123 | ) 124 | # Compute result for decoder 125 | dec_mAP = utl.compute_result_multilabel( 126 | args.class_index, 127 | dec_score_metrics, 128 | dec_target_metrics, 129 | save_dir, 130 | result_file, 131 | ignore_class=[0,21], 132 | save=False, 133 | ) 134 | 135 | # Output result 136 | logger.output(epoch, enc_losses, dec_losses, 137 | len(data_loaders['train'].dataset), len(data_loaders['test'].dataset), 138 | enc_mAP, dec_mAP, end - start, debug=args.debug) 139 | 140 | # Save model 141 | checkpoint_file = 'inputs-{}-epoch-{}.pth'.format(args.inputs, epoch) 142 | torch.save({ 143 | 'epoch': epoch, 144 | 'model_state_dict': model.module.state_dict() if args.distributed else model.state_dict(), 145 | 'optimizer_state_dict': optimizer.state_dict(), 146 | }, osp.join(save_dir, checkpoint_file)) 147 | 148 | if __name__ == '__main__': 149 | main(parse_args()) 150 | --------------------------------------------------------------------------------