├── LICENSE ├── README.md ├── THIRD-PARTY-LICENSES.txt ├── configs └── UCF101 │ ├── SLOW_32x8_R50.yaml │ └── SLOW_32x8_R50_CLSONLY.yaml ├── readme └── modist.gif ├── setup.cfg ├── setup.py ├── slowfast ├── __init__.py ├── config │ ├── __init__.py │ ├── custom_config.py │ └── defaults.py ├── datasets │ ├── DATASET.md │ ├── __init__.py │ ├── audioset.py │ ├── ava_dataset.py │ ├── ava_helper.py │ ├── build.py │ ├── charades.py │ ├── cv2_transform.py │ ├── decoder.py │ ├── folder.py │ ├── kinetics.py │ ├── loader.py │ ├── multigrid_helper.py │ ├── ssv2.py │ ├── transform.py │ ├── utils.py │ ├── video_container.py │ └── video_feature_dataset.py ├── models │ ├── __init__.py │ ├── attention.py │ ├── batchnorm_helper.py │ ├── build.py │ ├── common.py │ ├── head_helper.py │ ├── losses.py │ ├── nonlocal_helper.py │ ├── optimizer.py │ ├── resnet_helper.py │ ├── stem_helper.py │ ├── video_model_builder.py │ └── video_model_builder_transformer.py ├── utils │ ├── __init__.py │ ├── ava_eval_helper.py │ ├── ava_evaluation │ │ ├── README.md │ │ ├── __init__.py │ │ ├── ava_action_list_v2.1_for_activitynet_2018.pbtxt.txt │ │ ├── label_map_util.py │ │ ├── metrics.py │ │ ├── np_box_list.py │ │ ├── np_box_list_ops.py │ │ ├── np_box_mask_list.py │ │ ├── np_box_mask_list_ops.py │ │ ├── np_box_ops.py │ │ ├── np_mask_ops.py │ │ ├── object_detection_evaluation.py │ │ ├── per_image_evaluation.py │ │ └── standard_fields.py │ ├── benchmark.py │ ├── bn_helper.py │ ├── c2_model_loading.py │ ├── checkpoint.py │ ├── custom_platform.py │ ├── distributed.py │ ├── env.py │ ├── frame_timecode.py │ ├── futils.py │ ├── logging.py │ ├── lr_policy.py │ ├── meters.py │ ├── metrics.py │ ├── misc.py │ ├── multigrid.py │ ├── multiprocessing.py │ ├── parser.py │ ├── video_splitter.py │ ├── viterbi.py │ └── weight_init_helper.py └── visualization │ ├── __init__.py │ ├── async_predictor.py │ ├── ava_demo_precomputed_boxes.py │ ├── demo_loader.py │ ├── gradcam_utils.py │ ├── predictor.py │ ├── tensorboard_vis.py │ ├── utils.py │ └── video_visualizer.py └── tools ├── __pycache__ ├── test_net.cpython-38.pyc └── train_net.cpython-38.pyc ├── run_net.py ├── test_net.py ├── train_net.py └── video_io.py /README.md: -------------------------------------------------------------------------------- 1 | # MaCLR 2 | 3 | > [**MaCLR: Motion-aware Contrastive Learning of Representations for Videos**](https://arxiv.org/abs/2106.09703), 4 | > Fanyi Xiao, Joseph Tighe, Davide Modolo 5 | 6 | 7 | ![](readme/modist.gif) 8 | 9 | 10 | @inproceedings{xiao2022maclr, 11 | title={MaCLR: Motion-aware Contrastive Learning of Representations for Videos}, 12 | author={Xiao, Fanyi and Tighe, Joseph and Modolo, Davide}, 13 | booktitle={ECCV}, 14 | year={2022} 15 | } 16 | 17 | 18 | ## Abstract 19 | We present MaCLR as a novel method to explicitly perform cross-modal self-supervised video representations learning from visual and motion modalities.Compared to previous video representation learning methods that mostly focus on learning motion cues implicitly from RGB inputs, MaCLR enriches standard contrastive learning objectives for RGB video clips with a cross-modal learning objective between a Motion pathway and a Visual pathway. We show that the representation learned with our MaCLR method focuses more on foreground motion regions and thus generalizes better to downstream tasks. To demonstrate this, we evaluate MaCLR on five datasets for both action recognition and action detection, and demonstrate state-of-the-art self-supervised performance on all datasets. 20 | Furthermore, we show that MaCLR representation can be as effective as representations learned with full supervision on UCF101 and HMDB51 action recognition, while even outperforming the supervised representation for action recognition on VidSitu and SSv2, and action detection on AVA. 21 | 22 | ## Installation 23 | 24 | - Create the python environment: `virtualenv -p python3 maclr` and activate the environment with `source /PATH/TO/ENV/maclr/bin/activate` 25 | - Install all dependencies: 26 | ``` 27 | pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html 28 | pip install fvcore simplejson av psutil opencv-python tensorboard moviepy indexed kornia==0.5.7 matplotlib librosa einops timm 29 | ``` 30 | - Navigate to code directory and install the repo with `python setup.py develop` 31 | 32 | 33 | ## Data preparation 34 | 35 | You need to download UCF101 datasets from their respective sites and format the structure like the following: 36 | ``` 37 | maclr 38 | ├── data 39 | │ ├── pretrained 40 | │ │ └── maclr.pyth 41 | │ └── ucf101 42 | │ ├── data 43 | │ └── simple_annot 44 | │ ├── split1 45 | │ │ ├── train.csv 46 | │ │ └── val.csv 47 | │ ├── split2 48 | │ │ ├── train.csv 49 | │ │ └── val.csv 50 | │ └── split3 51 | │ ├── train.csv 52 | │ └── val.csv 53 | └── src 54 | ├── configs 55 | │ └── UCF101 56 | ├── slowfast 57 | │ ├── config 58 | │ ├── datasets 59 | │ ├── models 60 | │ ├── utils 61 | │ └── visualization 62 | └── tools 63 | ``` 64 | where `.csv` file are with rows as `video_path class_idx`. 65 | 66 | 67 | ## Download pretrained MaCLR model 68 | 69 | You can download a pre-trained SLOW_R50 MaCLR model here: [MaCLR model](https://aws-cv-sci-motion-public.s3.us-west-2.amazonaws.com/MaCLR/model_zoos/maclr.pyth) 70 | 71 | 72 | ## Transfer pretrained MaCLR model to UCF101 73 | 74 | End-to-end finetuning: 75 | ``` 76 | python run_net.py \ 77 | --cfg ../configs/UCF101/SLOW_32x8_R50.yaml \ 78 | --output_dir ../../data/output/UCF101/SLOW_32x8_R50_Pretrain_MaCLR 79 | ``` 80 | 81 | Linear-probing: 82 | ``` 83 | python run_net.py \ 84 | --cfg ../configs/UCF101/SLOW_32x8_R50_CLSONLY.yaml \ 85 | --output_dir ../../data/output/UCF101/SLOW_32x8_R50_CLSONLY_Pretrain_MaCLR 86 | ``` 87 | 88 | ## Licence 89 | This repository is released under Apache License 2.0. It enables users to download a pre-trained MaCLR representation and fine-tune it on downstream tasks, like action recognition on UCF101. It does not include training code for MaCLR, as MaCLR is currently patent pending. 90 | -------------------------------------------------------------------------------- /configs/UCF101/SLOW_32x8_R50.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | SPLIT: train 5 | BATCH_SIZE: 64 6 | EVAL_PERIOD: 1 7 | CHECKPOINT_PERIOD: 1 8 | AUTO_RESUME: True 9 | CHECKPOINT_FILE_PATH: ../../data/pretrained/maclr.pyth 10 | CHECKPOINT_TYPE: pytorch_striphead_stripssl 11 | LOAD_TRAIN_STATE: False 12 | DATA: 13 | PATH_TO_DATA_DIR: ../../data/ucf101/simple_annot/split1 14 | PATH_PREFIX: ../../data/ucf101/data 15 | USE_BGR_ORDER: False 16 | NUM_FRAMES: 32 17 | SAMPLING_RATE: 8 18 | TRAIN_JITTER_SCALES: [256, 340] 19 | TRAIN_CROP_SIZE: 224 20 | TEST_CROP_SIZE: 256 21 | INPUT_CHANNEL_NUM: [3] 22 | EXPAND_DATASET: 20 # ACTUAL_EPOCHS = cfg.DATA.EXPAND_DATASET x cfg.SOLVER.MAX_EPOCH 23 | TRAIN_AUGMENTATION_STYLE: CropResize # ResizeCrop, CropResize 24 | TRAIN_JITTER_AREAS: [0.3, 1.0] 25 | TRAIN_JITTER_ASPECT_RATIOS: [0.5, 2.0] 26 | COLOR_JITTER: 27 | PROB: 0.6 28 | GRAYSCALE: 29 | PROB: 0.5 30 | RESNET: 31 | ZERO_INIT_FINAL_BN: True 32 | WIDTH_PER_GROUP: 64 33 | NUM_GROUPS: 1 34 | DEPTH: 50 35 | TRANS_FUNC: bottleneck_transform 36 | STRIDE_1X1: False 37 | NUM_BLOCK_TEMP_KERNEL: [[3], [4], [6], [3]] 38 | CONV1_TEMPORAL_STRIDE: 2 39 | NONLOCAL: 40 | LOCATION: [[[]], [[]], [[]], [[]]] 41 | GROUP: [[1], [1], [1], [1]] 42 | INSTANTIATION: dot_product 43 | BN: 44 | USE_PRECISE_STATS: True 45 | PRECISE_STATS_PERIOD: 1 46 | NUM_BATCHES_PRECISE: 200 47 | SOLVER: 48 | BASE_LR: 0.01 49 | LR_POLICY: cosine 50 | MAX_EPOCH: 10 51 | MOMENTUM: 0.9 52 | WEIGHT_DECAY: 1e-4 53 | WARMUP_EPOCHS: 1.7 54 | WARMUP_START_LR: 0.001 55 | OPTIMIZING_METHOD: sgd 56 | MODEL: 57 | NUM_CLASSES: 101 58 | ARCH: dense_slow 59 | MODEL_NAME: ResNet 60 | LOSS_FUNC: cross_entropy 61 | DROPOUT_RATE: 0.5 62 | TEST: 63 | ENABLE: True 64 | DATASET: kinetics 65 | SPLIT: val 66 | BATCH_SIZE: 32 67 | DATA_LOADER: 68 | NUM_WORKERS: 8 69 | PIN_MEMORY: True 70 | NUM_GPUS: 8 71 | NUM_SHARDS: 1 72 | RNG_SEED: 0 73 | OUTPUT_DIR: . -------------------------------------------------------------------------------- /configs/UCF101/SLOW_32x8_R50_CLSONLY.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | SPLIT: train 5 | BATCH_SIZE: 256 6 | EVAL_PERIOD: 1 7 | CHECKPOINT_PERIOD: 1 8 | AUTO_RESUME: True 9 | CHECKPOINT_FILE_PATH: ../../data/pretrained/maclr.pyth 10 | CHECKPOINT_TYPE: pytorch_striphead_stripssl 11 | LOAD_TRAIN_STATE: False 12 | DATA: 13 | PATH_TO_DATA_DIR: ../../data/ucf101/simple_annot/split1 14 | PATH_PREFIX: ../../data/ucf101/data 15 | USE_BGR_ORDER: False 16 | NUM_FRAMES: 32 17 | SAMPLING_RATE: 8 18 | TRAIN_JITTER_SCALES: [256, 340] 19 | TRAIN_CROP_SIZE: 224 20 | TEST_CROP_SIZE: 256 21 | INPUT_CHANNEL_NUM: [3] 22 | EXPAND_DATASET: 50 # ACTUAL_EPOCHS = cfg.DATA.EXPAND_DATASET x cfg.SOLVER.MAX_EPOCH 23 | TRAIN_AUGMENTATION_STYLE: CropResize # ResizeCrop, CropResize 24 | TRAIN_JITTER_AREAS: [0.3, 1.0] 25 | TRAIN_JITTER_ASPECT_RATIOS: [0.5, 2.0] 26 | RESNET: 27 | ZERO_INIT_FINAL_BN: True 28 | WIDTH_PER_GROUP: 64 29 | NUM_GROUPS: 1 30 | DEPTH: 50 31 | TRANS_FUNC: bottleneck_transform 32 | STRIDE_1X1: False 33 | NUM_BLOCK_TEMP_KERNEL: [[3], [4], [6], [3]] 34 | CONV1_TEMPORAL_STRIDE: 2 35 | NONLOCAL: 36 | LOCATION: [[[]], [[]], [[]], [[]]] 37 | GROUP: [[1], [1], [1], [1]] 38 | INSTANTIATION: dot_product 39 | BN: 40 | USE_PRECISE_STATS: False 41 | PRECISE_STATS_PERIOD: 1 42 | NUM_BATCHES_PRECISE: 200 43 | SOLVER: 44 | BASE_LR: 32.0 45 | LR_POLICY: cosine 46 | MAX_EPOCH: 4 47 | MOMENTUM: 0.9 48 | WEIGHT_DECAY: 0.0 49 | # WARMUP_EPOCHS: 0.7 50 | # WARMUP_START_LR: 3.2 51 | OPTIMIZING_METHOD: sgd 52 | MODEL: 53 | CLS_ONLY: True 54 | NORMALIZE_FEATURE: False 55 | NUM_CLASSES: 101 56 | ARCH: dense_slow 57 | MODEL_NAME: ResNet 58 | LOSS_FUNC: cross_entropy 59 | DROPOUT_RATE: 0.0 60 | TEST: 61 | ENABLE: True 62 | DATASET: kinetics 63 | SPLIT: val 64 | BATCH_SIZE: 256 65 | DATA_LOADER: 66 | NUM_WORKERS: 8 67 | PIN_MEMORY: True 68 | NUM_GPUS: 8 69 | NUM_SHARDS: 1 70 | RNG_SEED: 0 71 | OUTPUT_DIR: . -------------------------------------------------------------------------------- /readme/modist.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/self-supervised-maclr/8a92ef0586109ad3110376e61be7e97f61f08b0d/readme/modist.gif -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=4 4 | known_standard_library=numpy,setuptools 5 | known_myself=slowfast 6 | known_third_party=fvcore,av,torch,pycocotools,yacs,termcolor,scipy,simplejson,matplotlib,detectron2,torchvision,yaml,tqdm,psutil,opencv-python,pandas,tensorboard,moviepy 7 | no_lines_before=STDLIB,THIRDPARTY 8 | sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER 9 | default_section=FIRSTPARTY 10 | 11 | [mypy] 12 | python_version=3.6 13 | ignore_missing_imports = True 14 | warn_unused_configs = True 15 | disallow_untyped_defs = True 16 | check_untyped_defs = True 17 | warn_unused_ignores = True 18 | warn_redundant_casts = True 19 | show_column_numbers = True 20 | follow_imports = silent 21 | allow_redefinition = True 22 | ; Require all functions to be annotated 23 | disallow_incomplete_defs = True 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from setuptools import find_packages, setup 5 | 6 | setup( 7 | name="slowfast", 8 | version="1.0", 9 | author="FAIR", 10 | url="unknown", 11 | description="SlowFast Video Understanding", 12 | install_requires=[ 13 | "yacs>=0.1.6", 14 | "pyyaml>=5.1", 15 | "av", 16 | "matplotlib", 17 | "termcolor>=1.1", 18 | "simplejson", 19 | "tqdm", 20 | "psutil", 21 | "matplotlib", 22 | "detectron2", 23 | "opencv-python", 24 | "pandas", 25 | "torchvision>=0.4.2", 26 | "sklearn", 27 | ], 28 | extras_require={"tensorboard_video_visualization": ["moviepy"]}, 29 | packages=find_packages(exclude=("configs", "tests")), 30 | ) 31 | -------------------------------------------------------------------------------- /slowfast/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from slowfast.utils.env import setup_environment 5 | 6 | setup_environment() 7 | -------------------------------------------------------------------------------- /slowfast/config/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /slowfast/config/custom_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Add custom configs and default values""" 5 | 6 | 7 | def add_custom_config(_C): 8 | # Add your own customized configs. 9 | pass 10 | -------------------------------------------------------------------------------- /slowfast/datasets/DATASET.md: -------------------------------------------------------------------------------- 1 | # Dataset Preparation 2 | 3 | ## Kinetics 4 | 5 | The Kinetics Dataset could be downloaded via the code released by ActivityNet: 6 | 7 | 1. Download the videos via the official [scripts](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics). 8 | 9 | 2. After all the videos were downloaded, resize the video to the short edge size of 256, then prepare the csv files for training, validation, and testing set as `train.csv`, `val.csv`, `test.csv`. The format of the csv file is: 10 | 11 | ``` 12 | path_to_video_1 label_1 13 | path_to_video_2 label_2 14 | path_to_video_3 label_3 15 | ... 16 | path_to_video_N label_N 17 | ``` 18 | 19 | All the Kinetics models in the Model Zoo are trained and tested with the same data as [Non-local Network](https://github.com/facebookresearch/video-nonlocal-net/blob/master/DATASET.md). For dataset specific issues, please reach out to the [dataset provider](https://deepmind.com/research/open-source/kinetics). 20 | 21 | ## AVA 22 | 23 | The AVA Dataset could be downloaded from the [official site](https://research.google.com/ava/download.html#ava_actions_download) 24 | 25 | We followed the same [downloading and preprocessing procedure](https://github.com/facebookresearch/video-long-term-feature-banks/blob/master/DATASET.md) as the [Long-Term Feature Banks for Detailed Video Understanding](https://arxiv.org/abs/1812.05038) do. 26 | 27 | You could follow these steps to download and preprocess the data: 28 | 29 | 1. Download videos 30 | 31 | ``` 32 | DATA_DIR="../../data/ava/videos" 33 | 34 | if [[ ! -d "${DATA_DIR}" ]]; then 35 | echo "${DATA_DIR} doesn't exist. Creating it."; 36 | mkdir -p ${DATA_DIR} 37 | fi 38 | 39 | wget https://s3.amazonaws.com/ava-dataset/annotations/ava_file_names_trainval_v2.1.txt 40 | 41 | for line in $(cat ava_file_names_trainval_v2.1.txt) 42 | do 43 | wget https://s3.amazonaws.com/ava-dataset/trainval/$line -P ${DATA_DIR} 44 | done 45 | ``` 46 | 47 | 2. Cut each video from its 15th to 30th minute 48 | 49 | ``` 50 | IN_DATA_DIR="../../data/ava/videos" 51 | OUT_DATA_DIR="../../data/ava/videos_15min" 52 | 53 | if [[ ! -d "${OUT_DATA_DIR}" ]]; then 54 | echo "${OUT_DATA_DIR} doesn't exist. Creating it."; 55 | mkdir -p ${OUT_DATA_DIR} 56 | fi 57 | 58 | for video in $(ls -A1 -U ${IN_DATA_DIR}/*) 59 | do 60 | out_name="${OUT_DATA_DIR}/${video##*/}" 61 | if [ ! -f "${out_name}" ]; then 62 | ffmpeg -ss 900 -t 901 -i "${video}" "${out_name}" 63 | fi 64 | done 65 | ``` 66 | 67 | 3. Extract frames 68 | 69 | ``` 70 | IN_DATA_DIR="../../data/ava/videos_15min" 71 | OUT_DATA_DIR="../../data/ava/frames" 72 | 73 | if [[ ! -d "${OUT_DATA_DIR}" ]]; then 74 | echo "${OUT_DATA_DIR} doesn't exist. Creating it."; 75 | mkdir -p ${OUT_DATA_DIR} 76 | fi 77 | 78 | for video in $(ls -A1 -U ${IN_DATA_DIR}/*) 79 | do 80 | video_name=${video##*/} 81 | 82 | if [[ $video_name = *".webm" ]]; then 83 | video_name=${video_name::-5} 84 | else 85 | video_name=${video_name::-4} 86 | fi 87 | 88 | out_video_dir=${OUT_DATA_DIR}/${video_name}/ 89 | mkdir -p "${out_video_dir}" 90 | 91 | out_name="${out_video_dir}/${video_name}_%06d.jpg" 92 | 93 | ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}" 94 | done 95 | ``` 96 | 97 | 4. Download annotations 98 | 99 | ``` 100 | DATA_DIR="../../data/ava/annotations" 101 | 102 | if [[ ! -d "${DATA_DIR}" ]]; then 103 | echo "${DATA_DIR} doesn't exist. Creating it."; 104 | mkdir -p ${DATA_DIR} 105 | fi 106 | 107 | wget https://research.google.com/ava/download/ava_train_v2.1.csv -P ${DATA_DIR} 108 | wget https://research.google.com/ava/download/ava_val_v2.1.csv -P ${DATA_DIR} 109 | wget https://research.google.com/ava/download/ava_action_list_v2.1_for_activitynet_2018.pbtxt -P ${DATA_DIR} 110 | wget https://research.google.com/ava/download/ava_train_excluded_timestamps_v2.1.csv -P ${DATA_DIR} 111 | wget https://research.google.com/ava/download/ava_val_excluded_timestamps_v2.1.csv -P ${DATA_DIR} 112 | ``` 113 | 114 | 5. Download "frame lists" ([train](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/frame_lists/val.csv)) and put them in 115 | the `frame_lists` folder (see structure above). 116 | 117 | 6. Download person boxes ([train](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/annotations/ava_train_predicted_boxes.csv), [val](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/annotations/ava_val_predicted_boxes.csv), [test](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/annotations/ava_test_predicted_boxes.csv)) and put them in the `annotations` folder (see structure above). 118 | If you prefer to use your own person detector, please see details 119 | in [here](https://github.com/facebookresearch/video-long-term-feature-banks/blob/master/GETTING_STARTED.md#ava-person-detector). 120 | 121 | 122 | Download the ava dataset with the following structure: 123 | 124 | ``` 125 | ava 126 | |_ frames 127 | | |_ [video name 0] 128 | | | |_ [video name 0]_000001.jpg 129 | | | |_ [video name 0]_000002.jpg 130 | | | |_ ... 131 | | |_ [video name 1] 132 | | |_ [video name 1]_000001.jpg 133 | | |_ [video name 1]_000002.jpg 134 | | |_ ... 135 | |_ frame_lists 136 | | |_ train.csv 137 | | |_ val.csv 138 | |_ annotations 139 | |_ [official AVA annotation files] 140 | |_ ava_train_predicted_boxes.csv 141 | |_ ava_val_predicted_boxes.csv 142 | ``` 143 | 144 | You could also replace the `v2.1` by `v2.2` if you need the AVA v2.2 annotation. You can also download some pre-prepared annotations from [here](https://dl.fbaipublicfiles.com/pyslowfast/annotation/ava/ava_annotations.tar). 145 | 146 | 147 | ## Charades 148 | 1. Please download the Charades RGB frames from [dataset provider](http://ai2-website.s3.amazonaws.com/data/Charades_v1_rgb.tar). 149 | 150 | 2. Download the *frame list* from the following links: ([train](https://dl.fbaipublicfiles.com/pyslowfast/dataset/charades/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/pyslowfast/dataset/charades/frame_lists/val.csv)). 151 | 152 | Please set `DATA.PATH_TO_DATA_DIR` to point to the folder containing the frame lists, and `DATA.PATH_PREFIX` to the folder containing RGB frames. 153 | 154 | 155 | ## Something-Something V2 156 | 1. Please download the dataset and annotations from [dataset provider](https://20bn.com/datasets/something-something). 157 | 158 | 2. Download the *frame list* from the following links: ([train](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/val.csv)). 159 | 160 | 3. Extract the frames at 30 FPS. (We used ffmpeg-4.1.3 with command 161 | `ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}"` 162 | in experiments.) Please put the frames in a structure consistent with the frame lists. 163 | 164 | 165 | Please put all annotation json files and the frame lists in the same folder, and set `DATA.PATH_TO_DATA_DIR` to the path. Set `DATA.PATH_PREFIX` to be the path to the folder containing extracted frames. 166 | -------------------------------------------------------------------------------- /slowfast/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Modified by AWS AI Labs on 07/15/2022 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | 5 | from .ava_dataset import Ava # noqa 6 | from .build import DATASET_REGISTRY, build_dataset # noqa 7 | from .charades import Charades # noqa 8 | from .kinetics import Kinetics # noqa 9 | from .audioset import Audioset # noqa 10 | from .ssl_dataset import Ssl_video # noqa 11 | from .ssv2 import Ssv2 # noqa 12 | from .meprod import Meprod # noqa 13 | from .meprod_v2 import Meprod_v2 # noqa 14 | from .video_feature_dataset import Video_feature_dataset # noqa 15 | -------------------------------------------------------------------------------- /slowfast/datasets/ava_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import logging 5 | import os 6 | from collections import defaultdict 7 | from fvcore.common.file_io import PathManager 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | FPS = 30 12 | AVA_VALID_FRAMES = range(902, 1799) 13 | 14 | 15 | def load_image_lists(cfg, is_train): 16 | """ 17 | Loading image paths from corresponding files. 18 | 19 | Args: 20 | cfg (CfgNode): config. 21 | is_train (bool): if it is training dataset or not. 22 | 23 | Returns: 24 | image_paths (list[list]): a list of items. Each item (also a list) 25 | corresponds to one video and contains the paths of images for 26 | this video. 27 | video_idx_to_name (list): a list which stores video names. 28 | """ 29 | list_filenames = [ 30 | os.path.join(cfg.AVA.FRAME_LIST_DIR, filename) 31 | for filename in ( 32 | cfg.AVA.TRAIN_LISTS if is_train else cfg.AVA.TEST_LISTS 33 | ) 34 | ] 35 | image_paths = defaultdict(list) 36 | video_name_to_idx = {} 37 | video_idx_to_name = [] 38 | for list_filename in list_filenames: 39 | with PathManager.open(list_filename, "r") as f: 40 | f.readline() 41 | for line in f: 42 | row = line.split() 43 | # The format of each row should follow: 44 | # original_vido_id video_id frame_id path labels. 45 | assert len(row) == 5 46 | video_name = row[0] 47 | 48 | if video_name not in video_name_to_idx: 49 | idx = len(video_name_to_idx) 50 | video_name_to_idx[video_name] = idx 51 | video_idx_to_name.append(video_name) 52 | 53 | data_key = video_name_to_idx[video_name] 54 | 55 | image_paths[data_key].append( 56 | os.path.join(cfg.AVA.FRAME_DIR, row[3]) 57 | ) 58 | 59 | image_paths = [image_paths[i] for i in range(len(image_paths))] 60 | 61 | logger.info( 62 | "Finished loading image paths from: %s" % ", ".join(list_filenames) 63 | ) 64 | 65 | return image_paths, video_idx_to_name 66 | 67 | 68 | def load_boxes_and_labels(cfg, mode): 69 | """ 70 | Loading boxes and labels from csv files. 71 | 72 | Args: 73 | cfg (CfgNode): config. 74 | mode (str): 'train', 'val', or 'test' mode. 75 | Returns: 76 | all_boxes (dict): a dict which maps from `video_name` and 77 | `frame_sec` to a list of `box`. Each `box` is a 78 | [`box_coord`, `box_labels`] where `box_coord` is the 79 | coordinates of box and 'box_labels` are the corresponding 80 | labels for the box. 81 | """ 82 | gt_lists = cfg.AVA.TRAIN_GT_BOX_LISTS if mode == "train" else [] 83 | pred_lists = ( 84 | cfg.AVA.TRAIN_PREDICT_BOX_LISTS 85 | if mode == "train" 86 | else cfg.AVA.TEST_PREDICT_BOX_LISTS 87 | ) 88 | ann_filenames = [ 89 | os.path.join(cfg.AVA.ANNOTATION_DIR, filename) 90 | for filename in gt_lists + pred_lists 91 | ] 92 | ann_is_gt_box = [True] * len(gt_lists) + [False] * len(pred_lists) 93 | 94 | detect_thresh = cfg.AVA.DETECTION_SCORE_THRESH 95 | # Only select frame_sec % 4 = 0 samples for validation if not 96 | # set FULL_TEST_ON_VAL. 97 | boxes_sample_rate = ( 98 | 4 if mode == "val" and not cfg.AVA.FULL_TEST_ON_VAL else 1 99 | ) 100 | all_boxes, count, unique_box_count = parse_bboxes_file( 101 | ann_filenames=ann_filenames, 102 | ann_is_gt_box=ann_is_gt_box, 103 | detect_thresh=detect_thresh, 104 | boxes_sample_rate=boxes_sample_rate, 105 | ) 106 | logger.info( 107 | "Finished loading annotations from: %s" % ", ".join(ann_filenames) 108 | ) 109 | logger.info("Detection threshold: {}".format(detect_thresh)) 110 | logger.info("Number of unique boxes: %d" % unique_box_count) 111 | logger.info("Number of annotations: %d" % count) 112 | 113 | return all_boxes 114 | 115 | 116 | def get_keyframe_data(boxes_and_labels): 117 | """ 118 | Getting keyframe indices, boxes and labels in the dataset. 119 | 120 | Args: 121 | boxes_and_labels (list[dict]): a list which maps from video_idx to a dict. 122 | Each dict `frame_sec` to a list of boxes and corresponding labels. 123 | 124 | Returns: 125 | keyframe_indices (list): a list of indices of the keyframes. 126 | keyframe_boxes_and_labels (list[list[list]]): a list of list which maps from 127 | video_idx and sec_idx to a list of boxes and corresponding labels. 128 | """ 129 | 130 | def sec_to_frame(sec): 131 | """ 132 | Convert time index (in second) to frame index. 133 | 0: 900 134 | 30: 901 135 | """ 136 | return (sec - 900) * FPS 137 | 138 | keyframe_indices = [] 139 | keyframe_boxes_and_labels = [] 140 | count = 0 141 | for video_idx in range(len(boxes_and_labels)): 142 | sec_idx = 0 143 | keyframe_boxes_and_labels.append([]) 144 | for sec in boxes_and_labels[video_idx].keys(): 145 | if sec not in AVA_VALID_FRAMES: 146 | continue 147 | 148 | if len(boxes_and_labels[video_idx][sec]) > 0: 149 | keyframe_indices.append( 150 | (video_idx, sec_idx, sec, sec_to_frame(sec)) 151 | ) 152 | keyframe_boxes_and_labels[video_idx].append( 153 | boxes_and_labels[video_idx][sec] 154 | ) 155 | sec_idx += 1 156 | count += 1 157 | logger.info("%d keyframes used." % count) 158 | 159 | return keyframe_indices, keyframe_boxes_and_labels 160 | 161 | 162 | def get_num_boxes_used(keyframe_indices, keyframe_boxes_and_labels): 163 | """ 164 | Get total number of used boxes. 165 | 166 | Args: 167 | keyframe_indices (list): a list of indices of the keyframes. 168 | keyframe_boxes_and_labels (list[list[list]]): a list of list which maps from 169 | video_idx and sec_idx to a list of boxes and corresponding labels. 170 | 171 | Returns: 172 | count (int): total number of used boxes. 173 | """ 174 | 175 | count = 0 176 | for video_idx, sec_idx, _, _ in keyframe_indices: 177 | count += len(keyframe_boxes_and_labels[video_idx][sec_idx]) 178 | return count 179 | 180 | 181 | def parse_bboxes_file( 182 | ann_filenames, ann_is_gt_box, detect_thresh, boxes_sample_rate=1 183 | ): 184 | """ 185 | Parse AVA bounding boxes files. 186 | Args: 187 | ann_filenames (list of str(s)): a list of AVA bounding boxes annotation files. 188 | ann_is_gt_box (list of bools): a list of boolean to indicate whether the corresponding 189 | ann_file is ground-truth. `ann_is_gt_box[i]` correspond to `ann_filenames[i]`. 190 | detect_thresh (float): threshold for accepting predicted boxes, range [0, 1]. 191 | boxes_sample_rate (int): sample rate for test bounding boxes. Get 1 every `boxes_sample_rate`. 192 | """ 193 | all_boxes = {} 194 | count = 0 195 | unique_box_count = 0 196 | for filename, is_gt_box in zip(ann_filenames, ann_is_gt_box): 197 | with PathManager.open(filename, "r") as f: 198 | for line in f: 199 | row = line.strip().split(",") 200 | # When we use predicted boxes to train/eval, we need to 201 | # ignore the boxes whose scores are below the threshold. 202 | if not is_gt_box: 203 | score = float(row[7]) 204 | if score < detect_thresh: 205 | continue 206 | 207 | video_name, frame_sec = row[0], int(row[1]) 208 | if frame_sec % boxes_sample_rate != 0: 209 | continue 210 | 211 | # Box with format [x1, y1, x2, y2] with a range of [0, 1] as float. 212 | box_key = ",".join(row[2:6]) 213 | box = list(map(float, row[2:6])) 214 | label = -1 if row[6] == "" else int(row[6]) 215 | 216 | if video_name not in all_boxes: 217 | all_boxes[video_name] = {} 218 | for sec in AVA_VALID_FRAMES: 219 | all_boxes[video_name][sec] = {} 220 | 221 | if box_key not in all_boxes[video_name][frame_sec]: 222 | all_boxes[video_name][frame_sec][box_key] = [box, []] 223 | unique_box_count += 1 224 | 225 | all_boxes[video_name][frame_sec][box_key][1].append(label) 226 | if label != -1: 227 | count += 1 228 | 229 | for video_name in all_boxes.keys(): 230 | for frame_sec in all_boxes[video_name].keys(): 231 | # Save in format of a list of [box_i, box_i_labels]. 232 | all_boxes[video_name][frame_sec] = list( 233 | all_boxes[video_name][frame_sec].values() 234 | ) 235 | 236 | return all_boxes, count, unique_box_count 237 | -------------------------------------------------------------------------------- /slowfast/datasets/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from fvcore.common.registry import Registry 5 | 6 | DATASET_REGISTRY = Registry("DATASET") 7 | DATASET_REGISTRY.__doc__ = """ 8 | Registry for dataset. 9 | 10 | The registered object will be called with `obj(cfg, split)`. 11 | The call should return a `torch.utils.data.Dataset` object. 12 | """ 13 | 14 | 15 | def build_dataset(dataset_name, cfg, split, mode): 16 | """ 17 | Build a dataset, defined by `dataset_name`. 18 | Args: 19 | dataset_name (str): the name of the dataset to be constructed. 20 | cfg (CfgNode): configs. Details can be found in 21 | slowfast/config/defaults.py 22 | split (str): the split of the data loader. Options include `train`, 23 | `val`, and `test`. 24 | Returns: 25 | Dataset: a constructed dataset specified by dataset_name. 26 | """ 27 | # Capitalize the the first letter of the dataset_name since the dataset_name 28 | # in configs may be in lowercase but the name of dataset class should always 29 | # start with an uppercase letter. 30 | name = dataset_name.capitalize() 31 | return DATASET_REGISTRY.get(name)(cfg, split, mode) 32 | -------------------------------------------------------------------------------- /slowfast/datasets/charades.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import os 5 | import random 6 | from itertools import chain as chain 7 | import torch 8 | import torch.utils.data 9 | from fvcore.common.file_io import PathManager 10 | 11 | import slowfast.utils.logging as logging 12 | 13 | from . import utils as utils 14 | from .build import DATASET_REGISTRY 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | @DATASET_REGISTRY.register() 20 | class Charades(torch.utils.data.Dataset): 21 | """ 22 | Charades video loader. Construct the Charades video loader, then sample 23 | clips from the videos. For training and validation, a single clip is randomly 24 | sampled from every video with random cropping, scaling, and flipping. For 25 | testing, multiple clips are uniformaly sampled from every video with uniform 26 | cropping. For uniform cropping, we take the left, center, and right crop if 27 | the width is larger than height, or take top, center, and bottom crop if the 28 | height is larger than the width. 29 | """ 30 | 31 | def __init__(self, cfg, mode, num_retries=10): 32 | """ 33 | Load Charades data (frame paths, labels, etc. ) to a given Dataset object. 34 | The dataset could be downloaded from Chrades official website 35 | (https://allenai.org/plato/charades/). 36 | Please see datasets/DATASET.md for more information about the data format. 37 | Args: 38 | dataset (Dataset): a Dataset object to load Charades data to. 39 | mode (string): 'train', 'val', or 'test'. 40 | Args: 41 | cfg (CfgNode): configs. 42 | mode (string): Options includes `train`, `val`, or `test` mode. 43 | For the train and val mode, the data loader will take data 44 | from the train or val set, and sample one clip per video. 45 | For the test mode, the data loader will take data from test set, 46 | and sample multiple clips per video. 47 | num_retries (int): number of retries. 48 | """ 49 | # Only support train, val, and test mode. 50 | assert mode in [ 51 | "train", 52 | "val", 53 | "test", 54 | ], "Split '{}' not supported for Charades ".format(mode) 55 | self.mode = mode 56 | self.cfg = cfg 57 | 58 | self._video_meta = {} 59 | self._num_retries = num_retries 60 | # For training or validation mode, one single clip is sampled from every 61 | # video. For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every 62 | # video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from 63 | # the frames. 64 | if self.mode in ["train", "val"]: 65 | self._num_clips = 1 66 | elif self.mode in ["test"]: 67 | self._num_clips = ( 68 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS 69 | ) 70 | 71 | logger.info("Constructing Charades {}...".format(mode)) 72 | self._construct_loader() 73 | 74 | def _construct_loader(self): 75 | """ 76 | Construct the video loader. 77 | """ 78 | path_to_file = os.path.join( 79 | self.cfg.DATA.PATH_TO_DATA_DIR, 80 | "{}.csv".format("train" if self.mode == "train" else "val"), 81 | ) 82 | assert PathManager.exists(path_to_file), "{} dir not found".format( 83 | path_to_file 84 | ) 85 | (self._path_to_videos, self._labels) = utils.load_image_lists( 86 | path_to_file, self.cfg.DATA.PATH_PREFIX, return_list=True 87 | ) 88 | 89 | if self.mode != "train": 90 | # Form video-level labels from frame level annotations. 91 | self._labels = utils.convert_to_video_level_labels(self._labels) 92 | 93 | self._path_to_videos = list( 94 | chain.from_iterable( 95 | [[x] * self._num_clips for x in self._path_to_videos] 96 | ) 97 | ) 98 | self._labels = list( 99 | chain.from_iterable([[x] * self._num_clips for x in self._labels]) 100 | ) 101 | self._spatial_temporal_idx = list( 102 | chain.from_iterable( 103 | [range(self._num_clips) for _ in range(len(self._labels))] 104 | ) 105 | ) 106 | 107 | logger.info( 108 | "Charades dataloader constructed (size: {}) from {}".format( 109 | len(self._path_to_videos), path_to_file 110 | ) 111 | ) 112 | 113 | def __getitem__(self, index): 114 | """ 115 | Given the video index, return the list of frames, label, and video 116 | index if the video frames can be fetched. 117 | Args: 118 | index (int): the video index provided by the pytorch sampler. 119 | Returns: 120 | frames (tensor): the frames of sampled from the video. The dimension 121 | is `channel` x `num frames` x `height` x `width`. 122 | label (int): the label of the current video. 123 | index (int): the index of the video. 124 | """ 125 | short_cycle_idx = None 126 | # When short cycle is used, input index is a tupple. 127 | if isinstance(index, tuple): 128 | index, short_cycle_idx = index 129 | 130 | if self.mode in ["train", "val"]: 131 | # -1 indicates random sampling. 132 | temporal_sample_index = -1 133 | spatial_sample_index = -1 134 | min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0] 135 | max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1] 136 | crop_size = self.cfg.DATA.TRAIN_CROP_SIZE 137 | if short_cycle_idx in [0, 1]: 138 | crop_size = int( 139 | round( 140 | self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx] 141 | * self.cfg.MULTIGRID.DEFAULT_S 142 | ) 143 | ) 144 | if self.cfg.MULTIGRID.DEFAULT_S > 0: 145 | # Decreasing the scale is equivalent to using a larger "span" 146 | # in a sampling grid. 147 | min_scale = int( 148 | round( 149 | float(min_scale) 150 | * crop_size 151 | / self.cfg.MULTIGRID.DEFAULT_S 152 | ) 153 | ) 154 | elif self.mode in ["test"]: 155 | temporal_sample_index = ( 156 | self._spatial_temporal_idx[index] 157 | // self.cfg.TEST.NUM_SPATIAL_CROPS 158 | ) 159 | # spatial_sample_index is in [0, 1, 2]. Corresponding to left, 160 | # center, or right if width is larger than height, and top, middle, 161 | # or bottom if height is larger than width. 162 | spatial_sample_index = ( 163 | self._spatial_temporal_idx[index] 164 | % self.cfg.TEST.NUM_SPATIAL_CROPS 165 | ) 166 | min_scale, max_scale, crop_size = [self.cfg.DATA.TEST_CROP_SIZE] * 3 167 | # The testing is deterministic and no jitter should be performed. 168 | # min_scale, max_scale, and crop_size are expect to be the same. 169 | assert len({min_scale, max_scale, crop_size}) == 1 170 | else: 171 | raise NotImplementedError( 172 | "Does not support {} mode".format(self.mode) 173 | ) 174 | 175 | num_frames = self.cfg.DATA.NUM_FRAMES 176 | sampling_rate = utils.get_random_sampling_rate( 177 | self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE, 178 | self.cfg.DATA.SAMPLING_RATE, 179 | ) 180 | video_length = len(self._path_to_videos[index]) 181 | assert video_length == len(self._labels[index]) 182 | 183 | clip_length = (num_frames - 1) * sampling_rate + 1 184 | if temporal_sample_index == -1: 185 | if clip_length > video_length: 186 | start = random.randint(video_length - clip_length, 0) 187 | else: 188 | start = random.randint(0, video_length - clip_length) 189 | else: 190 | gap = float(max(video_length - clip_length, 0)) / ( 191 | self.cfg.TEST.NUM_ENSEMBLE_VIEWS - 1 192 | ) 193 | start = int(round(gap * temporal_sample_index)) 194 | 195 | seq = [ 196 | max(min(start + i * sampling_rate, video_length - 1), 0) 197 | for i in range(num_frames) 198 | ] 199 | frames = torch.as_tensor( 200 | utils.retry_load_images( 201 | [self._path_to_videos[index][frame] for frame in seq], 202 | self._num_retries, 203 | ) 204 | ) 205 | 206 | label = utils.aggregate_labels( 207 | [self._labels[index][i] for i in range(seq[0], seq[-1] + 1)] 208 | ) 209 | label = torch.as_tensor( 210 | utils.as_binary_vector(label, self.cfg.MODEL.NUM_CLASSES) 211 | ) 212 | 213 | # Perform color normalization. 214 | frames = utils.tensor_normalize( 215 | frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD 216 | ) 217 | # T H W C -> C T H W. 218 | frames = frames.permute(3, 0, 1, 2) 219 | # Perform data augmentation. 220 | frames = utils.spatial_sampling( 221 | frames, 222 | spatial_idx=spatial_sample_index, 223 | min_scale=min_scale, 224 | max_scale=max_scale, 225 | crop_size=crop_size, 226 | random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP, 227 | inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE, 228 | ) 229 | frames = utils.pack_pathway_output(self.cfg, frames) 230 | return frames, label, index, {} 231 | 232 | def __len__(self): 233 | """ 234 | Returns: 235 | (int): the number of videos in the dataset. 236 | """ 237 | return len(self._path_to_videos) 238 | -------------------------------------------------------------------------------- /slowfast/datasets/folder.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.vision import VisionDataset 2 | 3 | from PIL import Image 4 | 5 | import os 6 | import os.path 7 | 8 | 9 | def has_file_allowed_extension(filename, extensions): 10 | """Checks if a file is an allowed extension. 11 | 12 | Args: 13 | filename (string): path to a file 14 | extensions (tuple of strings): extensions to consider (lowercase) 15 | 16 | Returns: 17 | bool: True if the filename ends with one of given extensions 18 | """ 19 | return filename.lower().endswith(extensions) 20 | 21 | 22 | def is_image_file(filename): 23 | """Checks if a file is an allowed image extension. 24 | 25 | Args: 26 | filename (string): path to a file 27 | 28 | Returns: 29 | bool: True if the filename ends with a known image extension 30 | """ 31 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 32 | 33 | 34 | def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None): 35 | instances = [] 36 | directory = os.path.expanduser(directory) 37 | both_none = extensions is None and is_valid_file is None 38 | both_something = extensions is not None and is_valid_file is not None 39 | if both_none or both_something: 40 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 41 | if extensions is not None: 42 | def is_valid_file(x): 43 | return has_file_allowed_extension(x, extensions) 44 | for target_class in sorted(class_to_idx.keys()): 45 | class_index = class_to_idx[target_class] 46 | target_dir = os.path.join(directory, target_class) 47 | if not os.path.isdir(target_dir): 48 | continue 49 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): 50 | for fname in sorted(fnames): 51 | path = os.path.join(root, fname) 52 | if is_valid_file(path): 53 | item = path, class_index 54 | instances.append(item) 55 | return instances 56 | 57 | 58 | class DatasetFolder(VisionDataset): 59 | """A generic data loader where the samples are arranged in this way: :: 60 | 61 | root/class_x/xxx.ext 62 | root/class_x/xxy.ext 63 | root/class_x/xxz.ext 64 | 65 | root/class_y/123.ext 66 | root/class_y/nsdf3.ext 67 | root/class_y/asd932_.ext 68 | 69 | Args: 70 | root (string): Root directory path. 71 | loader (callable): A function to load a sample given its path. 72 | extensions (tuple[string]): A list of allowed extensions. 73 | both extensions and is_valid_file should not be passed. 74 | transform (callable, optional): A function/transform that takes in 75 | a sample and returns a transformed version. 76 | E.g, ``transforms.RandomCrop`` for images. 77 | target_transform (callable, optional): A function/transform that takes 78 | in the target and transforms it. 79 | is_valid_file (callable, optional): A function that takes path of a file 80 | and check if the file is a valid file (used to check of corrupt files) 81 | both extensions and is_valid_file should not be passed. 82 | 83 | Attributes: 84 | classes (list): List of the class names sorted alphabetically. 85 | class_to_idx (dict): Dict with items (class_name, class_index). 86 | samples (list): List of (sample path, class_index) tuples 87 | targets (list): The class_index value for each image in the dataset 88 | """ 89 | 90 | def __init__(self, root, loader, extensions=None, transform=None, 91 | target_transform=None, is_valid_file=None, class_to_idx=None): 92 | super(DatasetFolder, self).__init__(root, transform=transform, 93 | target_transform=target_transform) 94 | if class_to_idx is None: 95 | classes, class_to_idx = self._find_classes(self.root) 96 | else: 97 | classes = list(class_to_idx.keys()) 98 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) 99 | if len(samples) == 0: 100 | msg = "Found 0 files in subfolders of: {}\n".format(self.root) 101 | if extensions is not None: 102 | msg += "Supported extensions are: {}".format(",".join(extensions)) 103 | raise RuntimeError(msg) 104 | 105 | self.loader = loader 106 | self.extensions = extensions 107 | 108 | self.classes = classes 109 | self.class_to_idx = class_to_idx 110 | self.samples = samples 111 | self.targets = [s[1] for s in samples] 112 | 113 | def _find_classes(self, dir): 114 | """ 115 | Finds the class folders in a dataset. 116 | 117 | Args: 118 | dir (string): Root directory path. 119 | 120 | Returns: 121 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 122 | 123 | Ensures: 124 | No class is a subdirectory of another. 125 | """ 126 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 127 | classes.sort() 128 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 129 | return classes, class_to_idx 130 | 131 | def __getitem__(self, index): 132 | """ 133 | Args: 134 | index (int): Index 135 | 136 | Returns: 137 | tuple: (sample, target) where target is class_index of the target class. 138 | """ 139 | path, target = self.samples[index] 140 | sample = self.loader(path) 141 | if self.transform is not None: 142 | sample = self.transform(sample) 143 | if self.target_transform is not None: 144 | target = self.target_transform(target) 145 | 146 | return sample, target 147 | 148 | def __len__(self): 149 | return len(self.samples) 150 | 151 | 152 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 153 | 154 | 155 | def pil_loader(path): 156 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 157 | with open(path, 'rb') as f: 158 | img = Image.open(f) 159 | return img.convert('RGB') 160 | 161 | 162 | def accimage_loader(path): 163 | import accimage 164 | try: 165 | return accimage.Image(path) 166 | except IOError: 167 | # Potentially a decoding problem, fall back to PIL.Image 168 | return pil_loader(path) 169 | 170 | 171 | def default_loader(path): 172 | from torchvision import get_image_backend 173 | if get_image_backend() == 'accimage': 174 | return accimage_loader(path) 175 | else: 176 | return pil_loader(path) 177 | 178 | 179 | class ImageFolder(DatasetFolder): 180 | """A generic data loader where the images are arranged in this way: :: 181 | 182 | root/dog/xxx.png 183 | root/dog/xxy.png 184 | root/dog/xxz.png 185 | 186 | root/cat/123.png 187 | root/cat/nsdf3.png 188 | root/cat/asd932_.png 189 | 190 | Args: 191 | root (string): Root directory path. 192 | transform (callable, optional): A function/transform that takes in an PIL image 193 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 194 | target_transform (callable, optional): A function/transform that takes in the 195 | target and transforms it. 196 | loader (callable, optional): A function to load an image given its path. 197 | is_valid_file (callable, optional): A function that takes path of an Image file 198 | and check if the file is a valid file (used to check of corrupt files) 199 | 200 | Attributes: 201 | classes (list): List of the class names sorted alphabetically. 202 | class_to_idx (dict): Dict with items (class_name, class_index). 203 | imgs (list): List of (image path, class_index) tuples 204 | """ 205 | 206 | def __init__(self, root, transform=None, target_transform=None, 207 | loader=default_loader, is_valid_file=None, class_to_idx=None): 208 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, 209 | transform=transform, 210 | target_transform=target_transform, 211 | is_valid_file=is_valid_file, 212 | class_to_idx=class_to_idx) 213 | self.imgs = self.samples 214 | -------------------------------------------------------------------------------- /slowfast/datasets/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Modified by AWS AI Labs on 07/15/2022 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | 5 | """Data loader.""" 6 | 7 | import itertools 8 | import numpy as np 9 | import torch 10 | from torch.utils.data._utils.collate import default_collate 11 | from torch.utils.data.distributed import DistributedSampler 12 | from torch.utils.data.sampler import RandomSampler 13 | from torch import distributions 14 | 15 | from slowfast.datasets.multigrid_helper import ShortCycleBatchSampler 16 | 17 | from .build import build_dataset 18 | 19 | # FOR MIXUP EXPERIMENT 20 | # mixup_alpha = 0.4 21 | # beta_sampler = distributions.beta.Beta(mixup_alpha, mixup_alpha) 22 | 23 | def detection_collate(batch): 24 | """ 25 | Collate function for detection task. Concatanate bboxes, labels and 26 | metadata from different samples in the first dimension instead of 27 | stacking them to have a batch-size dimension. 28 | Args: 29 | batch (tuple or list): data batch to collate. 30 | Returns: 31 | (tuple): collated detection data batch. 32 | """ 33 | inputs, labels, video_idx, extra_data = zip(*batch) 34 | inputs, video_idx = default_collate(inputs), default_collate(video_idx) 35 | labels = torch.tensor(np.concatenate(labels, axis=0)).float() 36 | 37 | collated_extra_data = {} 38 | for key in extra_data[0].keys(): 39 | data = [d[key] for d in extra_data] 40 | if key == "boxes" or key == "ori_boxes": 41 | # Append idx info to the bboxes before concatenating them. 42 | bboxes = [ 43 | np.concatenate( 44 | [np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1 45 | ) 46 | for i in range(len(data)) 47 | ] 48 | bboxes = np.concatenate(bboxes, axis=0) 49 | collated_extra_data[key] = torch.tensor(bboxes).float() 50 | elif key == "metadata": 51 | collated_extra_data[key] = torch.tensor( 52 | list(itertools.chain(*data)) 53 | ).view(-1, 2) 54 | else: 55 | collated_extra_data[key] = default_collate(data) 56 | 57 | return inputs, labels, video_idx, collated_extra_data 58 | 59 | 60 | def mixup_collate(batch): 61 | def _perturb(x, shift): 62 | bsz = x.shape[0] 63 | shift = shift % bsz 64 | assert shift != 0, 'Invalid shift' 65 | idx = torch.arange(bsz) 66 | idx = torch.cat([idx[-shift:], idx[:-shift]], dim=0) 67 | y = x[idx] 68 | return y 69 | 70 | def _proc_beta_label(beta, label_a, label_b): 71 | label = label_a.clone() 72 | label.zero_() 73 | # when label_a == 0 and label_b == 0 74 | indicator = torch.logical_and(label_a == 0, label_b == 0) 75 | label[indicator] = label_a[indicator] 76 | # when label_a == 0 and label_b > 0 77 | indicator = torch.logical_and(label_a == 0, label_b > 0) 78 | label[indicator] = label_a[indicator] 79 | beta[indicator] = 1.0 80 | # when label_a > 0 and label_b == 0 81 | indicator = torch.logical_and(label_a > 0, label_b == 0) 82 | label[indicator] = label_a[indicator] 83 | # when label_a > 0 and label_b > 0 84 | indicator = torch.logical_and(label_a > 0, label_b > 0) 85 | label[indicator] = label_a[indicator] 86 | beta[indicator] = 1.0 87 | return beta, label 88 | 89 | shift = 1 90 | batch = default_collate(batch) 91 | x, label = batch[0][0], batch[1] 92 | beta = beta_sampler.sample((x.size(0), )) 93 | beta = torch.max(beta, 1 - beta) 94 | shifted_label = _perturb(label, shift) 95 | beta, label = _proc_beta_label(beta, label, shifted_label) 96 | beta = beta.reshape(-1, 1, 1, 1, 1) 97 | x = beta * x + (1 - beta) * _perturb(x, shift) 98 | batch[0][0] = x 99 | batch[1] = label 100 | return batch 101 | 102 | 103 | def shuffle_misaligned_audio(epoch, inputs, cfg): 104 | """ 105 | Shuffle the misaligned (negative) input audio clips, 106 | such that creating positive/negative pairs that are 107 | from different videos. 108 | 109 | Args: 110 | epoch (int): the current epoch number. 111 | inputs (list of tensors): inputs to model, 112 | inputs[2] corresponds to audio inputs. 113 | cfg (CfgNode): configs. Details can be found in 114 | slowfast/config/defaults.py 115 | """ 116 | 117 | if len(inputs) > 2 and cfg.DATA.GET_MISALIGNED_AUDIO: 118 | N = inputs[2].size(0) 119 | # We only leave "hard negatives" after 120 | # cfg.DATA.MIX_NEG_EPOCH epochs 121 | SN = max(int(cfg.DATA.EASY_NEG_RATIO * N), 1) if \ 122 | epoch >= cfg.DATA.MIX_NEG_EPOCH else N 123 | with torch.no_grad(): 124 | idx = torch.arange(N) 125 | idx[:SN] = torch.arange(1, SN+1) % SN 126 | inputs[2][:, 1, ...] = inputs[2][idx, 1, ...] 127 | return inputs 128 | 129 | 130 | def construct_loader(cfg, split, mode=None, is_precise_bn=False): 131 | """ 132 | Constructs the data loader for the given dataset. 133 | Args: 134 | cfg (CfgNode): configs. Details can be found in 135 | slowfast/config/defaults.py 136 | split (str): the split of the data loader. Options include `train`, 137 | `val`, and `test`. 138 | """ 139 | if mode is None: 140 | # read out mode from split name 141 | mode = split 142 | 143 | assert mode in ["train", "val", "test"] 144 | if mode in ["train"]: 145 | dataset_name = cfg.TRAIN.DATASET 146 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 147 | shuffle = True 148 | drop_last = True 149 | elif mode in ["val"]: 150 | dataset_name = cfg.TRAIN.DATASET 151 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 152 | shuffle = False 153 | drop_last = False 154 | elif mode in ["test"]: 155 | dataset_name = cfg.TEST.DATASET 156 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 157 | shuffle = False 158 | drop_last = False 159 | 160 | # Construct the dataset 161 | dataset = build_dataset(dataset_name, cfg, split, mode) 162 | 163 | if cfg.MULTIGRID.SHORT_CYCLE and mode in ["train"] and not is_precise_bn: 164 | # Create a sampler for multi-process training 165 | sampler = ( 166 | DistributedSampler(dataset) 167 | if cfg.NUM_GPUS > 1 168 | else RandomSampler(dataset) 169 | ) 170 | batch_sampler = ShortCycleBatchSampler( 171 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg 172 | ) 173 | # Create a loader 174 | loader = torch.utils.data.DataLoader( 175 | dataset, 176 | batch_sampler=batch_sampler, 177 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 178 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 179 | ) 180 | else: 181 | # Create a sampler for multi-process training 182 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 183 | # Create a loader 184 | loader = torch.utils.data.DataLoader( 185 | dataset, 186 | batch_size=batch_size, 187 | shuffle=(False if sampler else shuffle), 188 | sampler=sampler, 189 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 190 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 191 | drop_last=drop_last, 192 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None, 193 | # FOR MIXUP EXPERIMENT 194 | # collate_fn=mixup_collate if mode == 'train' else None, 195 | ) 196 | return loader 197 | 198 | 199 | def shuffle_dataset(loader, cur_epoch): 200 | """" 201 | Shuffles the data. 202 | Args: 203 | loader (loader): data loader to perform shuffle. 204 | cur_epoch (int): number of the current epoch. 205 | """ 206 | sampler = ( 207 | loader.batch_sampler.sampler 208 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler) 209 | else loader.sampler 210 | ) 211 | assert isinstance( 212 | sampler, (RandomSampler, DistributedSampler) 213 | ), "Sampler type '{}' not supported".format(type(sampler)) 214 | # RandomSampler handles shuffling automatically 215 | if isinstance(sampler, DistributedSampler): 216 | # DistributedSampler shuffles data based on epoch 217 | sampler.set_epoch(cur_epoch) 218 | -------------------------------------------------------------------------------- /slowfast/datasets/multigrid_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Helper functions for multigrid training.""" 5 | 6 | import numpy as np 7 | from torch.utils.data.sampler import Sampler 8 | # from torch._six import int_classes as _int_classes 9 | int_classes = int 10 | 11 | 12 | class ShortCycleBatchSampler(Sampler): 13 | """ 14 | Extend Sampler to support "short cycle" sampling. 15 | See paper "A Multigrid Method for Efficiently Training Video Models", 16 | Wu et al., 2019 (https://arxiv.org/abs/1912.00998) for details. 17 | """ 18 | 19 | def __init__(self, sampler, batch_size, drop_last, cfg): 20 | if not isinstance(sampler, Sampler): 21 | raise ValueError( 22 | "sampler should be an instance of " 23 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 24 | ) 25 | if ( 26 | not isinstance(batch_size, _int_classes) 27 | or isinstance(batch_size, bool) 28 | or batch_size <= 0 29 | ): 30 | raise ValueError( 31 | "batch_size should be a positive integer value, " 32 | "but got batch_size={}".format(batch_size) 33 | ) 34 | if not isinstance(drop_last, bool): 35 | raise ValueError( 36 | "drop_last should be a boolean value, but got " 37 | "drop_last={}".format(drop_last) 38 | ) 39 | self.sampler = sampler 40 | self.drop_last = drop_last 41 | 42 | bs_factor = [ 43 | int( 44 | round( 45 | ( 46 | float(cfg.DATA.TRAIN_CROP_SIZE) 47 | / (s * cfg.MULTIGRID.DEFAULT_S) 48 | ) 49 | ** 2 50 | ) 51 | ) 52 | for s in cfg.MULTIGRID.SHORT_CYCLE_FACTORS 53 | ] 54 | 55 | self.batch_sizes = [ 56 | batch_size * bs_factor[0], 57 | batch_size * bs_factor[1], 58 | batch_size, 59 | ] 60 | 61 | def __iter__(self): 62 | counter = 0 63 | batch_size = self.batch_sizes[0] 64 | batch = [] 65 | for idx in self.sampler: 66 | batch.append((idx, counter % 3)) 67 | if len(batch) == batch_size: 68 | yield batch 69 | counter += 1 70 | batch_size = self.batch_sizes[counter % 3] 71 | batch = [] 72 | if len(batch) > 0 and not self.drop_last: 73 | yield batch 74 | 75 | def __len__(self): 76 | avg_batch_size = sum(self.batch_sizes) / 3.0 77 | if self.drop_last: 78 | return int(np.floor(len(self.sampler) / avg_batch_size)) 79 | else: 80 | return int(np.ceil(len(self.sampler) / avg_batch_size)) 81 | -------------------------------------------------------------------------------- /slowfast/datasets/video_container.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import av 5 | 6 | 7 | def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"): 8 | """ 9 | Given the path to the video, return the pyav video container. 10 | Args: 11 | path_to_vid (str): path to the video. 12 | multi_thread_decode (bool): if True, perform multi-thread decoding. 13 | backend (str): decoder backend, options include `pyav` and 14 | `torchvision`, default is `pyav`. 15 | Returns: 16 | container (container): video container. 17 | """ 18 | if backend == "torchvision": 19 | with open(path_to_vid, "rb") as fp: 20 | container = fp.read() 21 | return container 22 | elif backend == "pyav": 23 | container = av.open(path_to_vid) 24 | if multi_thread_decode: 25 | # Enable multiple threads for decoding. 26 | container.streams.video[0].thread_type = "AUTO" 27 | return container 28 | else: 29 | raise NotImplementedError("Unknown backend {}".format(backend)) 30 | -------------------------------------------------------------------------------- /slowfast/datasets/video_feature_dataset.py: -------------------------------------------------------------------------------- 1 | # Modified by AWS AI Labs on 07/15/2022 2 | import random 3 | import numpy as np 4 | import os.path as osp 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | from .build import DATASET_REGISTRY 10 | import slowfast.utils.logging as logging 11 | from fvcore.common.file_io import PathManager 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class Video_feature_dataset(Dataset): 18 | def __init__(self, cfg, split, mode): 19 | assert mode in [ 20 | "train", 21 | "val", 22 | "test", 23 | ], "Split '{}' not supported for Video_feature_dataset".format(mode) 24 | self.split = split 25 | self.mode = mode 26 | self.cfg = cfg 27 | 28 | # Extra arguments. 29 | self._num_retries = 10 30 | self.ssl_temporal_shift_max = float('inf') 31 | self.ssl_temporal_shift_min = self.cfg.TRANSFORMER.INPUT_LENGTH * self.cfg.TRANSFORMER.SHOT_STRIDE * 10 32 | # self.ssl_temporal_shift_max = 3 33 | # self.ssl_temporal_shift_min = 1 34 | 35 | self._construct_loader() 36 | print('Video_feature_dataset constructed.') 37 | 38 | 39 | def _construct_loader(self): 40 | path_to_file = osp.join(self.cfg.DATA.PATH_TO_DATA_DIR, "{}.csv".format(self.split)) 41 | path_prefix = self.cfg.DATA.PATH_PREFIX 42 | path_label_separator = self.cfg.DATA.PATH_LABEL_SEPARATOR 43 | 44 | assert PathManager.exists(path_to_file), "{} dir not found".format(path_to_file) 45 | 46 | self._video_feature_list = [] 47 | with PathManager.open(path_to_file, "r") as f: 48 | for video_idx, path_label in enumerate(f.read().splitlines()): 49 | path = path_label.split(path_label_separator)[0] 50 | # get the path composing mode. 51 | if video_idx == 0: 52 | if PathManager.exists('{}.npy'.format(path)): 53 | path_mode = 'no_prefix' 54 | elif PathManager.exists(osp.join(path_prefix, '{}.npy'.format(osp.basename(path)))): 55 | path_mode = 'standard' 56 | elif PathManager.exists(osp.join(path_prefix, '{}_feats.npy'.format(osp.splitext(osp.basename(path))[0]))): 57 | path_mode = 'remove_postfix_feats' 58 | elif PathManager.exists(osp.join(path_prefix, '{}.npy'.format(path))): 59 | path_mode = 'standard_path' 60 | else: 61 | raise RuntimeError('Cannot find the feature file.') 62 | # compose the feature path. 63 | if path_mode == 'no_prefix': 64 | video_feature_path = '{}.npy'.format(path) 65 | elif path_mode == 'remove_postfix_feats': 66 | video_feature_path = osp.join(path_prefix, '{}_feats.npy'.format(osp.splitext(osp.basename(path))[0])) 67 | elif path_mode == 'standard_path': 68 | video_feature_path = osp.join(path_prefix, '{}.npy'.format(path)) 69 | elif path_mode == 'standard': 70 | video_feature_path = osp.join(path_prefix, '{}.npy'.format(osp.basename(path))) 71 | self._video_feature_list.append(video_feature_path) 72 | 73 | logger.info("Constructing Video_feature_dataset dataloader (size: {}) from {}".format(len(self._video_feature_list), path_to_file)) 74 | 75 | 76 | def __len__(self): 77 | if self.mode in ['train']: 78 | sample_num = len(self._video_feature_list) * self.cfg.DATA.EXPAND_DATASET 79 | elif self.mode in ['val', 'test']: 80 | sample_num = len(self._video_feature_list) 81 | else: 82 | raise RuntimeError('Unknown mode.') 83 | return sample_num 84 | 85 | 86 | def __getitem__(self, index): 87 | num_shots = (self.cfg.TRANSFORMER.INPUT_LENGTH - 1) * self.cfg.TRANSFORMER.SHOT_STRIDE + 1 88 | for _ in range(self._num_retries): 89 | if self.cfg.DEBUG: index = 0 90 | video_index = index % len(self._video_feature_list) 91 | video_feature_path = self._video_feature_list[video_index] 92 | video_feat = np.load(video_feature_path) 93 | total_shots = video_feat.shape[0] 94 | 95 | if total_shots < num_shots: 96 | tmp_out = video_feat[::self.cfg.TRANSFORMER.SHOT_STRIDE, :] 97 | out = np.concatenate([tmp_out, np.zeros([self.cfg.TRANSFORMER.INPUT_LENGTH - tmp_out.shape[0], tmp_out.shape[1]])]) 98 | neg = np.concatenate([tmp_out, np.zeros([self.cfg.TRANSFORMER.INPUT_LENGTH - tmp_out.shape[0], tmp_out.shape[1]])]) 99 | else: 100 | # sample self.cfg.TRANSFORMER.INPUT_LENGTH steps 101 | st = random.randint(0, total_shots - num_shots) 102 | out = video_feat[st: (st + num_shots): self.cfg.TRANSFORMER.SHOT_STRIDE, :] 103 | # sample negative shots 104 | pre_start = max(st - self.ssl_temporal_shift_max, 0) 105 | pre_end = st - self.ssl_temporal_shift_min 106 | post_start = st + self.ssl_temporal_shift_min 107 | post_end = min(st + self.ssl_temporal_shift_max, total_shots - num_shots) 108 | pre_win = max(pre_end - pre_start + 1, 0) 109 | post_win = max(post_end - post_start + 1, 0) 110 | sampling_win = pre_win + post_win 111 | if sampling_win <= 0: 112 | # print('Unable to find a sampling window in %s. Anchor start: %d.' % (video_feature_path, st)) 113 | neg_st = random.randint(0, total_shots - num_shots) 114 | else: 115 | randval = random.randrange(sampling_win) 116 | if randval >= pre_win: 117 | neg_st = randval - pre_win + post_start 118 | else: 119 | neg_st = pre_start + randval 120 | neg = video_feat[neg_st: (neg_st + num_shots): self.cfg.TRANSFORMER.SHOT_STRIDE, :] 121 | out = torch.FloatTensor(out) 122 | neg = torch.FloatTensor(neg) 123 | video_index = torch.FloatTensor([video_index]) 124 | return out, neg, video_index 125 | else: 126 | raise RuntimeError("Failed to fetch video feature after {} retries.".format(self._num_retries)) 127 | -------------------------------------------------------------------------------- /slowfast/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from .build import MODEL_REGISTRY, build_model, wrap_ssl_model, wrap_distributed_model # noqa 5 | # from .custom_video_model_builder import * # noqa 6 | from .video_model_builder import ResNet, SlowFast # noqa 7 | -------------------------------------------------------------------------------- /slowfast/models/batchnorm_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """BatchNorm (BN) utility functions and custom batch-size BN implementations""" 5 | 6 | from functools import partial 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | from torch.autograd.function import Function 11 | 12 | import slowfast.utils.distributed as du 13 | 14 | 15 | def get_norm(cfg): 16 | """ 17 | Args: 18 | cfg (CfgNode): model building configs, details are in the comments of 19 | the config file. 20 | Returns: 21 | nn.Module: the normalization layer. 22 | """ 23 | if cfg.BN.NORM_TYPE == "batchnorm": 24 | return nn.BatchNorm3d 25 | elif cfg.BN.NORM_TYPE == "sub_batchnorm": 26 | return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS) 27 | elif cfg.BN.NORM_TYPE == "sync_batchnorm": 28 | return partial( 29 | NaiveSyncBatchNorm3d, num_sync_devices=cfg.BN.NUM_SYNC_DEVICES 30 | ) 31 | else: 32 | raise NotImplementedError( 33 | "Norm type {} is not supported".format(cfg.BN.NORM_TYPE) 34 | ) 35 | 36 | 37 | class SubBatchNorm3d(nn.Module): 38 | """ 39 | The standard BN layer computes stats across all examples in a GPU. In some 40 | cases it is desirable to compute stats across only a subset of examples 41 | (e.g., in multigrid training https://arxiv.org/abs/1912.00998). 42 | SubBatchNorm3d splits the batch dimension into N splits, and run BN on 43 | each of them separately (so that the stats are computed on each subset of 44 | examples (1/N of batch) independently. During evaluation, it aggregates 45 | the stats from all splits into one BN. 46 | """ 47 | 48 | def __init__(self, num_splits, **args): 49 | """ 50 | Args: 51 | num_splits (int): number of splits. 52 | args (list): other arguments. 53 | """ 54 | super(SubBatchNorm3d, self).__init__() 55 | self.num_splits = num_splits 56 | num_features = args["num_features"] 57 | # Keep only one set of weight and bias. 58 | if args.get("affine", True): 59 | self.affine = True 60 | args["affine"] = False 61 | self.weight = torch.nn.Parameter(torch.ones(num_features)) 62 | self.bias = torch.nn.Parameter(torch.zeros(num_features)) 63 | else: 64 | self.affine = False 65 | self.bn = nn.BatchNorm3d(**args) 66 | args["num_features"] = num_features * num_splits 67 | self.split_bn = nn.BatchNorm3d(**args) 68 | 69 | def _get_aggregated_mean_std(self, means, stds, n): 70 | """ 71 | Calculate the aggregated mean and stds. 72 | Args: 73 | means (tensor): mean values. 74 | stds (tensor): standard deviations. 75 | n (int): number of sets of means and stds. 76 | """ 77 | mean = means.view(n, -1).sum(0) / n 78 | std = ( 79 | stds.view(n, -1).sum(0) / n 80 | + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n 81 | ) 82 | return mean.detach(), std.detach() 83 | 84 | def aggregate_stats(self): 85 | """ 86 | Synchronize running_mean, and running_var. Call this before eval. 87 | """ 88 | if self.split_bn.track_running_stats: 89 | ( 90 | self.bn.running_mean.data, 91 | self.bn.running_var.data, 92 | ) = self._get_aggregated_mean_std( 93 | self.split_bn.running_mean, 94 | self.split_bn.running_var, 95 | self.num_splits, 96 | ) 97 | 98 | def forward(self, x): 99 | if self.training: 100 | n, c, t, h, w = x.shape 101 | x = x.view(n // self.num_splits, c * self.num_splits, t, h, w) 102 | x = self.split_bn(x) 103 | x = x.view(n, c, t, h, w) 104 | else: 105 | x = self.bn(x) 106 | if self.affine: 107 | x = x * self.weight.view((-1, 1, 1, 1)) 108 | x = x + self.bias.view((-1, 1, 1, 1)) 109 | return x 110 | 111 | 112 | class GroupGather(Function): 113 | """ 114 | GroupGather performs all gather on each of the local process/ GPU groups. 115 | """ 116 | 117 | @staticmethod 118 | def forward(ctx, input, num_sync_devices, num_groups): 119 | """ 120 | Perform forwarding, gathering the stats across different process/ GPU 121 | group. 122 | """ 123 | ctx.num_sync_devices = num_sync_devices 124 | ctx.num_groups = num_groups 125 | 126 | input_list = [ 127 | torch.zeros_like(input) for k in range(du.get_local_size()) 128 | ] 129 | dist.all_gather( 130 | input_list, input, async_op=False, group=du._LOCAL_PROCESS_GROUP 131 | ) 132 | 133 | inputs = torch.stack(input_list, dim=0) 134 | if num_groups > 1: 135 | rank = du.get_local_rank() 136 | group_idx = rank // num_sync_devices 137 | inputs = inputs[ 138 | group_idx 139 | * num_sync_devices : (group_idx + 1) 140 | * num_sync_devices 141 | ] 142 | inputs = torch.sum(inputs, dim=0) 143 | return inputs 144 | 145 | @staticmethod 146 | def backward(ctx, grad_output): 147 | """ 148 | Perform backwarding, gathering the gradients across different process/ GPU 149 | group. 150 | """ 151 | grad_output_list = [ 152 | torch.zeros_like(grad_output) for k in range(du.get_local_size()) 153 | ] 154 | dist.all_gather( 155 | grad_output_list, 156 | grad_output, 157 | async_op=False, 158 | group=du._LOCAL_PROCESS_GROUP, 159 | ) 160 | 161 | grads = torch.stack(grad_output_list, dim=0) 162 | if ctx.num_groups > 1: 163 | rank = du.get_local_rank() 164 | group_idx = rank // ctx.num_sync_devices 165 | grads = grads[ 166 | group_idx 167 | * ctx.num_sync_devices : (group_idx + 1) 168 | * ctx.num_sync_devices 169 | ] 170 | grads = torch.sum(grads, dim=0) 171 | return grads, None, None 172 | 173 | 174 | class NaiveSyncBatchNorm3d(nn.BatchNorm3d): 175 | def __init__(self, num_sync_devices, **args): 176 | """ 177 | Naive version of Synchronized 3D BatchNorm. 178 | Args: 179 | num_sync_devices (int): number of device to sync. 180 | args (list): other arguments. 181 | """ 182 | self.num_sync_devices = num_sync_devices 183 | if self.num_sync_devices > 0: 184 | assert du.get_local_size() % self.num_sync_devices == 0, ( 185 | du.get_local_size(), 186 | self.num_sync_devices, 187 | ) 188 | self.num_groups = du.get_local_size() // self.num_sync_devices 189 | else: 190 | self.num_sync_devices = du.get_local_size() 191 | self.num_groups = 1 192 | super(NaiveSyncBatchNorm3d, self).__init__(**args) 193 | 194 | def forward(self, input): 195 | if du.get_local_size() == 1 or not self.training: 196 | return super().forward(input) 197 | 198 | assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" 199 | C = input.shape[1] 200 | mean = torch.mean(input, dim=[0, 2, 3, 4]) 201 | meansqr = torch.mean(input * input, dim=[0, 2, 3, 4]) 202 | 203 | vec = torch.cat([mean, meansqr], dim=0) 204 | vec = GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * ( 205 | 1.0 / self.num_sync_devices 206 | ) 207 | 208 | mean, meansqr = torch.split(vec, C) 209 | var = meansqr - mean * mean 210 | self.running_mean += self.momentum * (mean.detach() - self.running_mean) 211 | self.running_var += self.momentum * (var.detach() - self.running_var) 212 | 213 | invstd = torch.rsqrt(var + self.eps) 214 | scale = self.weight * invstd 215 | bias = self.bias - mean * scale 216 | scale = scale.reshape(1, -1, 1, 1, 1) 217 | bias = bias.reshape(1, -1, 1, 1, 1) 218 | return input * scale + bias 219 | -------------------------------------------------------------------------------- /slowfast/models/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Model construction functions.""" 5 | 6 | import torch 7 | from fvcore.common.registry import Registry 8 | 9 | try: 10 | from apex import amp 11 | except ImportError: 12 | print("Please install apex from https://www.github.com/nvidia/apex to use mixed-precision training.") 13 | amp = None 14 | 15 | MODEL_REGISTRY = Registry("MODEL") 16 | MODEL_REGISTRY.__doc__ = """ 17 | Registry for video model. 18 | 19 | The registered object will be called with `obj(cfg)`. 20 | The call should return a `torch.nn.Module` object. 21 | """ 22 | 23 | 24 | def build_model(cfg, gpu_id=None, wrap_model=True): 25 | """ 26 | Builds the video model. 27 | Args: 28 | cfg (configs): configs that contains the hyper-parameters to build the 29 | backbone. Details can be seen in slowfast/config/defaults.py. 30 | gpu_id (Optional[int]): specify the gpu index to build model. 31 | """ 32 | if torch.cuda.is_available(): 33 | assert ( 34 | cfg.NUM_GPUS <= torch.cuda.device_count() 35 | ), "Cannot use more GPU devices than available" 36 | else: 37 | assert ( 38 | cfg.NUM_GPUS == 0 39 | ), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs." 40 | 41 | # Construct the model 42 | name = cfg.MODEL.MODEL_NAME 43 | model = MODEL_REGISTRY.get(name)(cfg) 44 | 45 | # Use multi-process data parallel model in the multi-gpu setting 46 | if wrap_model: 47 | wrap_distributed_model(cfg, model, gpu_id) 48 | 49 | return model 50 | 51 | 52 | def wrap_ssl_model(cfg, model, wrapper): 53 | model = MODEL_REGISTRY.get(wrapper)(cfg, model) 54 | return model 55 | 56 | 57 | def wrap_distributed_model(cfg, model, optimizer, gpu_id=None): 58 | if cfg.NUM_GPUS: 59 | if gpu_id is None: 60 | # Determine the GPU used by the current process 61 | cur_device = torch.cuda.current_device() 62 | else: 63 | cur_device = gpu_id 64 | # Transfer the model to the current GPU device 65 | model = model.cuda(device=cur_device) 66 | 67 | if cfg.TRAIN.MIX_PRECISION_LEVEL != 'O0' and amp is not None: 68 | model, optimizer = amp.initialize(model, optimizer, opt_level=cfg.TRAIN.MIX_PRECISION_LEVEL) 69 | 70 | if cfg.NUM_GPUS > 1: 71 | # Make model replica operate on the current device 72 | model = torch.nn.parallel.DistributedDataParallel( 73 | module=model, device_ids=[cur_device], output_device=cur_device, 74 | # find_unused_parameters=True, 75 | ) 76 | return model, optimizer 77 | -------------------------------------------------------------------------------- /slowfast/models/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Mlp(nn.Module): 8 | def __init__( 9 | self, 10 | in_features, 11 | hidden_features=None, 12 | out_features=None, 13 | act_layer=nn.GELU, 14 | drop_rate=0.0, 15 | ): 16 | super().__init__() 17 | self.drop_rate = drop_rate 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | if self.drop_rate > 0.0: 24 | self.drop = nn.Dropout(drop_rate) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | x = self.act(x) 29 | if self.drop_rate > 0.0: 30 | x = self.drop(x) 31 | x = self.fc2(x) 32 | if self.drop_rate > 0.0: 33 | x = self.drop(x) 34 | return x 35 | 36 | 37 | class Permute(nn.Module): 38 | def __init__(self, dims): 39 | super().__init__() 40 | self.dims = dims 41 | 42 | def forward(self, x): 43 | return x.permute(*self.dims) 44 | 45 | 46 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 47 | """ 48 | Stochastic Depth per sample. 49 | """ 50 | if drop_prob == 0.0 or not training: 51 | return x 52 | keep_prob = 1 - drop_prob 53 | shape = (x.shape[0],) + (1,) * ( 54 | x.ndim - 1 55 | ) # work with diff dim tensors, not just 2D ConvNets 56 | mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 57 | mask.floor_() # binarize 58 | output = x.div(keep_prob) * mask 59 | return output 60 | 61 | 62 | class DropPath(nn.Module): 63 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 64 | 65 | def __init__(self, drop_prob=None): 66 | super(DropPath, self).__init__() 67 | self.drop_prob = drop_prob 68 | 69 | def forward(self, x): 70 | return drop_path(x, self.drop_prob, self.training) -------------------------------------------------------------------------------- /slowfast/models/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Loss functions.""" 5 | 6 | import torch.nn as nn 7 | 8 | _LOSSES = { 9 | "cross_entropy": nn.CrossEntropyLoss, 10 | "bce": nn.BCELoss, 11 | "bce_logit": nn.BCEWithLogitsLoss, 12 | } 13 | 14 | 15 | def get_loss_func(loss_name): 16 | """ 17 | Retrieve the loss given the loss name. 18 | Args (int): 19 | loss_name: the name of the loss to use. 20 | """ 21 | if loss_name not in _LOSSES.keys(): 22 | raise NotImplementedError("Loss {} is not supported".format(loss_name)) 23 | return _LOSSES[loss_name] 24 | -------------------------------------------------------------------------------- /slowfast/models/nonlocal_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Non-local helper""" 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Nonlocal(nn.Module): 11 | """ 12 | Builds Non-local Neural Networks as a generic family of building 13 | blocks for capturing long-range dependencies. Non-local Network 14 | computes the response at a position as a weighted sum of the 15 | features at all positions. This building block can be plugged into 16 | many computer vision architectures. 17 | More details in the paper: https://arxiv.org/pdf/1711.07971.pdf 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dim, 23 | dim_inner, 24 | pool_size=None, 25 | instantiation="softmax", 26 | zero_init_final_conv=False, 27 | zero_init_final_norm=True, 28 | norm_eps=1e-5, 29 | norm_momentum=0.1, 30 | norm_module=nn.BatchNorm3d, 31 | ): 32 | """ 33 | Args: 34 | dim (int): number of dimension for the input. 35 | dim_inner (int): number of dimension inside of the Non-local block. 36 | pool_size (list): the kernel size of spatial temporal pooling, 37 | temporal pool kernel size, spatial pool kernel size, spatial 38 | pool kernel size in order. By default pool_size is None, 39 | then there would be no pooling used. 40 | instantiation (string): supports two different instantiation method: 41 | "dot_product": normalizing correlation matrix with L2. 42 | "softmax": normalizing correlation matrix with Softmax. 43 | zero_init_final_conv (bool): If true, zero initializing the final 44 | convolution of the Non-local block. 45 | zero_init_final_norm (bool): 46 | If true, zero initializing the final batch norm of the Non-local 47 | block. 48 | norm_module (nn.Module): nn.Module for the normalization layer. The 49 | default is nn.BatchNorm3d. 50 | """ 51 | super(Nonlocal, self).__init__() 52 | self.dim = dim 53 | self.dim_inner = dim_inner 54 | self.pool_size = pool_size 55 | self.instantiation = instantiation 56 | self.use_pool = ( 57 | False 58 | if pool_size is None 59 | else any((size > 1 for size in pool_size)) 60 | ) 61 | self.norm_eps = norm_eps 62 | self.norm_momentum = norm_momentum 63 | self._construct_nonlocal( 64 | zero_init_final_conv, zero_init_final_norm, norm_module 65 | ) 66 | 67 | def _construct_nonlocal( 68 | self, zero_init_final_conv, zero_init_final_norm, norm_module 69 | ): 70 | # Three convolution heads: theta, phi, and g. 71 | self.conv_theta = nn.Conv3d( 72 | self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 73 | ) 74 | self.conv_phi = nn.Conv3d( 75 | self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 76 | ) 77 | self.conv_g = nn.Conv3d( 78 | self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 79 | ) 80 | 81 | # Final convolution output. 82 | self.conv_out = nn.Conv3d( 83 | self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0 84 | ) 85 | # Zero initializing the final convolution output. 86 | self.conv_out.zero_init = zero_init_final_conv 87 | 88 | # TODO: change the name to `norm` 89 | self.bn = norm_module( 90 | num_features=self.dim, 91 | eps=self.norm_eps, 92 | momentum=self.norm_momentum, 93 | ) 94 | # Zero initializing the final bn. 95 | self.bn.transform_final_bn = zero_init_final_norm 96 | 97 | # Optional to add the spatial-temporal pooling. 98 | if self.use_pool: 99 | self.pool = nn.MaxPool3d( 100 | kernel_size=self.pool_size, 101 | stride=self.pool_size, 102 | padding=[0, 0, 0], 103 | ) 104 | 105 | def forward(self, x): 106 | x_identity = x 107 | N, C, T, H, W = x.size() 108 | 109 | theta = self.conv_theta(x) 110 | 111 | # Perform temporal-spatial pooling to reduce the computation. 112 | if self.use_pool: 113 | x = self.pool(x) 114 | 115 | phi = self.conv_phi(x) 116 | g = self.conv_g(x) 117 | 118 | theta = theta.view(N, self.dim_inner, -1) 119 | phi = phi.view(N, self.dim_inner, -1) 120 | g = g.view(N, self.dim_inner, -1) 121 | 122 | # (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW). 123 | theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi)) 124 | # For original Non-local paper, there are two main ways to normalize 125 | # the affinity tensor: 126 | # 1) Softmax normalization (norm on exp). 127 | # 2) dot_product normalization. 128 | if self.instantiation == "softmax": 129 | # Normalizing the affinity tensor theta_phi before softmax. 130 | theta_phi = theta_phi * (self.dim_inner ** -0.5) 131 | theta_phi = nn.functional.softmax(theta_phi, dim=2) 132 | elif self.instantiation == "dot_product": 133 | spatial_temporal_dim = theta_phi.shape[2] 134 | theta_phi = theta_phi / spatial_temporal_dim 135 | else: 136 | raise NotImplementedError( 137 | "Unknown norm type {}".format(self.instantiation) 138 | ) 139 | 140 | # (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW). 141 | theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g)) 142 | 143 | # (N, C, TxHxW) => (N, C, T, H, W). 144 | theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W) 145 | 146 | p = self.conv_out(theta_phi_g) 147 | p = self.bn(p) 148 | return x_identity + p 149 | -------------------------------------------------------------------------------- /slowfast/models/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Optimizer.""" 5 | 6 | import torch 7 | 8 | import slowfast.utils.lr_policy as lr_policy 9 | 10 | 11 | def construct_optimizer(model, cfg, opt_params=None): 12 | """ 13 | Construct a stochastic gradient descent or ADAM optimizer with momentum. 14 | Details can be found in: 15 | Herbert Robbins, and Sutton Monro. "A stochastic approximation method." 16 | and 17 | Diederik P.Kingma, and Jimmy Ba. 18 | "Adam: A Method for Stochastic Optimization." 19 | 20 | Args: 21 | model (model): model to perform stochastic gradient descent 22 | optimization or ADAM optimization. 23 | cfg (config): configs of hyper-parameters of SGD or ADAM, includes base 24 | learning rate, momentum, weight_decay, dampening, and etc. 25 | """ 26 | # Batchnorm parameters. 27 | bn_params = [] 28 | # Non-batchnorm parameters. 29 | non_bn_parameters = [] 30 | 31 | # The set of params to optimize. 32 | if opt_params is None: 33 | opt_params = list(model.named_parameters()) 34 | 35 | # Add bn and non-bn params. 36 | for name, p in opt_params: 37 | if "bn" in name: 38 | bn_params.append(p) 39 | else: 40 | non_bn_parameters.append(p) 41 | 42 | # Check all parameters will be passed into optimizer. 43 | if len(list(model.parameters())) != len(non_bn_parameters) + len(bn_params): 44 | print( 45 | "Warning: parameter size does not match: {} + {} != {}".format( 46 | len(non_bn_parameters), len(bn_params), len(list(model.parameters()))) 47 | ) 48 | # # Set requires_grad to False for params that are not under optimization. 49 | # opt_params_ptrs = set([x[1].data_ptr() for x in opt_params]) 50 | # grad_disabled = [] 51 | # for name, p in model.named_parameters(): 52 | # if p.data_ptr() not in opt_params_ptrs: 53 | # p.requires_grad = False 54 | # grad_disabled.append(name) 55 | # print('Gradients are disabled for the following params:') 56 | # print(grad_disabled) 57 | 58 | # Apply different weight decay to Batchnorm and non-batchnorm parameters. 59 | # In Caffe2 classification codebase the weight decay for batchnorm is 0.0. 60 | # Having a different weight decay on batchnorm might cause a performance 61 | # drop. 62 | if cfg.MODEL.CLS_ONLY: 63 | optim_params = list(filter(lambda p: p.requires_grad, model.parameters())) 64 | assert len(optim_params) == 2 # fc.weight, fc.bias 65 | else: 66 | optim_params = [ 67 | {"params": bn_params, "weight_decay": cfg.BN.WEIGHT_DECAY}, 68 | {"params": non_bn_parameters, "weight_decay": cfg.SOLVER.WEIGHT_DECAY}, 69 | ] 70 | 71 | if cfg.SOLVER.OPTIMIZING_METHOD == "sgd": 72 | return torch.optim.SGD( 73 | optim_params, 74 | lr=cfg.SOLVER.BASE_LR, 75 | momentum=cfg.SOLVER.MOMENTUM, 76 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 77 | dampening=cfg.SOLVER.DAMPENING, 78 | nesterov=cfg.SOLVER.NESTEROV, 79 | ) 80 | elif cfg.SOLVER.OPTIMIZING_METHOD == "adam": 81 | return torch.optim.Adam( 82 | optim_params, 83 | lr=cfg.SOLVER.BASE_LR, 84 | betas=(0.9, 0.999), 85 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 86 | ) 87 | elif cfg.SOLVER.OPTIMIZING_METHOD == "adamw": 88 | from transformers import AdamW 89 | return AdamW( 90 | optim_params, 91 | lr=cfg.SOLVER.BASE_LR, 92 | betas=(0.9, 0.999), 93 | eps=1e-8, 94 | ) 95 | else: 96 | raise NotImplementedError( 97 | "Does not support {} optimizer".format(cfg.SOLVER.OPTIMIZING_METHOD) 98 | ) 99 | 100 | 101 | def get_epoch_lr(cur_epoch, cfg): 102 | """ 103 | Retrieves the lr for the given epoch (as specified by the lr policy). 104 | Args: 105 | cfg (config): configs of hyper-parameters of ADAM, includes base 106 | learning rate, betas, and weight decays. 107 | cur_epoch (float): the number of epoch of the current training stage. 108 | """ 109 | return lr_policy.get_lr_at_epoch(cfg, cur_epoch) 110 | 111 | 112 | def set_lr(optimizer, new_lr): 113 | """ 114 | Sets the optimizer lr to the specified value. 115 | Args: 116 | optimizer (optim): the optimizer using to optimize the current network. 117 | new_lr (float): the new learning rate to set. 118 | """ 119 | for param_group in optimizer.param_groups: 120 | param_group["lr"] = new_lr 121 | -------------------------------------------------------------------------------- /slowfast/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/README.md: -------------------------------------------------------------------------------- 1 | The code under this folder is from the official [ActivityNet repo](https://github.com/activitynet/ActivityNet). 2 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/self-supervised-maclr/8a92ef0586109ad3110376e61be7e97f61f08b0d/slowfast/utils/ava_evaluation/__init__.py -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/ava_action_list_v2.1_for_activitynet_2018.pbtxt.txt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "bend/bow (at the waist)" 3 | id: 1 4 | } 5 | item { 6 | name: "crouch/kneel" 7 | id: 3 8 | } 9 | item { 10 | name: "dance" 11 | id: 4 12 | } 13 | item { 14 | name: "fall down" 15 | id: 5 16 | } 17 | item { 18 | name: "get up" 19 | id: 6 20 | } 21 | item { 22 | name: "jump/leap" 23 | id: 7 24 | } 25 | item { 26 | name: "lie/sleep" 27 | id: 8 28 | } 29 | item { 30 | name: "martial art" 31 | id: 9 32 | } 33 | item { 34 | name: "run/jog" 35 | id: 10 36 | } 37 | item { 38 | name: "sit" 39 | id: 11 40 | } 41 | item { 42 | name: "stand" 43 | id: 12 44 | } 45 | item { 46 | name: "swim" 47 | id: 13 48 | } 49 | item { 50 | name: "walk" 51 | id: 14 52 | } 53 | item { 54 | name: "answer phone" 55 | id: 15 56 | } 57 | item { 58 | name: "carry/hold (an object)" 59 | id: 17 60 | } 61 | item { 62 | name: "climb (e.g., a mountain)" 63 | id: 20 64 | } 65 | item { 66 | name: "close (e.g., a door, a box)" 67 | id: 22 68 | } 69 | item { 70 | name: "cut" 71 | id: 24 72 | } 73 | item { 74 | name: "dress/put on clothing" 75 | id: 26 76 | } 77 | item { 78 | name: "drink" 79 | id: 27 80 | } 81 | item { 82 | name: "drive (e.g., a car, a truck)" 83 | id: 28 84 | } 85 | item { 86 | name: "eat" 87 | id: 29 88 | } 89 | item { 90 | name: "enter" 91 | id: 30 92 | } 93 | item { 94 | name: "hit (an object)" 95 | id: 34 96 | } 97 | item { 98 | name: "lift/pick up" 99 | id: 36 100 | } 101 | item { 102 | name: "listen (e.g., to music)" 103 | id: 37 104 | } 105 | item { 106 | name: "open (e.g., a window, a car door)" 107 | id: 38 108 | } 109 | item { 110 | name: "play musical instrument" 111 | id: 41 112 | } 113 | item { 114 | name: "point to (an object)" 115 | id: 43 116 | } 117 | item { 118 | name: "pull (an object)" 119 | id: 45 120 | } 121 | item { 122 | name: "push (an object)" 123 | id: 46 124 | } 125 | item { 126 | name: "put down" 127 | id: 47 128 | } 129 | item { 130 | name: "read" 131 | id: 48 132 | } 133 | item { 134 | name: "ride (e.g., a bike, a car, a horse)" 135 | id: 49 136 | } 137 | item { 138 | name: "sail boat" 139 | id: 51 140 | } 141 | item { 142 | name: "shoot" 143 | id: 52 144 | } 145 | item { 146 | name: "smoke" 147 | id: 54 148 | } 149 | item { 150 | name: "take a photo" 151 | id: 56 152 | } 153 | item { 154 | name: "text on/look at a cellphone" 155 | id: 57 156 | } 157 | item { 158 | name: "throw" 159 | id: 58 160 | } 161 | item { 162 | name: "touch (an object)" 163 | id: 59 164 | } 165 | item { 166 | name: "turn (e.g., a screwdriver)" 167 | id: 60 168 | } 169 | item { 170 | name: "watch (e.g., TV)" 171 | id: 61 172 | } 173 | item { 174 | name: "work on a computer" 175 | id: 62 176 | } 177 | item { 178 | name: "write" 179 | id: 63 180 | } 181 | item { 182 | name: "fight/hit (a person)" 183 | id: 64 184 | } 185 | item { 186 | name: "give/serve (an object) to (a person)" 187 | id: 65 188 | } 189 | item { 190 | name: "grab (a person)" 191 | id: 66 192 | } 193 | item { 194 | name: "hand clap" 195 | id: 67 196 | } 197 | item { 198 | name: "hand shake" 199 | id: 68 200 | } 201 | item { 202 | name: "hand wave" 203 | id: 69 204 | } 205 | item { 206 | name: "hug (a person)" 207 | id: 70 208 | } 209 | item { 210 | name: "kiss (a person)" 211 | id: 72 212 | } 213 | item { 214 | name: "lift (a person)" 215 | id: 73 216 | } 217 | item { 218 | name: "listen to (a person)" 219 | id: 74 220 | } 221 | item { 222 | name: "push (another person)" 223 | id: 76 224 | } 225 | item { 226 | name: "sing to (e.g., self, a person, a group)" 227 | id: 77 228 | } 229 | item { 230 | name: "take (an object) from (a person)" 231 | id: 78 232 | } 233 | item { 234 | name: "talk to (e.g., self, a person, a group)" 235 | id: 79 236 | } 237 | item { 238 | name: "watch (a person)" 239 | id: 80 240 | } 241 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Label map utility functions.""" 16 | 17 | from __future__ import ( 18 | absolute_import, 19 | division, 20 | print_function, 21 | unicode_literals, 22 | ) 23 | import logging 24 | 25 | # from google.protobuf import text_format 26 | # from google3.third_party.tensorflow_models.object_detection.protos import string_int_label_map_pb2 27 | 28 | 29 | def _validate_label_map(label_map): 30 | """Checks if a label map is valid. 31 | 32 | Args: 33 | label_map: StringIntLabelMap to validate. 34 | 35 | Raises: 36 | ValueError: if label map is invalid. 37 | """ 38 | for item in label_map.item: 39 | if item.id < 1: 40 | raise ValueError("Label map ids should be >= 1.") 41 | 42 | 43 | def create_category_index(categories): 44 | """Creates dictionary of COCO compatible categories keyed by category id. 45 | 46 | Args: 47 | categories: a list of dicts, each of which has the following keys: 48 | 'id': (required) an integer id uniquely identifying this category. 49 | 'name': (required) string representing category name 50 | e.g., 'cat', 'dog', 'pizza'. 51 | 52 | Returns: 53 | category_index: a dict containing the same entries as categories, but keyed 54 | by the 'id' field of each category. 55 | """ 56 | category_index = {} 57 | for cat in categories: 58 | category_index[cat["id"]] = cat 59 | return category_index 60 | 61 | 62 | def get_max_label_map_index(label_map): 63 | """Get maximum index in label map. 64 | 65 | Args: 66 | label_map: a StringIntLabelMapProto 67 | 68 | Returns: 69 | an integer 70 | """ 71 | return max([item.id for item in label_map.item]) 72 | 73 | 74 | def convert_label_map_to_categories( 75 | label_map, max_num_classes, use_display_name=True 76 | ): 77 | """Loads label map proto and returns categories list compatible with eval. 78 | 79 | This function loads a label map and returns a list of dicts, each of which 80 | has the following keys: 81 | 'id': (required) an integer id uniquely identifying this category. 82 | 'name': (required) string representing category name 83 | e.g., 'cat', 'dog', 'pizza'. 84 | We only allow class into the list if its id-label_id_offset is 85 | between 0 (inclusive) and max_num_classes (exclusive). 86 | If there are several items mapping to the same id in the label map, 87 | we will only keep the first one in the categories list. 88 | 89 | Args: 90 | label_map: a StringIntLabelMapProto or None. If None, a default categories 91 | list is created with max_num_classes categories. 92 | max_num_classes: maximum number of (consecutive) label indices to include. 93 | use_display_name: (boolean) choose whether to load 'display_name' field 94 | as category name. If False or if the display_name field does not exist, 95 | uses 'name' field as category names instead. 96 | Returns: 97 | categories: a list of dictionaries representing all possible categories. 98 | """ 99 | categories = [] 100 | list_of_ids_already_added = [] 101 | if not label_map: 102 | label_id_offset = 1 103 | for class_id in range(max_num_classes): 104 | categories.append( 105 | { 106 | "id": class_id + label_id_offset, 107 | "name": "category_{}".format(class_id + label_id_offset), 108 | } 109 | ) 110 | return categories 111 | for item in label_map.item: 112 | if not 0 < item.id <= max_num_classes: 113 | logging.info( 114 | "Ignore item %d since it falls outside of requested " 115 | "label range.", 116 | item.id, 117 | ) 118 | continue 119 | if use_display_name and item.HasField("display_name"): 120 | name = item.display_name 121 | else: 122 | name = item.name 123 | if item.id not in list_of_ids_already_added: 124 | list_of_ids_already_added.append(item.id) 125 | categories.append({"id": item.id, "name": name}) 126 | return categories 127 | 128 | 129 | def load_labelmap(path): 130 | """Loads label map proto. 131 | 132 | Args: 133 | path: path to StringIntLabelMap proto text file. 134 | Returns: 135 | a StringIntLabelMapProto 136 | """ 137 | with open(path, "r") as fid: 138 | label_map_string = fid.read() 139 | label_map = string_int_label_map_pb2.StringIntLabelMap() 140 | try: 141 | text_format.Merge(label_map_string, label_map) 142 | except text_format.ParseError: 143 | label_map.ParseFromString(label_map_string) 144 | _validate_label_map(label_map) 145 | return label_map 146 | 147 | 148 | def get_label_map_dict(label_map_path, use_display_name=False): 149 | """Reads a label map and returns a dictionary of label names to id. 150 | 151 | Args: 152 | label_map_path: path to label_map. 153 | use_display_name: whether to use the label map items' display names as keys. 154 | 155 | Returns: 156 | A dictionary mapping label names to id. 157 | """ 158 | label_map = load_labelmap(label_map_path) 159 | label_map_dict = {} 160 | for item in label_map.item: 161 | if use_display_name: 162 | label_map_dict[item.display_name] = item.id 163 | else: 164 | label_map_dict[item.name] = item.id 165 | return label_map_dict 166 | 167 | 168 | def create_category_index_from_labelmap(label_map_path): 169 | """Reads a label map and returns a category index. 170 | 171 | Args: 172 | label_map_path: Path to `StringIntLabelMap` proto text file. 173 | 174 | Returns: 175 | A category index, which is a dictionary that maps integer ids to dicts 176 | containing categories, e.g. 177 | {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} 178 | """ 179 | label_map = load_labelmap(label_map_path) 180 | max_num_classes = max(item.id for item in label_map.item) 181 | categories = convert_label_map_to_categories(label_map, max_num_classes) 182 | return create_category_index(categories) 183 | 184 | 185 | def create_class_agnostic_category_index(): 186 | """Creates a category index with a single `object` class.""" 187 | return {1: {"id": 1, "name": "object"}} 188 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for computing metrics like precision, recall, CorLoc and etc.""" 17 | from __future__ import division 18 | import numpy as np 19 | 20 | 21 | def compute_precision_recall(scores, labels, num_gt): 22 | """Compute precision and recall. 23 | 24 | Args: 25 | scores: A float numpy array representing detection score 26 | labels: A boolean numpy array representing true/false positive labels 27 | num_gt: Number of ground truth instances 28 | 29 | Raises: 30 | ValueError: if the input is not of the correct format 31 | 32 | Returns: 33 | precision: Fraction of positive instances over detected ones. This value is 34 | None if no ground truth labels are present. 35 | recall: Fraction of detected positive instance over all positive instances. 36 | This value is None if no ground truth labels are present. 37 | 38 | """ 39 | if ( 40 | not isinstance(labels, np.ndarray) 41 | or labels.dtype != np.bool 42 | or len(labels.shape) != 1 43 | ): 44 | raise ValueError("labels must be single dimension bool numpy array") 45 | 46 | if not isinstance(scores, np.ndarray) or len(scores.shape) != 1: 47 | raise ValueError("scores must be single dimension numpy array") 48 | 49 | if num_gt < np.sum(labels): 50 | raise ValueError( 51 | "Number of true positives must be smaller than num_gt." 52 | ) 53 | 54 | if len(scores) != len(labels): 55 | raise ValueError("scores and labels must be of the same size.") 56 | 57 | if num_gt == 0: 58 | return None, None 59 | 60 | sorted_indices = np.argsort(scores) 61 | sorted_indices = sorted_indices[::-1] 62 | labels = labels.astype(int) 63 | true_positive_labels = labels[sorted_indices] 64 | false_positive_labels = 1 - true_positive_labels 65 | cum_true_positives = np.cumsum(true_positive_labels) 66 | cum_false_positives = np.cumsum(false_positive_labels) 67 | precision = cum_true_positives.astype(float) / ( 68 | cum_true_positives + cum_false_positives 69 | ) 70 | recall = cum_true_positives.astype(float) / num_gt 71 | return precision, recall 72 | 73 | 74 | def compute_average_precision(precision, recall): 75 | """Compute Average Precision according to the definition in VOCdevkit. 76 | 77 | Precision is modified to ensure that it does not decrease as recall 78 | decrease. 79 | 80 | Args: 81 | precision: A float [N, 1] numpy array of precisions 82 | recall: A float [N, 1] numpy array of recalls 83 | 84 | Raises: 85 | ValueError: if the input is not of the correct format 86 | 87 | Returns: 88 | average_precison: The area under the precision recall curve. NaN if 89 | precision and recall are None. 90 | 91 | """ 92 | if precision is None: 93 | if recall is not None: 94 | raise ValueError("If precision is None, recall must also be None") 95 | return np.NAN 96 | 97 | if not isinstance(precision, np.ndarray) or not isinstance( 98 | recall, np.ndarray 99 | ): 100 | raise ValueError("precision and recall must be numpy array") 101 | if precision.dtype != np.float or recall.dtype != np.float: 102 | raise ValueError("input must be float numpy array.") 103 | if len(precision) != len(recall): 104 | raise ValueError("precision and recall must be of the same size.") 105 | if not precision.size: 106 | return 0.0 107 | if np.amin(precision) < 0 or np.amax(precision) > 1: 108 | raise ValueError("Precision must be in the range of [0, 1].") 109 | if np.amin(recall) < 0 or np.amax(recall) > 1: 110 | raise ValueError("recall must be in the range of [0, 1].") 111 | if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)): 112 | raise ValueError("recall must be a non-decreasing array") 113 | 114 | recall = np.concatenate([[0], recall, [1]]) 115 | precision = np.concatenate([[0], precision, [0]]) 116 | 117 | # Preprocess precision to be a non-decreasing array 118 | for i in range(len(precision) - 2, -1, -1): 119 | precision[i] = np.maximum(precision[i], precision[i + 1]) 120 | 121 | indices = np.where(recall[1:] != recall[:-1])[0] + 1 122 | average_precision = np.sum( 123 | (recall[indices] - recall[indices - 1]) * precision[indices] 124 | ) 125 | return average_precision 126 | 127 | 128 | def compute_cor_loc( 129 | num_gt_imgs_per_class, num_images_correctly_detected_per_class 130 | ): 131 | """Compute CorLoc according to the definition in the following paper. 132 | 133 | https://www.robots.ox.ac.uk/~vgg/rg/papers/deselaers-eccv10.pdf 134 | 135 | Returns nans if there are no ground truth images for a class. 136 | 137 | Args: 138 | num_gt_imgs_per_class: 1D array, representing number of images containing 139 | at least one object instance of a particular class 140 | num_images_correctly_detected_per_class: 1D array, representing number of 141 | images that are correctly detected at least one object instance of a 142 | particular class 143 | 144 | Returns: 145 | corloc_per_class: A float numpy array represents the corloc score of each 146 | class 147 | """ 148 | # Divide by zero expected for classes with no gt examples. 149 | with np.errstate(divide="ignore", invalid="ignore"): 150 | return np.where( 151 | num_gt_imgs_per_class == 0, 152 | np.nan, 153 | num_images_correctly_detected_per_class / num_gt_imgs_per_class, 154 | ) 155 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/np_box_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxList classes and functions.""" 17 | 18 | from __future__ import ( 19 | absolute_import, 20 | division, 21 | print_function, 22 | unicode_literals, 23 | ) 24 | import numpy as np 25 | 26 | 27 | class BoxList(object): 28 | """Box collection. 29 | 30 | BoxList represents a list of bounding boxes as numpy array, where each 31 | bounding box is represented as a row of 4 numbers, 32 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a 33 | given list correspond to a single image. 34 | 35 | Optionally, users can add additional related fields (such as 36 | objectness/classification scores). 37 | """ 38 | 39 | def __init__(self, data): 40 | """Constructs box collection. 41 | 42 | Args: 43 | data: a numpy array of shape [N, 4] representing box coordinates 44 | 45 | Raises: 46 | ValueError: if bbox data is not a numpy array 47 | ValueError: if invalid dimensions for bbox data 48 | """ 49 | if not isinstance(data, np.ndarray): 50 | raise ValueError("data must be a numpy array.") 51 | if len(data.shape) != 2 or data.shape[1] != 4: 52 | raise ValueError("Invalid dimensions for box data.") 53 | if data.dtype != np.float32 and data.dtype != np.float64: 54 | raise ValueError( 55 | "Invalid data type for box data: float is required." 56 | ) 57 | if not self._is_valid_boxes(data): 58 | raise ValueError( 59 | "Invalid box data. data must be a numpy array of " 60 | "N*[y_min, x_min, y_max, x_max]" 61 | ) 62 | self.data = {"boxes": data} 63 | 64 | def num_boxes(self): 65 | """Return number of boxes held in collections.""" 66 | return self.data["boxes"].shape[0] 67 | 68 | def get_extra_fields(self): 69 | """Return all non-box fields.""" 70 | return [k for k in self.data.keys() if k != "boxes"] 71 | 72 | def has_field(self, field): 73 | return field in self.data 74 | 75 | def add_field(self, field, field_data): 76 | """Add data to a specified field. 77 | 78 | Args: 79 | field: a string parameter used to speficy a related field to be accessed. 80 | field_data: a numpy array of [N, ...] representing the data associated 81 | with the field. 82 | Raises: 83 | ValueError: if the field is already exist or the dimension of the field 84 | data does not matches the number of boxes. 85 | """ 86 | if self.has_field(field): 87 | raise ValueError("Field " + field + "already exists") 88 | if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes(): 89 | raise ValueError("Invalid dimensions for field data") 90 | self.data[field] = field_data 91 | 92 | def get(self): 93 | """Convenience function for accesssing box coordinates. 94 | 95 | Returns: 96 | a numpy array of shape [N, 4] representing box corners 97 | """ 98 | return self.get_field("boxes") 99 | 100 | def get_field(self, field): 101 | """Accesses data associated with the specified field in the box collection. 102 | 103 | Args: 104 | field: a string parameter used to speficy a related field to be accessed. 105 | 106 | Returns: 107 | a numpy 1-d array representing data of an associated field 108 | 109 | Raises: 110 | ValueError: if invalid field 111 | """ 112 | if not self.has_field(field): 113 | raise ValueError("field {} does not exist".format(field)) 114 | return self.data[field] 115 | 116 | def get_coordinates(self): 117 | """Get corner coordinates of boxes. 118 | 119 | Returns: 120 | a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max] 121 | """ 122 | box_coordinates = self.get() 123 | y_min = box_coordinates[:, 0] 124 | x_min = box_coordinates[:, 1] 125 | y_max = box_coordinates[:, 2] 126 | x_max = box_coordinates[:, 3] 127 | return [y_min, x_min, y_max, x_max] 128 | 129 | def _is_valid_boxes(self, data): 130 | """Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin]. 131 | 132 | Args: 133 | data: a numpy array of shape [N, 4] representing box coordinates 134 | 135 | Returns: 136 | a boolean indicating whether all ymax of boxes are equal or greater than 137 | ymin, and all xmax of boxes are equal or greater than xmin. 138 | """ 139 | if data.shape[0] > 0: 140 | for i in range(data.shape[0]): 141 | if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]: 142 | return False 143 | return True 144 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/np_box_mask_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxMaskList classes and functions.""" 17 | 18 | from __future__ import ( 19 | absolute_import, 20 | division, 21 | print_function, 22 | unicode_literals, 23 | ) 24 | import numpy as np 25 | 26 | from . import np_box_list 27 | 28 | 29 | class BoxMaskList(np_box_list.BoxList): 30 | """Convenience wrapper for BoxList with masks. 31 | 32 | BoxMaskList extends the np_box_list.BoxList to contain masks as well. 33 | In particular, its constructor receives both boxes and masks. Note that the 34 | masks correspond to the full image. 35 | """ 36 | 37 | def __init__(self, box_data, mask_data): 38 | """Constructs box collection. 39 | 40 | Args: 41 | box_data: a numpy array of shape [N, 4] representing box coordinates 42 | mask_data: a numpy array of shape [N, height, width] representing masks 43 | with values are in {0,1}. The masks correspond to the full 44 | image. The height and the width will be equal to image height and width. 45 | 46 | Raises: 47 | ValueError: if bbox data is not a numpy array 48 | ValueError: if invalid dimensions for bbox data 49 | ValueError: if mask data is not a numpy array 50 | ValueError: if invalid dimension for mask data 51 | """ 52 | super(BoxMaskList, self).__init__(box_data) 53 | if not isinstance(mask_data, np.ndarray): 54 | raise ValueError("Mask data must be a numpy array.") 55 | if len(mask_data.shape) != 3: 56 | raise ValueError("Invalid dimensions for mask data.") 57 | if mask_data.dtype != np.uint8: 58 | raise ValueError( 59 | "Invalid data type for mask data: uint8 is required." 60 | ) 61 | if mask_data.shape[0] != box_data.shape[0]: 62 | raise ValueError( 63 | "There should be the same number of boxes and masks." 64 | ) 65 | self.data["masks"] = mask_data 66 | 67 | def get_masks(self): 68 | """Convenience function for accessing masks. 69 | 70 | Returns: 71 | a numpy array of shape [N, height, width] representing masks 72 | """ 73 | return self.get_field("masks") 74 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/np_box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for [N, 4] numpy arrays representing bounding boxes. 17 | 18 | Example box operations that are supported: 19 | * Areas: compute bounding box areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | from __future__ import ( 23 | absolute_import, 24 | division, 25 | print_function, 26 | unicode_literals, 27 | ) 28 | import numpy as np 29 | 30 | 31 | def area(boxes): 32 | """Computes area of boxes. 33 | 34 | Args: 35 | boxes: Numpy array with shape [N, 4] holding N boxes 36 | 37 | Returns: 38 | a numpy array with shape [N*1] representing box areas 39 | """ 40 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 41 | 42 | 43 | def intersection(boxes1, boxes2): 44 | """Compute pairwise intersection areas between boxes. 45 | 46 | Args: 47 | boxes1: a numpy array with shape [N, 4] holding N boxes 48 | boxes2: a numpy array with shape [M, 4] holding M boxes 49 | 50 | Returns: 51 | a numpy array with shape [N*M] representing pairwise intersection area 52 | """ 53 | [y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1) 54 | [y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1) 55 | 56 | all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2)) 57 | all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2)) 58 | intersect_heights = np.maximum( 59 | np.zeros(all_pairs_max_ymin.shape), 60 | all_pairs_min_ymax - all_pairs_max_ymin, 61 | ) 62 | all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2)) 63 | all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2)) 64 | intersect_widths = np.maximum( 65 | np.zeros(all_pairs_max_xmin.shape), 66 | all_pairs_min_xmax - all_pairs_max_xmin, 67 | ) 68 | return intersect_heights * intersect_widths 69 | 70 | 71 | def iou(boxes1, boxes2): 72 | """Computes pairwise intersection-over-union between box collections. 73 | 74 | Args: 75 | boxes1: a numpy array with shape [N, 4] holding N boxes. 76 | boxes2: a numpy array with shape [M, 4] holding N boxes. 77 | 78 | Returns: 79 | a numpy array with shape [N, M] representing pairwise iou scores. 80 | """ 81 | intersect = intersection(boxes1, boxes2) 82 | area1 = area(boxes1) 83 | area2 = area(boxes2) 84 | union = ( 85 | np.expand_dims(area1, axis=1) 86 | + np.expand_dims(area2, axis=0) 87 | - intersect 88 | ) 89 | return intersect / union 90 | 91 | 92 | def ioa(boxes1, boxes2): 93 | """Computes pairwise intersection-over-area between box collections. 94 | 95 | Intersection-over-area (ioa) between two boxes box1 and box2 is defined as 96 | their intersection area over box2's area. Note that ioa is not symmetric, 97 | that is, IOA(box1, box2) != IOA(box2, box1). 98 | 99 | Args: 100 | boxes1: a numpy array with shape [N, 4] holding N boxes. 101 | boxes2: a numpy array with shape [M, 4] holding N boxes. 102 | 103 | Returns: 104 | a numpy array with shape [N, M] representing pairwise ioa scores. 105 | """ 106 | intersect = intersection(boxes1, boxes2) 107 | areas = np.expand_dims(area(boxes2), axis=0) 108 | return intersect / areas 109 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/np_mask_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for [N, height, width] numpy arrays representing masks. 17 | 18 | Example mask operations that are supported: 19 | * Areas: compute mask areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | from __future__ import ( 23 | absolute_import, 24 | division, 25 | print_function, 26 | unicode_literals, 27 | ) 28 | import numpy as np 29 | 30 | EPSILON = 1e-7 31 | 32 | 33 | def area(masks): 34 | """Computes area of masks. 35 | 36 | Args: 37 | masks: Numpy array with shape [N, height, width] holding N masks. Masks 38 | values are of type np.uint8 and values are in {0,1}. 39 | 40 | Returns: 41 | a numpy array with shape [N*1] representing mask areas. 42 | 43 | Raises: 44 | ValueError: If masks.dtype is not np.uint8 45 | """ 46 | if masks.dtype != np.uint8: 47 | raise ValueError("Masks type should be np.uint8") 48 | return np.sum(masks, axis=(1, 2), dtype=np.float32) 49 | 50 | 51 | def intersection(masks1, masks2): 52 | """Compute pairwise intersection areas between masks. 53 | 54 | Args: 55 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 56 | values are of type np.uint8 and values are in {0,1}. 57 | masks2: a numpy array with shape [M, height, width] holding M masks. Masks 58 | values are of type np.uint8 and values are in {0,1}. 59 | 60 | Returns: 61 | a numpy array with shape [N*M] representing pairwise intersection area. 62 | 63 | Raises: 64 | ValueError: If masks1 and masks2 are not of type np.uint8. 65 | """ 66 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 67 | raise ValueError("masks1 and masks2 should be of type np.uint8") 68 | n = masks1.shape[0] 69 | m = masks2.shape[0] 70 | answer = np.zeros([n, m], dtype=np.float32) 71 | for i in np.arange(n): 72 | for j in np.arange(m): 73 | answer[i, j] = np.sum( 74 | np.minimum(masks1[i], masks2[j]), dtype=np.float32 75 | ) 76 | return answer 77 | 78 | 79 | def iou(masks1, masks2): 80 | """Computes pairwise intersection-over-union between mask collections. 81 | 82 | Args: 83 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 84 | values are of type np.uint8 and values are in {0,1}. 85 | masks2: a numpy array with shape [M, height, width] holding N masks. Masks 86 | values are of type np.uint8 and values are in {0,1}. 87 | 88 | Returns: 89 | a numpy array with shape [N, M] representing pairwise iou scores. 90 | 91 | Raises: 92 | ValueError: If masks1 and masks2 are not of type np.uint8. 93 | """ 94 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 95 | raise ValueError("masks1 and masks2 should be of type np.uint8") 96 | intersect = intersection(masks1, masks2) 97 | area1 = area(masks1) 98 | area2 = area(masks2) 99 | union = ( 100 | np.expand_dims(area1, axis=1) 101 | + np.expand_dims(area2, axis=0) 102 | - intersect 103 | ) 104 | return intersect / np.maximum(union, EPSILON) 105 | 106 | 107 | def ioa(masks1, masks2): 108 | """Computes pairwise intersection-over-area between box collections. 109 | 110 | Intersection-over-area (ioa) between two masks, mask1 and mask2 is defined as 111 | their intersection area over mask2's area. Note that ioa is not symmetric, 112 | that is, IOA(mask1, mask2) != IOA(mask2, mask1). 113 | 114 | Args: 115 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 116 | values are of type np.uint8 and values are in {0,1}. 117 | masks2: a numpy array with shape [M, height, width] holding N masks. Masks 118 | values are of type np.uint8 and values are in {0,1}. 119 | 120 | Returns: 121 | a numpy array with shape [N, M] representing pairwise ioa scores. 122 | 123 | Raises: 124 | ValueError: If masks1 and masks2 are not of type np.uint8. 125 | """ 126 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 127 | raise ValueError("masks1 and masks2 should be of type np.uint8") 128 | intersect = intersection(masks1, masks2) 129 | areas = np.expand_dims(area(masks2), axis=0) 130 | return intersect / (areas + EPSILON) 131 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/standard_fields.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains classes specifying naming conventions used for object detection. 17 | 18 | 19 | Specifies: 20 | InputDataFields: standard fields used by reader/preprocessor/batcher. 21 | DetectionResultFields: standard fields returned by object detector. 22 | BoxListFields: standard field used by BoxList 23 | TfExampleFields: standard fields for tf-example data format (go/tf-example). 24 | """ 25 | 26 | 27 | from __future__ import ( 28 | absolute_import, 29 | division, 30 | print_function, 31 | unicode_literals, 32 | ) 33 | 34 | 35 | class InputDataFields(object): 36 | """Names for the input tensors. 37 | 38 | Holds the standard data field names to use for identifying input tensors. This 39 | should be used by the decoder to identify keys for the returned tensor_dict 40 | containing input tensors. And it should be used by the model to identify the 41 | tensors it needs. 42 | 43 | Attributes: 44 | image: image. 45 | original_image: image in the original input size. 46 | key: unique key corresponding to image. 47 | source_id: source of the original image. 48 | filename: original filename of the dataset (without common path). 49 | groundtruth_image_classes: image-level class labels. 50 | groundtruth_boxes: coordinates of the ground truth boxes in the image. 51 | groundtruth_classes: box-level class labels. 52 | groundtruth_label_types: box-level label types (e.g. explicit negative). 53 | groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead] 54 | is the groundtruth a single object or a crowd. 55 | groundtruth_area: area of a groundtruth segment. 56 | groundtruth_difficult: is a `difficult` object 57 | groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the 58 | same class, forming a connected group, where instances are heavily 59 | occluding each other. 60 | proposal_boxes: coordinates of object proposal boxes. 61 | proposal_objectness: objectness score of each proposal. 62 | groundtruth_instance_masks: ground truth instance masks. 63 | groundtruth_instance_boundaries: ground truth instance boundaries. 64 | groundtruth_instance_classes: instance mask-level class labels. 65 | groundtruth_keypoints: ground truth keypoints. 66 | groundtruth_keypoint_visibilities: ground truth keypoint visibilities. 67 | groundtruth_label_scores: groundtruth label scores. 68 | groundtruth_weights: groundtruth weight factor for bounding boxes. 69 | num_groundtruth_boxes: number of groundtruth boxes. 70 | true_image_shapes: true shapes of images in the resized images, as resized 71 | images can be padded with zeros. 72 | """ 73 | 74 | image = "image" 75 | original_image = "original_image" 76 | key = "key" 77 | source_id = "source_id" 78 | filename = "filename" 79 | groundtruth_image_classes = "groundtruth_image_classes" 80 | groundtruth_boxes = "groundtruth_boxes" 81 | groundtruth_classes = "groundtruth_classes" 82 | groundtruth_label_types = "groundtruth_label_types" 83 | groundtruth_is_crowd = "groundtruth_is_crowd" 84 | groundtruth_area = "groundtruth_area" 85 | groundtruth_difficult = "groundtruth_difficult" 86 | groundtruth_group_of = "groundtruth_group_of" 87 | proposal_boxes = "proposal_boxes" 88 | proposal_objectness = "proposal_objectness" 89 | groundtruth_instance_masks = "groundtruth_instance_masks" 90 | groundtruth_instance_boundaries = "groundtruth_instance_boundaries" 91 | groundtruth_instance_classes = "groundtruth_instance_classes" 92 | groundtruth_keypoints = "groundtruth_keypoints" 93 | groundtruth_keypoint_visibilities = "groundtruth_keypoint_visibilities" 94 | groundtruth_label_scores = "groundtruth_label_scores" 95 | groundtruth_weights = "groundtruth_weights" 96 | num_groundtruth_boxes = "num_groundtruth_boxes" 97 | true_image_shape = "true_image_shape" 98 | 99 | 100 | class DetectionResultFields(object): 101 | """Naming conventions for storing the output of the detector. 102 | 103 | Attributes: 104 | source_id: source of the original image. 105 | key: unique key corresponding to image. 106 | detection_boxes: coordinates of the detection boxes in the image. 107 | detection_scores: detection scores for the detection boxes in the image. 108 | detection_classes: detection-level class labels. 109 | detection_masks: contains a segmentation mask for each detection box. 110 | detection_boundaries: contains an object boundary for each detection box. 111 | detection_keypoints: contains detection keypoints for each detection box. 112 | num_detections: number of detections in the batch. 113 | """ 114 | 115 | source_id = "source_id" 116 | key = "key" 117 | detection_boxes = "detection_boxes" 118 | detection_scores = "detection_scores" 119 | detection_classes = "detection_classes" 120 | detection_masks = "detection_masks" 121 | detection_boundaries = "detection_boundaries" 122 | detection_keypoints = "detection_keypoints" 123 | num_detections = "num_detections" 124 | 125 | 126 | class BoxListFields(object): 127 | """Naming conventions for BoxLists. 128 | 129 | Attributes: 130 | boxes: bounding box coordinates. 131 | classes: classes per bounding box. 132 | scores: scores per bounding box. 133 | weights: sample weights per bounding box. 134 | objectness: objectness score per bounding box. 135 | masks: masks per bounding box. 136 | boundaries: boundaries per bounding box. 137 | keypoints: keypoints per bounding box. 138 | keypoint_heatmaps: keypoint heatmaps per bounding box. 139 | """ 140 | 141 | boxes = "boxes" 142 | classes = "classes" 143 | scores = "scores" 144 | weights = "weights" 145 | objectness = "objectness" 146 | masks = "masks" 147 | boundaries = "boundaries" 148 | keypoints = "keypoints" 149 | keypoint_heatmaps = "keypoint_heatmaps" 150 | 151 | 152 | class TfExampleFields(object): 153 | """TF-example proto feature names for object detection. 154 | 155 | Holds the standard feature names to load from an Example proto for object 156 | detection. 157 | 158 | Attributes: 159 | image_encoded: JPEG encoded string 160 | image_format: image format, e.g. "JPEG" 161 | filename: filename 162 | channels: number of channels of image 163 | colorspace: colorspace, e.g. "RGB" 164 | height: height of image in pixels, e.g. 462 165 | width: width of image in pixels, e.g. 581 166 | source_id: original source of the image 167 | object_class_text: labels in text format, e.g. ["person", "cat"] 168 | object_class_label: labels in numbers, e.g. [16, 8] 169 | object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30 170 | object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40 171 | object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50 172 | object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70 173 | object_view: viewpoint of object, e.g. ["frontal", "left"] 174 | object_truncated: is object truncated, e.g. [true, false] 175 | object_occluded: is object occluded, e.g. [true, false] 176 | object_difficult: is object difficult, e.g. [true, false] 177 | object_group_of: is object a single object or a group of objects 178 | object_depiction: is object a depiction 179 | object_is_crowd: [DEPRECATED, use object_group_of instead] 180 | is the object a single object or a crowd 181 | object_segment_area: the area of the segment. 182 | object_weight: a weight factor for the object's bounding box. 183 | instance_masks: instance segmentation masks. 184 | instance_boundaries: instance boundaries. 185 | instance_classes: Classes for each instance segmentation mask. 186 | detection_class_label: class label in numbers. 187 | detection_bbox_ymin: ymin coordinates of a detection box. 188 | detection_bbox_xmin: xmin coordinates of a detection box. 189 | detection_bbox_ymax: ymax coordinates of a detection box. 190 | detection_bbox_xmax: xmax coordinates of a detection box. 191 | detection_score: detection score for the class label and box. 192 | """ 193 | 194 | image_encoded = "image/encoded" 195 | image_format = "image/format" # format is reserved keyword 196 | filename = "image/filename" 197 | channels = "image/channels" 198 | colorspace = "image/colorspace" 199 | height = "image/height" 200 | width = "image/width" 201 | source_id = "image/source_id" 202 | object_class_text = "image/object/class/text" 203 | object_class_label = "image/object/class/label" 204 | object_bbox_ymin = "image/object/bbox/ymin" 205 | object_bbox_xmin = "image/object/bbox/xmin" 206 | object_bbox_ymax = "image/object/bbox/ymax" 207 | object_bbox_xmax = "image/object/bbox/xmax" 208 | object_view = "image/object/view" 209 | object_truncated = "image/object/truncated" 210 | object_occluded = "image/object/occluded" 211 | object_difficult = "image/object/difficult" 212 | object_group_of = "image/object/group_of" 213 | object_depiction = "image/object/depiction" 214 | object_is_crowd = "image/object/is_crowd" 215 | object_segment_area = "image/object/segment/area" 216 | object_weight = "image/object/weight" 217 | instance_masks = "image/segmentation/object" 218 | instance_boundaries = "image/boundaries/object" 219 | instance_classes = "image/segmentation/object/class" 220 | detection_class_label = "image/detection/label" 221 | detection_bbox_ymin = "image/detection/bbox/ymin" 222 | detection_bbox_xmin = "image/detection/bbox/xmin" 223 | detection_bbox_ymax = "image/detection/bbox/ymax" 224 | detection_bbox_xmax = "image/detection/bbox/xmax" 225 | detection_score = "image/detection/score" 226 | -------------------------------------------------------------------------------- /slowfast/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Functions for benchmarks. 4 | """ 5 | 6 | import numpy as np 7 | import pprint 8 | import torch 9 | import tqdm 10 | from fvcore.common.timer import Timer 11 | 12 | import slowfast.utils.logging as logging 13 | import slowfast.utils.misc as misc 14 | from slowfast.datasets import loader 15 | from slowfast.utils.env import setup_environment 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | def benchmark_data_loading(cfg): 21 | """ 22 | Benchmark the speed of data loading in PySlowFast. 23 | Args: 24 | 25 | cfg (CfgNode): configs. Details can be found in 26 | slowfast/config/defaults.py 27 | """ 28 | # Set up environment. 29 | setup_environment() 30 | # Set random seed from configs. 31 | np.random.seed(cfg.RNG_SEED) 32 | torch.manual_seed(cfg.RNG_SEED) 33 | 34 | # Setup logging format. 35 | logging.setup_logging(cfg.OUTPUT_DIR) 36 | 37 | # Print config. 38 | logger.info("Benchmark data loading with config:") 39 | logger.info(pprint.pformat(cfg)) 40 | 41 | timer = Timer() 42 | dataloader = loader.construct_loader(cfg, "train") 43 | logger.info( 44 | "Initialize loader using {:.2f} seconds.".format(timer.seconds()) 45 | ) 46 | # Total batch size across different machines. 47 | batch_size = cfg.TRAIN.BATCH_SIZE * cfg.NUM_SHARDS 48 | log_period = cfg.BENCHMARK.LOG_PERIOD 49 | epoch_times = [] 50 | # Test for a few epochs. 51 | for cur_epoch in range(cfg.BENCHMARK.NUM_EPOCHS): 52 | timer = Timer() 53 | timer_epoch = Timer() 54 | iter_times = [] 55 | if cfg.BENCHMARK.SHUFFLE: 56 | loader.shuffle_dataset(dataloader, cur_epoch) 57 | for cur_iter, _ in enumerate(tqdm.tqdm(dataloader)): 58 | if cur_iter > 0 and cur_iter % log_period == 0: 59 | iter_times.append(timer.seconds()) 60 | ram_usage, ram_total = misc.cpu_mem_usage() 61 | logger.info( 62 | "Epoch {}: {} iters ({} videos) in {:.2f} seconds. " 63 | "RAM Usage: {:.2f}/{:.2f} GB.".format( 64 | cur_epoch, 65 | log_period, 66 | log_period * batch_size, 67 | iter_times[-1], 68 | ram_usage, 69 | ram_total, 70 | ) 71 | ) 72 | timer.reset() 73 | epoch_times.append(timer_epoch.seconds()) 74 | ram_usage, ram_total = misc.cpu_mem_usage() 75 | logger.info( 76 | "Epoch {}: in total {} iters ({} videos) in {:.2f} seconds. " 77 | "RAM Usage: {:.2f}/{:.2f} GB.".format( 78 | cur_epoch, 79 | len(dataloader), 80 | len(dataloader) * batch_size, 81 | epoch_times[-1], 82 | ram_usage, 83 | ram_total, 84 | ) 85 | ) 86 | logger.info( 87 | "Epoch {}: on average every {} iters ({} videos) take {:.2f}/{:.2f} " 88 | "(avg/std) seconds.".format( 89 | cur_epoch, 90 | log_period, 91 | log_period * batch_size, 92 | np.mean(iter_times), 93 | np.std(iter_times), 94 | ) 95 | ) 96 | logger.info( 97 | "On average every epoch ({} videos) takes {:.2f}/{:.2f} " 98 | "(avg/std) seconds.".format( 99 | len(dataloader) * batch_size, 100 | np.mean(epoch_times), 101 | np.std(epoch_times), 102 | ) 103 | ) 104 | -------------------------------------------------------------------------------- /slowfast/utils/bn_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """bn helper.""" 5 | 6 | import itertools 7 | import torch 8 | 9 | 10 | @torch.no_grad() 11 | def compute_and_update_bn_stats(model, data_loader, num_batches=200): 12 | """ 13 | Compute and update the batch norm stats to make it more precise. During 14 | training both bn stats and the weight are changing after every iteration, 15 | so the bn can not precisely reflect the latest stats of the current model. 16 | Here the bn stats is recomputed without change of weights, to make the 17 | running mean and running var more precise. 18 | Args: 19 | model (model): the model using to compute and update the bn stats. 20 | data_loader (dataloader): dataloader using to provide inputs. 21 | num_batches (int): running iterations using to compute the stats. 22 | """ 23 | 24 | # Prepares all the bn layers. 25 | bn_layers = [ 26 | m 27 | for m in model.modules() 28 | if any( 29 | ( 30 | isinstance(m, bn_type) 31 | for bn_type in ( 32 | torch.nn.BatchNorm1d, 33 | torch.nn.BatchNorm2d, 34 | torch.nn.BatchNorm3d, 35 | ) 36 | ) 37 | ) 38 | ] 39 | 40 | # In order to make the running stats only reflect the current batch, the 41 | # momentum is disabled. 42 | # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean 43 | # Setting the momentum to 1.0 to compute the stats without momentum. 44 | momentum_actual = [bn.momentum for bn in bn_layers] 45 | for bn in bn_layers: 46 | bn.momentum = 1.0 47 | 48 | # Calculates the running iterations for precise stats computation. 49 | running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers] 50 | running_square_mean = [torch.zeros_like(bn.running_var) for bn in bn_layers] 51 | 52 | for ind, (inputs, _, _) in enumerate( 53 | itertools.islice(data_loader, num_batches) 54 | ): 55 | # Forwards the model to update the bn stats. 56 | if isinstance(inputs, (list,)): 57 | for i in range(len(inputs)): 58 | inputs[i] = inputs[i].float().cuda(non_blocking=True) 59 | else: 60 | inputs = inputs.cuda(non_blocking=True) 61 | model(inputs) 62 | 63 | for i, bn in enumerate(bn_layers): 64 | # Accumulates the bn stats. 65 | running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) 66 | # $E(x^2) = Var(x) + E(x)^2$. 67 | cur_square_mean = bn.running_var + bn.running_mean ** 2 68 | running_square_mean[i] += ( 69 | cur_square_mean - running_square_mean[i] 70 | ) / (ind + 1) 71 | 72 | for i, bn in enumerate(bn_layers): 73 | bn.running_mean = running_mean[i] 74 | # Var(x) = $E(x^2) - E(x)^2$. 75 | bn.running_var = running_square_mean[i] - bn.running_mean ** 2 76 | # Sets the precise bn stats. 77 | bn.momentum = momentum_actual[i] 78 | -------------------------------------------------------------------------------- /slowfast/utils/custom_platform.py: -------------------------------------------------------------------------------- 1 | # Modified by AWS AI Labs on 07/15/2022 2 | 3 | # The codes below partially refer to the PySceneDetect. According 4 | # to its BSD 3-Clause License, we keep the following. 5 | # 6 | # PySceneDetect: Python-Based Video Scene Detector 7 | # --------------------------------------------------------------- 8 | # [ Site: http://www.bcastell.com/projects/PySceneDetect/ ] 9 | # [ Github: https://github.com/Breakthrough/PySceneDetect/ ] 10 | # [ Documentation: http://pyscenedetect.readthedocs.org/ ] 11 | # 12 | # Copyright (C) 2014-2020 Brandon Castellano . 13 | # 14 | # PySceneDetect is licensed under the BSD 3-Clause License; see the included 15 | # LICENSE file, or visit one of the following pages for details: 16 | # - https://github.com/Breakthrough/PySceneDetect/ 17 | # - http://www.bcastell.com/projects/PySceneDetect/ 18 | # 19 | # This software uses Numpy, OpenCV, click, tqdm, simpletable, and pytest. 20 | # See the included LICENSE files or one of the above URLs for more information. 21 | # 22 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | # AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 26 | # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 27 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 28 | 29 | from __future__ import print_function 30 | import csv 31 | import os 32 | import platform 33 | import struct 34 | 35 | import cv2 36 | 37 | # tqdm Library 38 | try: 39 | from tqdm import tqdm 40 | except ImportError: 41 | tqdm = None 42 | 43 | # click/Command-Line Interface String Type 44 | # String type (used to allow FrameTimecode object to take both unicode 45 | # and native string objects when being constructed via 46 | # scenedetect.platform.STRING_TYPE). 47 | # pylint: disable=invalid-name, undefined-variable 48 | STRING_TYPE = str 49 | # pylint: enable=invalid-name, undefined-variable 50 | 51 | # OpenCV 2.x Compatibility Fix 52 | # Compatibility fix for OpenCV v2.x (copies CAP_PROP_* properties from the 53 | # cv2.cv namespace to the cv2 namespace, as the cv2.cv namespace was removed 54 | # with the release of OpenCV 3.0). 55 | # pylint: disable=c-extension-no-member 56 | if cv2.__version__[0] == '2' or not (cv2.__version__[0].isdigit() 57 | and int(cv2.__version__[0]) >= 3): 58 | cv2.CAP_PROP_FRAME_WIDTH = cv2.cv.CV_CAP_PROP_FRAME_WIDTH 59 | cv2.CAP_PROP_FRAME_HEIGHT = cv2.cv.CV_CAP_PROP_FRAME_HEIGHT 60 | cv2.CAP_PROP_FPS = cv2.cv.CV_CAP_PROP_FPS 61 | cv2.CAP_PROP_POS_MSEC = cv2.cv.CV_CAP_PROP_POS_MSEC 62 | cv2.CAP_PROP_POS_FRAMES = cv2.cv.CV_CAP_PROP_POS_FRAMES 63 | cv2.CAP_PROP_FRAME_COUNT = cv2.cv.CV_CAP_PROP_FRAME_COUNT 64 | # pylint: enable=c-extension-no-member 65 | 66 | # OpenCV DLL Check Function (Windows Only) 67 | 68 | 69 | def check_opencv_ffmpeg_dll(): 70 | # type: () -> bool 71 | """Check OpenCV FFmpeg DLL: Checks if OpenCV video I/O support is 72 | available, on Windows only, by checking for the appropriate 73 | opencv_ffmpeg*.dll file. 74 | 75 | On non-Windows systems always returns True, or for OpenCV versions that do 76 | not follow the X.Y.Z version numbering pattern. Thus there may be false 77 | positives (True) with this function, but not false negatives (False). 78 | In those cases, PySceneDetect will report that it could not open the 79 | video file, and for Windows users, also gives an additional warning message 80 | that the error may be due to the missing DLL file. 81 | 82 | Returns: 83 | (bool) True if OpenCV video support is detected (e.g. the appropriate 84 | opencv_ffmpegXYZ.dll file is in PATH), False otherwise. 85 | """ 86 | if platform.system() == 'Windows' and (cv2.__version__[0].isdigit() 87 | and cv2.__version__.find('.') > 0): 88 | is_64_bit_str = '_64' if struct.calcsize('P') == 8 else '' 89 | dll_filename = 'opencv_ffmpeg{OPENCV_VERSION}{IS_64_BIT}.dll'.format( 90 | OPENCV_VERSION=cv2.__version__.replace('.', ''), 91 | IS_64_BIT=is_64_bit_str) 92 | return any([ 93 | os.path.exists(os.path.join(path_path, dll_filename)) 94 | for path_path in os.environ['PATH'].split(';') 95 | ]), dll_filename 96 | return True 97 | 98 | 99 | # OpenCV imwrite Supported Image Types & Quality/Compression Parameters 100 | 101 | 102 | def _get_cv2_param(param_name): 103 | ''' 104 | Args: 105 | param_name (str) 106 | 107 | Returns: 108 | Union[int, None] 109 | ''' 110 | if param_name.startswith('CV_'): 111 | param_name = param_name[3:] 112 | try: 113 | return getattr(cv2, param_name) 114 | except AttributeError: 115 | return None 116 | 117 | 118 | def get_cv2_imwrite_params(): 119 | """Get OpenCV imwrite Params: Returns a dict of supported image formats and 120 | their associated quality/compression parameter. 121 | 122 | Returns: 123 | (Dict[str, int]) Dictionary of image formats/extensions ('jpg', 124 | 'png', etc...) mapped to the respective OpenCV quality or 125 | compression parameter (e.g. 'jpg' -> cv2.IMWRITE_JPEG_QUALITY, 126 | 'png' -> cv2.IMWRITE_PNG_COMPRESSION).. 127 | """ 128 | return { 129 | 'jpg': _get_cv2_param('IMWRITE_JPEG_QUALITY'), 130 | 'png': _get_cv2_param('IMWRITE_PNG_COMPRESSION'), 131 | 'webp': _get_cv2_param('IMWRITE_WEBP_QUALITY') 132 | } 133 | 134 | 135 | # Python csv Module Wrapper (for StatsManager, and 136 | # CliContext/list-scenes command) 137 | 138 | 139 | def get_csv_reader(file_handle): 140 | """Returns a csv.reader object using the passed file handle. 141 | 142 | Args: 143 | file_handle (File) 144 | 145 | Returns: 146 | csv.reader 147 | """ 148 | return csv.reader(file_handle, lineterminator='\n') 149 | 150 | 151 | def get_csv_writer(file_handle): 152 | """Returns a csv.writer object using the passed file handle. 153 | Args: 154 | file_handle (File) 155 | 156 | Returns: 157 | csv.reader 158 | """ 159 | return csv.writer(file_handle, lineterminator='\n') 160 | -------------------------------------------------------------------------------- /slowfast/utils/env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Set up Environment.""" 5 | 6 | import slowfast.utils.logging as logging 7 | 8 | _ENV_SETUP_DONE = False 9 | 10 | 11 | def setup_environment(): 12 | global _ENV_SETUP_DONE 13 | if _ENV_SETUP_DONE: 14 | return 15 | _ENV_SETUP_DONE = True 16 | -------------------------------------------------------------------------------- /slowfast/utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Logging.""" 5 | 6 | import builtins 7 | import decimal 8 | import functools 9 | import logging 10 | import os 11 | import sys 12 | import simplejson 13 | from fvcore.common.file_io import PathManager 14 | 15 | import slowfast.utils.distributed as du 16 | 17 | 18 | def _suppress_print(): 19 | """ 20 | Suppresses printing from the current process. 21 | """ 22 | 23 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 24 | pass 25 | 26 | builtins.print = print_pass 27 | 28 | 29 | @functools.lru_cache(maxsize=None) 30 | def _cached_log_stream(filename): 31 | return PathManager.open(filename, "a") 32 | 33 | 34 | def setup_logging(output_dir=None): 35 | """ 36 | Sets up the logging for multiple processes. Only enable the logging for the 37 | master process, and suppress logging for the non-master processes. 38 | """ 39 | # Set up logging format. 40 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 41 | 42 | if du.is_master_proc(): 43 | # Enable logging for the master process. 44 | logging.root.handlers = [] 45 | logging.basicConfig( 46 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 47 | ) 48 | else: 49 | # Suppress logging for non-master processes. 50 | _suppress_print() 51 | 52 | logger = logging.getLogger() 53 | logger.setLevel(logging.DEBUG) 54 | logger.propagate = False 55 | plain_formatter = logging.Formatter( 56 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 57 | datefmt="%m/%d %H:%M:%S", 58 | ) 59 | 60 | if du.is_master_proc(): 61 | ch = logging.StreamHandler(stream=sys.stdout) 62 | ch.setLevel(logging.DEBUG) 63 | ch.setFormatter(plain_formatter) 64 | logger.addHandler(ch) 65 | 66 | if output_dir is not None and du.is_master_proc(du.get_world_size()): 67 | filename = os.path.join(output_dir, "stdout.log") 68 | fh = logging.StreamHandler(_cached_log_stream(filename)) 69 | fh.setLevel(logging.DEBUG) 70 | fh.setFormatter(plain_formatter) 71 | logger.addHandler(fh) 72 | 73 | 74 | def get_logger(name): 75 | """ 76 | Retrieve the logger with the specified name or, if name is None, return a 77 | logger which is the root logger of the hierarchy. 78 | Args: 79 | name (string): name of the logger. 80 | """ 81 | return logging.getLogger(name) 82 | 83 | 84 | def log_json_stats(stats): 85 | """ 86 | Logs json stats. 87 | Args: 88 | stats (dict): a dictionary of statistical information to log. 89 | """ 90 | stats = { 91 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v 92 | for k, v in stats.items() 93 | } 94 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 95 | logger = get_logger(__name__) 96 | logger.info("json_stats: {:s}".format(json_stats)) 97 | -------------------------------------------------------------------------------- /slowfast/utils/lr_policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Learning rate policy.""" 5 | 6 | import math 7 | 8 | 9 | def get_lr_at_epoch(cfg, cur_epoch): 10 | """ 11 | Retrieve the learning rate of the current epoch with the option to perform 12 | warm up in the beginning of the training stage. 13 | Args: 14 | cfg (CfgNode): configs. Details can be found in 15 | slowfast/config/defaults.py 16 | cur_epoch (float): the number of epoch of the current training stage. 17 | """ 18 | lr = get_lr_func(cfg.SOLVER.LR_POLICY)(cfg, cur_epoch) 19 | # Perform warm up. 20 | if cur_epoch < cfg.SOLVER.WARMUP_EPOCHS: 21 | lr_start = cfg.SOLVER.WARMUP_START_LR 22 | lr_end = get_lr_func(cfg.SOLVER.LR_POLICY)( 23 | cfg, cfg.SOLVER.WARMUP_EPOCHS 24 | ) 25 | alpha = (lr_end - lr_start) / cfg.SOLVER.WARMUP_EPOCHS 26 | lr = cur_epoch * alpha + lr_start 27 | return lr 28 | 29 | 30 | def lr_func_cosine(cfg, cur_epoch): 31 | """ 32 | Retrieve the learning rate to specified values at specified epoch with the 33 | cosine learning rate schedule. Details can be found in: 34 | Ilya Loshchilov, and Frank Hutter 35 | SGDR: Stochastic Gradient Descent With Warm Restarts. 36 | Args: 37 | cfg (CfgNode): configs. Details can be found in 38 | slowfast/config/defaults.py 39 | cur_epoch (float): the number of epoch of the current training stage. 40 | """ 41 | return ( 42 | cfg.SOLVER.BASE_LR 43 | * (math.cos(math.pi * cur_epoch / cfg.SOLVER.MAX_EPOCH) + 1.0) 44 | * 0.5 45 | ) 46 | 47 | 48 | def lr_func_linear(cfg, cur_epoch): 49 | return ( 50 | cfg.SOLVER.BASE_LR 51 | * min(max(1.0 - cur_epoch / cfg.SOLVER.MAX_EPOCH, 0.0), 1.0) 52 | ) 53 | 54 | 55 | def lr_func_steps_with_relative_lrs(cfg, cur_epoch): 56 | """ 57 | Retrieve the learning rate to specified values at specified epoch with the 58 | steps with relative learning rate schedule. 59 | Args: 60 | cfg (CfgNode): configs. Details can be found in 61 | slowfast/config/defaults.py 62 | cur_epoch (float): the number of epoch of the current training stage. 63 | """ 64 | ind = get_step_index(cfg, cur_epoch) 65 | return cfg.SOLVER.LRS[ind] * cfg.SOLVER.BASE_LR 66 | 67 | 68 | def get_step_index(cfg, cur_epoch): 69 | """ 70 | Retrieves the lr step index for the given epoch. 71 | Args: 72 | cfg (CfgNode): configs. Details can be found in 73 | slowfast/config/defaults.py 74 | cur_epoch (float): the number of epoch of the current training stage. 75 | """ 76 | steps = cfg.SOLVER.STEPS + [cfg.SOLVER.MAX_EPOCH] 77 | for ind, step in enumerate(steps): # NoQA 78 | if cur_epoch < step: 79 | break 80 | return ind - 1 81 | 82 | 83 | def get_lr_func(lr_policy): 84 | """ 85 | Given the configs, retrieve the specified lr policy function. 86 | Args: 87 | lr_policy (string): the learning rate policy to use for the job. 88 | """ 89 | policy = "lr_func_" + lr_policy 90 | if policy not in globals(): 91 | raise NotImplementedError("Unknown LR policy: {}".format(lr_policy)) 92 | else: 93 | return globals()[policy] 94 | -------------------------------------------------------------------------------- /slowfast/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Modified by AWS AI Labs on 07/15/2022 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | 5 | """Functions for computing metrics.""" 6 | 7 | import torch 8 | 9 | 10 | def topks_correct(preds, labels, ks): 11 | """ 12 | Given the predictions, labels, and a list of top-k values, compute the 13 | number of correct predictions for each top-k value. 14 | 15 | Args: 16 | preds (array): array of predictions. Dimension is batchsize 17 | N x ClassNum. 18 | labels (array): array of labels. Dimension is batchsize N. 19 | ks (list): list of top-k values. For example, ks = [1, 5] correspods 20 | to top-1 and top-5. 21 | 22 | Returns: 23 | topks_correct (list): list of numbers, where the `i`-th entry 24 | corresponds to the number of top-`ks[i]` correct predictions. 25 | """ 26 | assert preds.size(0) == labels.size( 27 | 0 28 | ), "Batch dim of predictions and labels must match" 29 | # Find the top max_k predictions for each sample 30 | _top_max_k_vals, top_max_k_inds = torch.topk( 31 | preds, max(ks), dim=1, largest=True, sorted=True 32 | ) 33 | # (batch_size, max_k) -> (max_k, batch_size). 34 | top_max_k_inds = top_max_k_inds.t() 35 | top_max_k_inds = top_max_k_inds.contiguous() 36 | # (batch_size, ) -> (max_k, batch_size). 37 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 38 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct. 39 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 40 | # Compute the number of topk correct predictions for each k. 41 | topks_correct = [ 42 | top_max_k_correct[:k, :].view(-1).float().sum() for k in ks 43 | ] 44 | return topks_correct 45 | 46 | 47 | def topk_errors(preds, labels, ks): 48 | """ 49 | Computes the top-k error for each k. 50 | Args: 51 | preds (array): array of predictions. Dimension is N. 52 | labels (array): array of labels. Dimension is N. 53 | ks (list): list of ks to calculate the top accuracies. 54 | """ 55 | num_topks_correct = topks_correct(preds, labels, ks) 56 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] 57 | 58 | 59 | def topk_accuracies(preds, labels, ks): 60 | """ 61 | Computes the top-k accuracy for each k. 62 | Args: 63 | preds (array): array of predictions. Dimension is N. 64 | labels (array): array of labels. Dimension is N. 65 | ks (list): list of ks to calculate the top accuracies. 66 | """ 67 | num_topks_correct = topks_correct(preds, labels, ks) 68 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] 69 | -------------------------------------------------------------------------------- /slowfast/utils/multigrid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Helper functions for multigrid training.""" 5 | 6 | import numpy as np 7 | 8 | import slowfast.utils.logging as logging 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | class MultigridSchedule(object): 14 | """ 15 | This class defines multigrid training schedule and update cfg accordingly. 16 | """ 17 | 18 | def init_multigrid(self, cfg): 19 | """ 20 | Update cfg based on multigrid settings. 21 | Args: 22 | cfg (configs): configs that contains training and multigrid specific 23 | hyperparameters. Details can be seen in 24 | slowfast/config/defaults.py. 25 | Returns: 26 | cfg (configs): the updated cfg. 27 | """ 28 | self.schedule = None 29 | # We may modify cfg.TRAIN.BATCH_SIZE, cfg.DATA.NUM_FRAMES, and 30 | # cfg.DATA.TRAIN_CROP_SIZE during training, so we store their original 31 | # value in cfg and use them as global variables. 32 | cfg.MULTIGRID.DEFAULT_B = cfg.TRAIN.BATCH_SIZE 33 | cfg.MULTIGRID.DEFAULT_T = cfg.DATA.NUM_FRAMES 34 | cfg.MULTIGRID.DEFAULT_S = cfg.DATA.TRAIN_CROP_SIZE 35 | 36 | if cfg.MULTIGRID.LONG_CYCLE: 37 | self.schedule = self.get_long_cycle_schedule(cfg) 38 | cfg.SOLVER.STEPS = [0] + [s[-1] for s in self.schedule] 39 | # Fine-tuning phase. 40 | cfg.SOLVER.STEPS[-1] = ( 41 | cfg.SOLVER.STEPS[-2] + cfg.SOLVER.STEPS[-1] 42 | ) // 2 43 | cfg.SOLVER.LRS = [ 44 | cfg.SOLVER.GAMMA ** s[0] * s[1][0] for s in self.schedule 45 | ] 46 | # Fine-tuning phase. 47 | cfg.SOLVER.LRS = cfg.SOLVER.LRS[:-1] + [ 48 | cfg.SOLVER.LRS[-2], 49 | cfg.SOLVER.LRS[-1], 50 | ] 51 | 52 | cfg.SOLVER.MAX_EPOCH = self.schedule[-1][-1] 53 | 54 | elif cfg.MULTIGRID.SHORT_CYCLE: 55 | cfg.SOLVER.STEPS = [ 56 | int(s * cfg.MULTIGRID.EPOCH_FACTOR) for s in cfg.SOLVER.STEPS 57 | ] 58 | cfg.SOLVER.MAX_EPOCH = int( 59 | cfg.SOLVER.MAX_EPOCH * cfg.MULTIGRID.EPOCH_FACTOR 60 | ) 61 | return cfg 62 | 63 | def update_long_cycle(self, cfg, cur_epoch): 64 | """ 65 | Before every epoch, check if long cycle shape should change. If it 66 | should, update cfg accordingly. 67 | Args: 68 | cfg (configs): configs that contains training and multigrid specific 69 | hyperparameters. Details can be seen in 70 | slowfast/config/defaults.py. 71 | cur_epoch (int): current epoch index. 72 | Returns: 73 | cfg (configs): the updated cfg. 74 | changed (bool): do we change long cycle shape at this epoch? 75 | """ 76 | base_b, base_t, base_s = get_current_long_cycle_shape( 77 | self.schedule, cur_epoch 78 | ) 79 | if base_s != cfg.DATA.TRAIN_CROP_SIZE or base_t != cfg.DATA.NUM_FRAMES: 80 | 81 | cfg.DATA.NUM_FRAMES = base_t 82 | cfg.DATA.TRAIN_CROP_SIZE = base_s 83 | cfg.TRAIN.BATCH_SIZE = base_b * cfg.MULTIGRID.DEFAULT_B 84 | 85 | bs_factor = ( 86 | float(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) 87 | / cfg.MULTIGRID.BN_BASE_SIZE 88 | ) 89 | 90 | if bs_factor < 1: 91 | cfg.BN.NORM_TYPE = "sync_batchnorm" 92 | cfg.BN.NUM_SYNC_DEVICES = int(1.0 / bs_factor) 93 | elif bs_factor > 1: 94 | cfg.BN.NORM_TYPE = "sub_batchnorm" 95 | cfg.BN.NUM_SPLITS = int(bs_factor) 96 | else: 97 | cfg.BN.NORM_TYPE = "batchnorm" 98 | 99 | cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE = cfg.DATA.SAMPLING_RATE * ( 100 | cfg.MULTIGRID.DEFAULT_T // cfg.DATA.NUM_FRAMES 101 | ) 102 | logger.info("Long cycle updates:") 103 | logger.info("\tBN.NORM_TYPE: {}".format(cfg.BN.NORM_TYPE)) 104 | if cfg.BN.NORM_TYPE == "sync_batchnorm": 105 | logger.info( 106 | "\tBN.NUM_SYNC_DEVICES: {}".format(cfg.BN.NUM_SYNC_DEVICES) 107 | ) 108 | elif cfg.BN.NORM_TYPE == "sub_batchnorm": 109 | logger.info("\tBN.NUM_SPLITS: {}".format(cfg.BN.NUM_SPLITS)) 110 | logger.info("\tTRAIN.BATCH_SIZE: {}".format(cfg.TRAIN.BATCH_SIZE)) 111 | logger.info( 112 | "\tDATA.NUM_FRAMES x LONG_CYCLE_SAMPLING_RATE: {}x{}".format( 113 | cfg.DATA.NUM_FRAMES, cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE 114 | ) 115 | ) 116 | logger.info( 117 | "\tDATA.TRAIN_CROP_SIZE: {}".format(cfg.DATA.TRAIN_CROP_SIZE) 118 | ) 119 | return cfg, True 120 | else: 121 | return cfg, False 122 | 123 | def get_long_cycle_schedule(self, cfg): 124 | """ 125 | Based on multigrid hyperparameters, define the schedule of a long cycle. 126 | Args: 127 | cfg (configs): configs that contains training and multigrid specific 128 | hyperparameters. Details can be seen in 129 | slowfast/config/defaults.py. 130 | Returns: 131 | schedule (list): Specifies a list long cycle base shapes and their 132 | corresponding training epochs. 133 | """ 134 | 135 | steps = cfg.SOLVER.STEPS 136 | 137 | default_size = float( 138 | cfg.DATA.NUM_FRAMES * cfg.DATA.TRAIN_CROP_SIZE ** 2 139 | ) 140 | default_iters = steps[-1] 141 | 142 | # Get shapes and average batch size for each long cycle shape. 143 | avg_bs = [] 144 | all_shapes = [] 145 | for t_factor, s_factor in cfg.MULTIGRID.LONG_CYCLE_FACTORS: 146 | base_t = int(round(cfg.DATA.NUM_FRAMES * t_factor)) 147 | base_s = int(round(cfg.DATA.TRAIN_CROP_SIZE * s_factor)) 148 | if cfg.MULTIGRID.SHORT_CYCLE: 149 | shapes = [ 150 | [ 151 | base_t, 152 | cfg.MULTIGRID.DEFAULT_S 153 | * cfg.MULTIGRID.SHORT_CYCLE_FACTORS[0], 154 | ], 155 | [ 156 | base_t, 157 | cfg.MULTIGRID.DEFAULT_S 158 | * cfg.MULTIGRID.SHORT_CYCLE_FACTORS[1], 159 | ], 160 | [base_t, base_s], 161 | ] 162 | else: 163 | shapes = [[base_t, base_s]] 164 | 165 | # (T, S) -> (B, T, S) 166 | shapes = [ 167 | [int(round(default_size / (s[0] * s[1] * s[1]))), s[0], s[1]] 168 | for s in shapes 169 | ] 170 | avg_bs.append(np.mean([s[0] for s in shapes])) 171 | all_shapes.append(shapes) 172 | 173 | # Get schedule regardless of cfg.MULTIGRID.EPOCH_FACTOR. 174 | total_iters = 0 175 | schedule = [] 176 | for step_index in range(len(steps) - 1): 177 | step_epochs = steps[step_index + 1] - steps[step_index] 178 | 179 | for long_cycle_index, shapes in enumerate(all_shapes): 180 | cur_epochs = ( 181 | step_epochs * avg_bs[long_cycle_index] / sum(avg_bs) 182 | ) 183 | 184 | cur_iters = cur_epochs / avg_bs[long_cycle_index] 185 | total_iters += cur_iters 186 | schedule.append((step_index, shapes[-1], cur_epochs)) 187 | 188 | iter_saving = default_iters / total_iters 189 | 190 | final_step_epochs = cfg.SOLVER.MAX_EPOCH - steps[-1] 191 | 192 | # We define the fine-tuning phase to have the same amount of iteration 193 | # saving as the rest of the training. 194 | ft_epochs = final_step_epochs / iter_saving * avg_bs[-1] 195 | 196 | schedule.append((step_index + 1, all_shapes[-1][2], ft_epochs)) 197 | 198 | # Obtrain final schedule given desired cfg.MULTIGRID.EPOCH_FACTOR. 199 | x = ( 200 | cfg.SOLVER.MAX_EPOCH 201 | * cfg.MULTIGRID.EPOCH_FACTOR 202 | / sum(s[-1] for s in schedule) 203 | ) 204 | 205 | final_schedule = [] 206 | total_epochs = 0 207 | for s in schedule: 208 | epochs = s[2] * x 209 | total_epochs += epochs 210 | final_schedule.append((s[0], s[1], int(round(total_epochs)))) 211 | print_schedule(final_schedule) 212 | return final_schedule 213 | 214 | 215 | def print_schedule(schedule): 216 | """ 217 | Log schedule. 218 | """ 219 | logger.info("Long cycle index\tBase shape\tEpochs") 220 | for s in schedule: 221 | logger.info("{}\t{}\t{}".format(s[0], s[1], s[2])) 222 | 223 | 224 | def get_current_long_cycle_shape(schedule, epoch): 225 | """ 226 | Given a schedule and epoch index, return the long cycle base shape. 227 | Args: 228 | schedule (configs): configs that contains training and multigrid specific 229 | hyperparameters. Details can be seen in 230 | slowfast/config/defaults.py. 231 | cur_epoch (int): current epoch index. 232 | Returns: 233 | shapes (list): A list describing the base shape in a long cycle: 234 | [batch size relative to default, 235 | number of frames, spatial dimension]. 236 | """ 237 | for s in schedule: 238 | if epoch < s[-1]: 239 | return s[1] 240 | return schedule[-1][1] 241 | -------------------------------------------------------------------------------- /slowfast/utils/multiprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Multiprocessing helpers.""" 5 | 6 | import torch 7 | 8 | 9 | def run( 10 | local_rank, num_proc, func, init_method, shard_id, num_shards, backend, cfg 11 | ): 12 | """ 13 | Runs a function from a child process. 14 | Args: 15 | local_rank (int): rank of the current process on the current machine. 16 | num_proc (int): number of processes per machine. 17 | func (function): function to execute on each of the process. 18 | init_method (string): method to initialize the distributed training. 19 | TCP initialization: equiring a network address reachable from all 20 | processes followed by the port. 21 | Shared file-system initialization: makes use of a file system that 22 | is shared and visible from all machines. The URL should start with 23 | file:// and contain a path to a non-existent file on a shared file 24 | system. 25 | shard_id (int): the rank of the current machine. 26 | num_shards (int): number of overall machines for the distributed 27 | training job. 28 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 29 | supports, each with different capabilities. Details can be found 30 | here: 31 | https://pytorch.org/docs/stable/distributed.html 32 | cfg (CfgNode): configs. Details can be found in 33 | slowfast/config/defaults.py 34 | """ 35 | # Initialize the process group. 36 | world_size = num_proc * num_shards 37 | rank = shard_id * num_proc + local_rank 38 | 39 | try: 40 | torch.distributed.init_process_group( 41 | backend=backend, 42 | init_method=init_method, 43 | world_size=world_size, 44 | rank=rank, 45 | ) 46 | except Exception as e: 47 | raise e 48 | 49 | torch.cuda.set_device(local_rank) 50 | func(cfg) 51 | -------------------------------------------------------------------------------- /slowfast/utils/parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Argument parser functions.""" 5 | 6 | import argparse 7 | import sys 8 | 9 | import slowfast.utils.checkpoint as cu 10 | from slowfast.config.defaults import get_cfg 11 | 12 | 13 | def parse_args(): 14 | """ 15 | Parse the following arguments for a default parser for PySlowFast users. 16 | Args: 17 | shard_id (int): shard id for the current machine. Starts from 0 to 18 | num_shards - 1. If single machine is used, then set shard id to 0. 19 | num_shards (int): number of shards using by the job. 20 | init_method (str): initialization method to launch the job with multiple 21 | devices. Options includes TCP or shared file-system for 22 | initialization. details can be find in 23 | https://pytorch.org/docs/stable/distributed.html#tcp-initialization 24 | cfg (str): path to the config file. 25 | opts (argument): provide addtional options from the command line, it 26 | overwrites the config loaded from file. 27 | """ 28 | parser = argparse.ArgumentParser( 29 | description="Provide SlowFast video training and testing pipeline." 30 | ) 31 | parser.add_argument( 32 | "--shard_id", 33 | help="The shard id of current node, Starts from 0 to num_shards - 1", 34 | default=0, 35 | type=int, 36 | ) 37 | parser.add_argument( 38 | "--num_shards", 39 | help="Number of shards using by the job", 40 | default=1, 41 | type=int, 42 | ) 43 | parser.add_argument( 44 | "--init_method", 45 | help="Initialization method, includes TCP or shared file-system", 46 | default="tcp://localhost:9999", 47 | type=str, 48 | ) 49 | parser.add_argument( 50 | "--cfg", 51 | dest="cfg_file", 52 | help="Path to the config file", 53 | default="configs/Kinetics/SLOWFAST_4x16_R50.yaml", 54 | type=str, 55 | ) 56 | parser.add_argument( 57 | "--output_dir", 58 | dest="output_dir", 59 | help="Path to the outputs", 60 | default=".", 61 | type=str, 62 | ) 63 | parser.add_argument( 64 | "opts", 65 | help="See slowfast/config/defaults.py for all options", 66 | default=None, 67 | nargs=argparse.REMAINDER, 68 | ) 69 | if len(sys.argv) == 1: 70 | parser.print_help() 71 | return parser.parse_args() 72 | 73 | 74 | def load_config(args, mkdir=True): 75 | """ 76 | Given the arguemnts, load and initialize the configs. 77 | Args: 78 | args (argument): arguments includes `shard_id`, `num_shards`, 79 | `init_method`, `cfg_file`, and `opts`. 80 | """ 81 | # Setup cfg. 82 | cfg = get_cfg() 83 | # Load config from cfg. 84 | if args.cfg_file is not None: 85 | cfg.merge_from_file(args.cfg_file) 86 | # Load config from command line, overwrite config from opts. 87 | if args.opts is not None: 88 | cfg.merge_from_list(args.opts) 89 | 90 | # Inherit parameters from args. 91 | if hasattr(args, "num_shards") and hasattr(args, "shard_id"): 92 | cfg.NUM_SHARDS = args.num_shards 93 | cfg.SHARD_ID = args.shard_id 94 | if hasattr(args, "rng_seed"): 95 | cfg.RNG_SEED = args.rng_seed 96 | if hasattr(args, "output_dir"): 97 | cfg.OUTPUT_DIR = args.output_dir 98 | 99 | # Create the checkpoint dir. 100 | if mkdir: 101 | cu.make_checkpoint_dir(cfg.OUTPUT_DIR) 102 | 103 | return cfg 104 | -------------------------------------------------------------------------------- /slowfast/utils/video_splitter.py: -------------------------------------------------------------------------------- 1 | # The codes below partially refer to the PySceneDetect. According 2 | # to its BSD 3-Clause License, we keep the following. 3 | # 4 | # PySceneDetect: Python-Based Video Scene Detector 5 | # --------------------------------------------------------------- 6 | # [ Site: http://www.bcastell.com/projects/PySceneDetect/ ] 7 | # [ Github: https://github.com/Breakthrough/PySceneDetect/ ] 8 | # [ Documentation: http://pyscenedetect.readthedocs.org/ ] 9 | # 10 | # Copyright (C) 2014-2020 Brandon Castellano . 11 | # 12 | # PySceneDetect is licensed under the BSD 3-Clause License; see the included 13 | # LICENSE file, or visit one of the following pages for details: 14 | # - https://github.com/Breakthrough/PySceneDetect/ 15 | # - http://www.bcastell.com/projects/PySceneDetect/ 16 | # 17 | # This software uses Numpy, OpenCV, click, tqdm, simpletable, and pytest. 18 | # See the included LICENSE files or one of the above URLs for more information. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | # AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 24 | # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 25 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 26 | 27 | import logging 28 | import math 29 | import os 30 | import subprocess 31 | import time 32 | from string import Template 33 | from .custom_platform import tqdm 34 | 35 | 36 | def is_mkvmerge_available(): 37 | """Is mkvmerge Available: Gracefully checks if mkvmerge command is 38 | available. 39 | 40 | Returns: 41 | (bool) True if the mkvmerge command is available, False otherwise. 42 | """ 43 | ret_val = None 44 | try: 45 | ret_val = subprocess.call(['mkvmerge', '--quiet']) 46 | except OSError: 47 | return False 48 | if ret_val is not None and ret_val != 2: 49 | return False 50 | return True 51 | 52 | 53 | def is_ffmpeg_available(): 54 | """Is ffmpeg Available: Gracefully checks if ffmpeg command is available. 55 | 56 | Returns: 57 | (bool) True if the ffmpeg command is available, False otherwise. 58 | """ 59 | ret_val = None 60 | try: 61 | ret_val = subprocess.call(['ffmpeg', '-v', 'quiet']) 62 | except OSError: 63 | return False 64 | if ret_val is not None and ret_val != 1: 65 | return False 66 | return True 67 | 68 | 69 | def split_video_mkvmerge(input_video_paths, 70 | shot_list, 71 | output_file_prefix, 72 | video_name, 73 | suppress_output=False): 74 | """Split video. 75 | 76 | Calls the mkvmerge command on the input video(s), splitting it at 77 | the passed timecodes, where each shot is written in sequence from 78 | 001. 79 | """ 80 | 81 | if not input_video_paths or not shot_list: 82 | return 83 | 84 | logging.info( 85 | 'Splitting input video%s using mkvmerge, output path template:\n %s', 86 | 's' if len(input_video_paths) > 1 else '', output_file_prefix) 87 | 88 | ret_val = None 89 | # mkvmerge automatically appends '-$SHOT_NUMBER'. 90 | output_file_name = output_file_prefix.replace('-${SHOT_NUMBER}', '') 91 | output_file_name = output_file_prefix.replace('-$SHOT_NUMBER', '') 92 | output_file_template = Template(output_file_name) 93 | output_file_name = output_file_template.safe_substitute( 94 | VIDEO_NAME=video_name, SHOT_NUMBER='') 95 | 96 | try: 97 | call_list = ['mkvmerge'] 98 | if suppress_output: 99 | call_list.append('--quiet') 100 | call_list += [ 101 | '-o', output_file_name, '--split', 102 | 'parts:%s' % ','.join([ 103 | '%s-%s' % (start_time.get_timecode(), end_time.get_timecode()) 104 | for start_time, end_time in shot_list 105 | ]), ' +'.join(input_video_paths) 106 | ] 107 | total_frames = shot_list[-1][1].get_frames( 108 | ) - shot_list[0][0].get_frames() 109 | processing_start_time = time.time() 110 | ret_val = subprocess.call(call_list) 111 | if not suppress_output: 112 | print('') 113 | logging.info( 114 | 'Average processing speed %.2f frames/sec.', 115 | float(total_frames) / (time.time() - processing_start_time)) 116 | except OSError: 117 | logging.error( 118 | 'mkvmerge could not be found on the system.' 119 | ' Please install mkvmerge to enable video output support.') 120 | raise 121 | if ret_val is not None and ret_val != 0: 122 | logging.error('Error splitting video (mkvmerge returned %d).', ret_val) 123 | 124 | 125 | def split_video_ffmpeg(input_video_paths, 126 | shot_list, 127 | output_dir, 128 | output_file_template=('${OUTPUT_DIR}/' 129 | 'shot_${SHOT_NUMBER}.mp4'), 130 | arg_override='-crf 21', 131 | hide_progress=False, 132 | suppress_output=False, 133 | output_names=None): 134 | """Calls the ffmpeg command on the input video(s), generating a new video 135 | for each shot based on the start/end timecodes. 136 | 137 | Args: 138 | input_video_paths (List[str]) 139 | shot_list (List[Tuple[FrameTimecode, FrameTimecode]],) 140 | """ 141 | 142 | os.makedirs(output_dir, exist_ok=True) 143 | if not input_video_paths or not shot_list: 144 | return 145 | 146 | logging.info( 147 | 'Splitting input video%s using ffmpeg, output path template:\n %s', 148 | 's' if len(input_video_paths) > 1 else '', output_file_template) 149 | if len(input_video_paths) > 1: 150 | # TODO: Add support for splitting multiple/appended input videos. 151 | # https://trac.ffmpeg.org/wiki/Concatenate#samecodec 152 | # Requires generating a temporary file list for ffmpeg. 153 | logging.error( 154 | 'Sorry, splitting multiple appended/concatenated input videos with' 155 | ' ffmpeg is not supported yet. This feature will be added to a' 156 | ' future version of ShotDetect. In the meantime, you can try using' 157 | ' the -c / --copy option with the split-video to use mkvmerge,' 158 | ' which generates less accurate output,' 159 | ' but supports multiple input videos.') 160 | raise NotImplementedError() 161 | 162 | arg_override = arg_override.replace('\\"', '"') 163 | 164 | ret_val = None 165 | if len(arg_override) > 0: 166 | arg_override = arg_override.split(' ') 167 | else: 168 | arg_override = [] 169 | filename_template = Template(output_file_template) 170 | shot_num_format = '%0' 171 | shot_num_format += str( 172 | max(4, 173 | math.floor(math.log(len(shot_list), 10)) + 1)) + 'd' 174 | try: 175 | progress_bar = None 176 | total_frames = shot_list[-1][1].get_frames( 177 | ) - shot_list[0][0].get_frames() 178 | if tqdm and not hide_progress: 179 | progress_bar = tqdm( 180 | total=total_frames, 181 | unit='frame', 182 | miniters=1, 183 | desc='Split Video') 184 | processing_start_time = time.time() 185 | for i, (start_time, end_time) in enumerate(shot_list): 186 | end_time = end_time.__sub__(1) 187 | # Fix the last frame of a shot to be 1 less than the first frame 188 | # of the next shot 189 | duration = (end_time - start_time) 190 | # an alternative way to do it 191 | # duration = (end_time.get_frames()-1)/end_time.framerate - 192 | # (start_time.get_frames())/start_time.framerate 193 | # duration_frame = end_time.get_frames() - 1 - \ 194 | # start_time.get_frames() 195 | call_list = ['ffmpeg'] 196 | if suppress_output: 197 | call_list += ['-v', 'quiet'] 198 | elif i > 0: 199 | # Only show ffmpeg output for the first call, which will 200 | # display any errors if it fails, and then break the loop. 201 | # We only show error messages for the remaining calls. 202 | call_list += ['-v', 'error'] 203 | call_list += [ 204 | '-y', '-ss', 205 | start_time.get_timecode(), '-i', input_video_paths[0] 206 | ] 207 | call_list += arg_override # compress 208 | call_list += ['-map_chapters', '-1'] # remove meta stream 209 | if output_names is None: 210 | output_name = filename_template.safe_substitute(OUTPUT_DIR=output_dir, SHOT_NUMBER=shot_num_format % (i)) 211 | else: 212 | output_name = output_names[i] 213 | call_list += [ 214 | '-strict', '-2', '-t', 215 | duration.get_timecode(), '-sn', 216 | output_name 217 | ] 218 | ret_val = subprocess.call(call_list) 219 | if not suppress_output and i == 0 and len(shot_list) > 1: 220 | logging.info( 221 | 'Output from ffmpeg for shot 1 shown above, splitting \ 222 | remaining shots...') 223 | if ret_val != 0: 224 | break 225 | if progress_bar: 226 | progress_bar.update( 227 | duration.get_frames() + 228 | 1) # to compensate the missing one frame caused above 229 | if progress_bar: 230 | print('') 231 | logging.info( 232 | 'Average processing speed %.2f frames/sec.', 233 | float(total_frames) / (time.time() - processing_start_time)) 234 | except OSError: 235 | logging.error('ffmpeg could not be found on the system.' 236 | ' Please install ffmpeg to enable video output support.') 237 | if ret_val is not None and ret_val != 0: 238 | logging.error('Error splitting video (ffmpeg returned %d).', ret_val) 239 | -------------------------------------------------------------------------------- /slowfast/utils/viterbi.py: -------------------------------------------------------------------------------- 1 | # Modified by AWS AI Labs on 07/15/2022 2 | 3 | ''' 4 | Adapted from https://ben.bolte.cc/viterbi 5 | ''' 6 | 7 | import numpy as np 8 | from typing import List, Tuple 9 | 10 | 11 | def step(mu_prev: np.ndarray, 12 | unary: np.ndarray, 13 | binary: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 14 | pre_max = mu_prev + binary.T 15 | max_prev_states = np.argmax(pre_max, axis=1) 16 | max_vals = pre_max[np.arange(len(max_prev_states)), max_prev_states] 17 | mu_new = max_vals + unary 18 | 19 | return mu_new, max_prev_states 20 | 21 | 22 | def viterbi(unary: np.ndarray, 23 | binary: np.ndarray) -> Tuple[List[int], float]: 24 | # Runs the forward pass, storing the most likely previous state. 25 | mu = unary[:, 0] 26 | all_prev_states = [] 27 | for step_idx in range(1, unary.shape[1]): 28 | mu, prevs = step(mu, unary[:, step_idx], binary) 29 | all_prev_states.append(prevs) 30 | 31 | # Traces backwards to get the maximum likelihood sequence. 32 | state = np.argmax(mu) 33 | sequence_reward = mu[state] 34 | state_sequence = [state] 35 | for prev_states in all_prev_states[::-1]: 36 | state = prev_states[state] 37 | state_sequence.append(state) 38 | 39 | return state_sequence[::-1], sequence_reward 40 | 41 | 42 | def main(): 43 | # Setup a toy example. 44 | num_states = 3 45 | num_time_steps = 4 46 | 47 | # Initialize unary and binary terms for viterbi decoding. 48 | np.random.seed(777) 49 | unary = np.random.rand(num_states, num_time_steps) 50 | unary[1, 0] = -10.0 51 | # binary = np.array([ 52 | # [0.1, 0.2, 0.7], 53 | # [0.1, 0.1, 0.8], 54 | # [0.5, 0.4, 0.1], 55 | # ]) 56 | binary = np.diag(np.ones(num_states)) * 100.0 57 | assert binary.shape == (num_states, num_states) 58 | 59 | # Placeholder defining how we'll call the Viterbi algorithm. 60 | max_seq, seq_reward = viterbi(unary, binary) 61 | 62 | print(max_seq) 63 | print(seq_reward) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() -------------------------------------------------------------------------------- /slowfast/utils/weight_init_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Utility function for weight initialization""" 5 | 6 | import torch.nn as nn 7 | from fvcore.nn.weight_init import c2_msra_fill 8 | 9 | 10 | def init_weights(model, fc_init_std=0.01, zero_init_final_bn=True): 11 | """ 12 | Performs ResNet style weight initialization. 13 | Args: 14 | fc_init_std (float): the expected standard deviation for fc layer. 15 | zero_init_final_bn (bool): if True, zero initialize the final bn for 16 | every bottleneck. 17 | """ 18 | for m in model.modules(): 19 | if isinstance(m, nn.Conv3d): 20 | """ 21 | Follow the initialization method proposed in: 22 | {He, Kaiming, et al. 23 | "Delving deep into rectifiers: Surpassing human-level 24 | performance on imagenet classification." 25 | arXiv preprint arXiv:1502.01852 (2015)} 26 | """ 27 | c2_msra_fill(m) 28 | elif isinstance(m, nn.BatchNorm3d): 29 | if ( 30 | hasattr(m, "transform_final_bn") 31 | and m.transform_final_bn 32 | and zero_init_final_bn 33 | ): 34 | batchnorm_weight = 0.0 35 | else: 36 | batchnorm_weight = 1.0 37 | if m.weight is not None: 38 | m.weight.data.fill_(batchnorm_weight) 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | if isinstance(m, nn.Linear): 42 | m.weight.data.normal_(mean=0.0, std=fc_init_std) 43 | m.bias.data.zero_() 44 | -------------------------------------------------------------------------------- /slowfast/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /slowfast/visualization/async_predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import atexit 5 | import numpy as np 6 | import torch 7 | import torch.multiprocessing as mp 8 | 9 | import slowfast.utils.logging as logging 10 | from slowfast.datasets import cv2_transform 11 | from slowfast.visualization.predictor import Predictor 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | 16 | class AsycnActionPredictor: 17 | class _Predictor(mp.Process): 18 | def __init__(self, cfg, task_queue, result_queue, gpu_id=None): 19 | """ 20 | Predict Worker for Detectron2. 21 | Args: 22 | cfg (CfgNode): configs. Details can be found in 23 | slowfast/config/defaults.py 24 | task_queue (mp.Queue): a shared queue for incoming task. 25 | result_queue (mp.Queue): a shared queue for predicted results. 26 | gpu_id (int): index of the GPU device for the current child process. 27 | """ 28 | super().__init__() 29 | self.cfg = cfg 30 | self.task_queue = task_queue 31 | self.result_queue = result_queue 32 | self.gpu_id = gpu_id 33 | 34 | self.device = ( 35 | torch.device("cuda:{}".format(self.gpu_id)) 36 | if self.cfg.NUM_GPUS 37 | else "cpu" 38 | ) 39 | 40 | def run(self): 41 | """ 42 | Run prediction asynchronously. 43 | """ 44 | # Build the video model and print model statistics. 45 | model = Predictor(self.cfg, gpu_id=self.gpu_id) 46 | while True: 47 | task = self.task_queue.get() 48 | if isinstance(task, _StopToken): 49 | break 50 | task = model(task) 51 | self.result_queue.put(task) 52 | 53 | def __init__(self, cfg, result_queue=None): 54 | num_workers = cfg.NUM_GPUS 55 | 56 | self.task_queue = mp.Queue() 57 | self.result_queue = mp.Queue() if result_queue is None else result_queue 58 | 59 | self.get_idx = -1 60 | self.put_idx = -1 61 | self.procs = [] 62 | cfg = cfg.clone() 63 | cfg.defrost() 64 | cfg.NUM_GPUS = 1 65 | for gpu_id in range(num_workers): 66 | self.procs.append( 67 | AsycnActionPredictor._Predictor( 68 | cfg, self.task_queue, self.result_queue, gpu_id 69 | ) 70 | ) 71 | 72 | self.result_data = {} 73 | for p in self.procs: 74 | p.start() 75 | atexit.register(self.shutdown) 76 | 77 | def put(self, task): 78 | """ 79 | Add the new task to task queue. 80 | """ 81 | self.put_idx += 1 82 | self.task_queue.put(task) 83 | 84 | def get(self): 85 | """ 86 | Return a task object in the correct order based on task id if 87 | result(s) is available. Otherwise, raise queue.Empty exception. 88 | """ 89 | if self.result_data.get(self.get_idx + 1) is not None: 90 | self.get_idx += 1 91 | res = self.result_data[self.get_idx] 92 | del self.result_data[self.get_idx] 93 | return res 94 | while True: 95 | res = self.result_queue.get(block=False) 96 | idx = res.id 97 | if idx == self.get_idx + 1: 98 | self.get_idx += 1 99 | return res 100 | self.result_data[idx] = res 101 | 102 | def __call__(self, task): 103 | self.put(task) 104 | return self.get() 105 | 106 | def shutdown(self): 107 | for _ in self.procs: 108 | self.task_queue.put(_StopToken()) 109 | 110 | @property 111 | def result_available(self): 112 | """ 113 | How many results are ready to be returned. 114 | """ 115 | return self.result_queue.qsize() + len(self.result_data) 116 | 117 | @property 118 | def default_buffer_size(self): 119 | return len(self.procs) * 5 120 | 121 | 122 | class AsyncVis: 123 | class _VisWorker(mp.Process): 124 | def __init__(self, video_vis, task_queue, result_queue): 125 | """ 126 | Visualization Worker for AsyncVis. 127 | Args: 128 | video_vis (VideoVisualizer object): object with tools for visualization. 129 | task_queue (mp.Queue): a shared queue for incoming task for visualization. 130 | result_queue (mp.Queue): a shared queue for visualized results. 131 | """ 132 | self.video_vis = video_vis 133 | self.task_queue = task_queue 134 | self.result_queue = result_queue 135 | super().__init__() 136 | 137 | def run(self): 138 | """ 139 | Run visualization asynchronously. 140 | """ 141 | while True: 142 | task = self.task_queue.get() 143 | if isinstance(task, _StopToken): 144 | break 145 | 146 | frames = draw_predictions(task, self.video_vis) 147 | frames = np.array(frames) 148 | self.result_queue.put((task.id, frames)) 149 | 150 | def __init__(self, video_vis, n_workers=None): 151 | """ 152 | Args: 153 | cfg (CfgNode): configs. Details can be found in 154 | slowfast/config/defaults.py 155 | n_workers (Optional[int]): number of CPUs for running video visualizer. 156 | If not given, use all CPUs. 157 | """ 158 | 159 | num_workers = mp.cpu_count() if n_workers is None else n_workers 160 | 161 | self.task_queue = mp.Queue() 162 | self.result_queue = mp.Queue() 163 | 164 | self.get_idx = -1 165 | self.procs = [] 166 | self.result_data = {} 167 | self.put_id = -1 168 | for _ in range(max(num_workers, 1)): 169 | self.procs.append( 170 | AsyncVis._VisWorker( 171 | video_vis, self.task_queue, self.result_queue 172 | ) 173 | ) 174 | 175 | for p in self.procs: 176 | p.start() 177 | 178 | atexit.register(self.shutdown) 179 | 180 | def put(self, task): 181 | """ 182 | Add the new task to task queue. 183 | """ 184 | self.put_id += 1 185 | self.task_queue.put(task) 186 | 187 | def get(self): 188 | """ 189 | Return visualized frames/clips in the correct order based on task id if 190 | result(s) is available. Otherwise, raise queue.Empty exception. 191 | """ 192 | if self.result_data.get(self.get_idx + 1) is not None: 193 | self.get_idx += 1 194 | res = self.result_data[self.get_idx] 195 | del self.result_data[self.get_idx] 196 | return res 197 | 198 | while True: 199 | idx, res = self.result_queue.get(block=False) 200 | if idx == self.get_idx + 1: 201 | self.get_idx += 1 202 | return res 203 | self.result_data[idx] = res 204 | 205 | def __call__(self, task): 206 | """ 207 | How many results are ready to be returned. 208 | """ 209 | self.put(task) 210 | return self.get() 211 | 212 | def shutdown(self): 213 | for _ in self.procs: 214 | self.task_queue.put(_StopToken()) 215 | 216 | @property 217 | def result_available(self): 218 | return self.result_queue.qsize() + len(self.result_data) 219 | 220 | @property 221 | def default_buffer_size(self): 222 | return len(self.procs) * 5 223 | 224 | 225 | class _StopToken: 226 | pass 227 | 228 | 229 | def draw_predictions(task, video_vis): 230 | """ 231 | Draw prediction for the given task. 232 | Args: 233 | task (TaskInfo object): task object that contain 234 | the necessary information for visualization. (e.g. frames, preds) 235 | All attributes must lie on CPU devices. 236 | video_vis (VideoVisualizer object): the video visualizer object. 237 | """ 238 | boxes = task.bboxes 239 | frames = task.frames 240 | preds = task.action_preds 241 | if boxes is not None: 242 | img_width = task.img_width 243 | img_height = task.img_height 244 | if boxes.device != torch.device("cpu"): 245 | boxes = boxes.cpu() 246 | boxes = cv2_transform.revert_scaled_boxes( 247 | task.crop_size, boxes, img_height, img_width 248 | ) 249 | 250 | keyframe_idx = len(frames) // 2 - task.num_buffer_frames 251 | draw_range = [ 252 | keyframe_idx - task.clip_vis_size, 253 | keyframe_idx + task.clip_vis_size, 254 | ] 255 | frames = frames[task.num_buffer_frames :] 256 | if boxes is not None: 257 | if len(boxes) != 0: 258 | frames = video_vis.draw_clip_range( 259 | frames, 260 | preds, 261 | boxes, 262 | keyframe_idx=keyframe_idx, 263 | draw_range=draw_range, 264 | ) 265 | else: 266 | frames = video_vis.draw_clip_range( 267 | frames, preds, keyframe_idx=keyframe_idx, draw_range=draw_range 268 | ) 269 | del task 270 | 271 | return frames 272 | -------------------------------------------------------------------------------- /slowfast/visualization/demo_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import cv2 5 | 6 | from slowfast.visualization.utils import TaskInfo 7 | 8 | 9 | class VideoReader: 10 | """ 11 | VideoReader object for getting frames from video source for real-time inference. 12 | """ 13 | 14 | def __init__(self, cfg): 15 | """ 16 | Args: 17 | cfg (CfgNode): configs. Details can be found in 18 | slowfast/config/defaults.py 19 | """ 20 | assert ( 21 | cfg.DEMO.WEBCAM > -1 or cfg.DEMO.INPUT_VIDEO != "" 22 | ), "Must specify a data source as input." 23 | 24 | self.source = ( 25 | cfg.DEMO.WEBCAM if cfg.DEMO.WEBCAM > -1 else cfg.DEMO.INPUT_VIDEO 26 | ) 27 | 28 | self.display_width = cfg.DEMO.DISPLAY_WIDTH 29 | self.display_height = cfg.DEMO.DISPLAY_HEIGHT 30 | 31 | self.cap = cv2.VideoCapture(self.source) 32 | 33 | if self.display_width > 0 and self.display_height > 0: 34 | self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.display_width) 35 | self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.display_height) 36 | else: 37 | self.display_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 38 | self.display_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | 40 | if not self.cap.isOpened(): 41 | raise IOError("Video {} cannot be opened".format(self.source)) 42 | 43 | self.output_file = None 44 | if cfg.DEMO.OUTPUT_FILE != "": 45 | if cfg.DEMO.OUTPUT_FPS == -1: 46 | output_fps = self.cap.get(cv2.CAP_PROP_FPS) 47 | else: 48 | output_fps = cfg.DEMO.OUTPUT_FPS 49 | self.output_file = self.get_output_file( 50 | cfg.DEMO.OUTPUT_FILE, fps=output_fps 51 | ) 52 | self.id = -1 53 | self.buffer = [] 54 | self.buffer_size = cfg.DEMO.BUFFER_SIZE 55 | self.seq_length = cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE 56 | self.test_crop_size = cfg.DATA.TEST_CROP_SIZE 57 | self.clip_vis_size = cfg.DEMO.CLIP_VIS_SIZE 58 | 59 | def __iter__(self): 60 | return self 61 | 62 | def __next__(self): 63 | """ 64 | Read and return the required number of frames for 1 clip. 65 | Returns: 66 | was_read (bool): False if not enough frames to return. 67 | task (TaskInfo object): object contains metadata for the current clips. 68 | """ 69 | self.id += 1 70 | task = TaskInfo() 71 | 72 | task.img_height = self.display_height 73 | task.img_width = self.display_width 74 | task.crop_size = self.test_crop_size 75 | task.clip_vis_size = self.clip_vis_size 76 | 77 | frames = [] 78 | if len(self.buffer) != 0: 79 | frames = self.buffer 80 | was_read = True 81 | while was_read and len(frames) < self.seq_length: 82 | was_read, frame = self.cap.read() 83 | frames.append(frame) 84 | if was_read and self.buffer_size != 0: 85 | self.buffer = frames[-self.buffer_size :] 86 | 87 | task.add_frames(self.id, frames) 88 | task.num_buffer_frames = 0 if self.id == 0 else self.buffer_size 89 | 90 | return was_read, task 91 | 92 | def get_output_file(self, path, fps=30): 93 | """ 94 | Return a video writer object. 95 | Args: 96 | path (str): path to the output video file. 97 | fps (int or float): frames per second. 98 | """ 99 | return cv2.VideoWriter( 100 | filename=path, 101 | fourcc=cv2.VideoWriter_fourcc(*"mp4v"), 102 | fps=float(fps), 103 | frameSize=(self.display_width, self.display_height), 104 | isColor=True, 105 | ) 106 | 107 | def display(self, frame): 108 | """ 109 | Either display a single frame (BGR image) to a window or write to 110 | an output file if output path is provided. 111 | """ 112 | if self.output_file is None: 113 | cv2.imshow("SlowFast", frame) 114 | else: 115 | self.output_file.write(frame) 116 | 117 | def clean(self): 118 | """ 119 | Clean up open video files and windows. 120 | """ 121 | self.cap.release() 122 | if self.output_file is None: 123 | cv2.destroyAllWindows() 124 | else: 125 | self.output_file.release() 126 | -------------------------------------------------------------------------------- /slowfast/visualization/gradcam_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import slowfast.datasets.utils as data_utils 9 | from slowfast.visualization.utils import get_layer 10 | 11 | 12 | class GradCAM: 13 | """ 14 | GradCAM class helps create localization maps using the Grad-CAM method for input videos 15 | and overlap the maps over the input videos as heatmaps. 16 | https://arxiv.org/pdf/1610.02391.pdf 17 | """ 18 | 19 | def __init__( 20 | self, model, target_layers, data_mean, data_std, colormap="viridis" 21 | ): 22 | """ 23 | Args: 24 | model (model): the model to be used. 25 | target_layers (list of str(s)): name of convolutional layer to be used to get 26 | gradients and feature maps from for creating localization maps. 27 | data_mean (tensor or list): mean value to add to input videos. 28 | data_std (tensor or list): std to multiply for input videos. 29 | colormap (Optional[str]): matplotlib colormap used to create heatmap. 30 | See https://matplotlib.org/3.3.0/tutorials/colors/colormaps.html 31 | """ 32 | 33 | self.model = model 34 | # Run in eval mode. 35 | self.model.eval() 36 | self.target_layers = target_layers 37 | 38 | self.gradients = {} 39 | self.activations = {} 40 | self.colormap = plt.get_cmap(colormap) 41 | self.data_mean = data_mean 42 | self.data_std = data_std 43 | self._register_hooks() 44 | 45 | def _register_single_hook(self, layer_name): 46 | """ 47 | Register forward and backward hook to a layer, given layer_name, 48 | to obtain gradients and activations. 49 | Args: 50 | layer_name (str): name of the layer. 51 | """ 52 | 53 | def get_gradients(module, grad_input, grad_output): 54 | self.gradients[layer_name] = grad_output[0].detach() 55 | 56 | def get_activations(module, input, output): 57 | self.activations[layer_name] = output.clone().detach() 58 | 59 | target_layer = get_layer(self.model, layer_name=layer_name) 60 | target_layer.register_forward_hook(get_activations) 61 | target_layer.register_backward_hook(get_gradients) 62 | 63 | def _register_hooks(self): 64 | """ 65 | Register hooks to layers in `self.target_layers`. 66 | """ 67 | for layer_name in self.target_layers: 68 | self._register_single_hook(layer_name=layer_name) 69 | 70 | def _calculate_localization_map(self, inputs, labels=None): 71 | """ 72 | Calculate localization map for all inputs with Grad-CAM. 73 | Args: 74 | inputs (list of tensor(s)): the input clips. 75 | labels (Optional[tensor]): labels of the current input clips. 76 | Returns: 77 | localization_maps (list of ndarray(s)): the localization map for 78 | each corresponding input. 79 | preds (tensor): shape (n_instances, n_class). Model predictions for `inputs`. 80 | """ 81 | assert len(inputs) == len( 82 | self.target_layers 83 | ), "Must register the same number of target layers as the number of input pathways." 84 | input_clone = [inp.clone() for inp in inputs] 85 | preds = self.model(input_clone) 86 | 87 | if labels is None: 88 | score = torch.max(preds, dim=-1)[0] 89 | else: 90 | if labels.ndim == 1: 91 | labels = labels.unsqueeze(-1) 92 | score = torch.gather(preds, dim=1, index=labels) 93 | 94 | self.model.zero_grad() 95 | score = torch.sum(score) 96 | score.backward() 97 | localization_maps = [] 98 | for i, inp in enumerate(inputs): 99 | _, _, T, H, W = inp.size() 100 | 101 | gradients = self.gradients[self.target_layers[i]] 102 | activations = self.activations[self.target_layers[i]] 103 | B, C, Tg, _, _ = gradients.size() 104 | 105 | weights = torch.mean(gradients.view(B, C, Tg, -1), dim=3) 106 | 107 | weights = weights.view(B, C, Tg, 1, 1) 108 | localization_map = torch.sum( 109 | weights * activations, dim=1, keepdim=True 110 | ) 111 | localization_map = F.relu(localization_map) 112 | localization_map = F.interpolate( 113 | localization_map, 114 | size=(T, H, W), 115 | mode="trilinear", 116 | align_corners=False, 117 | ) 118 | localization_map_min, localization_map_max = ( 119 | torch.min(localization_map.view(B, -1), dim=-1, keepdim=True)[ 120 | 0 121 | ], 122 | torch.max(localization_map.view(B, -1), dim=-1, keepdim=True)[ 123 | 0 124 | ], 125 | ) 126 | localization_map_min = torch.reshape( 127 | localization_map_min, shape=(B, 1, 1, 1, 1) 128 | ) 129 | localization_map_max = torch.reshape( 130 | localization_map_max, shape=(B, 1, 1, 1, 1) 131 | ) 132 | # Normalize the localization map. 133 | localization_map = (localization_map - localization_map_min) / ( 134 | localization_map_max - localization_map_min + 1e-6 135 | ) 136 | localization_map = localization_map.data 137 | 138 | localization_maps.append(localization_map) 139 | 140 | return localization_maps, preds 141 | 142 | def __call__(self, inputs, labels=None, alpha=0.5): 143 | """ 144 | Visualize the localization maps on their corresponding inputs as heatmap, 145 | using Grad-CAM. 146 | Args: 147 | inputs (list of tensor(s)): the input clips. 148 | labels (Optional[tensor]): labels of the current input clips. 149 | alpha (float): transparency level of the heatmap, in the range [0, 1]. 150 | Returns: 151 | result_ls (list of tensor(s)): the visualized inputs. 152 | preds (tensor): shape (n_instances, n_class). Model predictions for `inputs`. 153 | """ 154 | result_ls = [] 155 | localization_maps, preds = self._calculate_localization_map( 156 | inputs, labels=labels 157 | ) 158 | for i, localization_map in enumerate(localization_maps): 159 | # Convert (B, 1, T, H, W) to (B, T, H, W) 160 | localization_map = localization_map.squeeze(dim=1) 161 | if localization_map.device != torch.device("cpu"): 162 | localization_map = localization_map.cpu() 163 | heatmap = self.colormap(localization_map) 164 | heatmap = heatmap[:, :, :, :, :3] 165 | # Permute input from (B, C, T, H, W) to (B, T, H, W, C) 166 | curr_inp = inputs[i].permute(0, 2, 3, 4, 1) 167 | if curr_inp.device != torch.device("cpu"): 168 | curr_inp = curr_inp.cpu() 169 | curr_inp = data_utils.revert_tensor_normalize( 170 | curr_inp, self.data_mean, self.data_std 171 | ) 172 | heatmap = torch.from_numpy(heatmap) 173 | curr_inp = alpha * heatmap + (1 - alpha) * curr_inp 174 | # Permute inp to (B, T, C, H, W) 175 | curr_inp = curr_inp.permute(0, 1, 4, 2, 3) 176 | result_ls.append(curr_inp) 177 | 178 | return result_ls, preds 179 | -------------------------------------------------------------------------------- /slowfast/visualization/predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import cv2 5 | import torch 6 | from detectron2 import model_zoo 7 | from detectron2.config import get_cfg 8 | from detectron2.engine import DefaultPredictor 9 | 10 | import slowfast.utils.checkpoint as cu 11 | from slowfast.datasets import cv2_transform 12 | from slowfast.models import build_model 13 | from slowfast.utils import logging 14 | from slowfast.visualization.utils import process_cv2_inputs 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | class Predictor: 20 | """ 21 | Action Predictor for action recognition. 22 | """ 23 | 24 | def __init__(self, cfg, gpu_id=None): 25 | """ 26 | Args: 27 | cfg (CfgNode): configs. Details can be found in 28 | slowfast/config/defaults.py 29 | gpu_id (Optional[int]): GPU id. 30 | """ 31 | if cfg.NUM_GPUS: 32 | self.gpu_id = torch.cuda.current_device() if gpu_id is None else gpu_id 33 | 34 | # Build the video model and print model statistics. 35 | self.model = build_model(cfg, gpu_id=gpu_id) 36 | self.model.eval() 37 | self.cfg = cfg 38 | 39 | if cfg.DETECTION.ENABLE: 40 | self.object_detector = Detectron2Predictor(cfg, gpu_id=self.gpu_id) 41 | 42 | logger.info("Start loading model weights.") 43 | cu.load_test_checkpoint(cfg, self.model) 44 | logger.info("Finish loading model weights") 45 | 46 | def __call__(self, task): 47 | """ 48 | Returns the prediction results for the current task. 49 | Args: 50 | task (TaskInfo object): task object that contain 51 | the necessary information for action prediction. (e.g. frames, boxes) 52 | Returns: 53 | task (TaskInfo object): the same task info object but filled with 54 | prediction values (a tensor) and the corresponding boxes for 55 | action detection task. 56 | """ 57 | if self.cfg.DETECTION.ENABLE: 58 | task = self.object_detector(task) 59 | 60 | frames, bboxes = task.frames, task.bboxes 61 | if bboxes is not None: 62 | bboxes = cv2_transform.scale_boxes( 63 | self.cfg.DATA.TEST_CROP_SIZE, 64 | bboxes, 65 | task.img_height, 66 | task.img_width, 67 | ) 68 | if self.cfg.DEMO.INPUT_FORMAT == "BGR": 69 | frames = [ 70 | cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames 71 | ] 72 | 73 | frames = [ 74 | cv2_transform.scale(self.cfg.DATA.TEST_CROP_SIZE, frame) 75 | for frame in frames 76 | ] 77 | inputs = process_cv2_inputs(frames, self.cfg) 78 | if bboxes is not None: 79 | index_pad = torch.full( 80 | size=(bboxes.shape[0], 1), 81 | fill_value=float(0), 82 | device=bboxes.device, 83 | ) 84 | 85 | # Pad frame index for each box. 86 | bboxes = torch.cat([index_pad, bboxes], axis=1) 87 | if self.cfg.NUM_GPUS > 0: 88 | # Transfer the data to the current GPU device. 89 | if isinstance(inputs, (list,)): 90 | for i in range(len(inputs)): 91 | inputs[i] = inputs[i].cuda( 92 | device=torch.device(self.gpu_id), non_blocking=True 93 | ) 94 | else: 95 | inputs = inputs.cuda( 96 | device=torch.device(self.gpu_id), non_blocking=True 97 | ) 98 | if self.cfg.DETECTION.ENABLE and not bboxes.shape[0]: 99 | preds = torch.tensor([]) 100 | else: 101 | preds = self.model(inputs, bboxes) 102 | 103 | if self.cfg.NUM_GPUS: 104 | preds = preds.cpu() 105 | if bboxes is not None: 106 | bboxes = bboxes.detach().cpu() 107 | 108 | preds = preds.detach() 109 | task.add_action_preds(preds) 110 | if bboxes is not None: 111 | task.add_bboxes(bboxes[:, 1:]) 112 | 113 | return task 114 | 115 | 116 | class ActionPredictor: 117 | """ 118 | Synchronous Action Prediction and Visualization pipeline with AsyncVis. 119 | """ 120 | def __init__(self, cfg, async_vis=None, gpu_id=None): 121 | """ 122 | Args: 123 | cfg (CfgNode): configs. Details can be found in 124 | slowfast/config/defaults.py 125 | async_vis (AsyncVis object): asynchronous visualizer. 126 | gpu_id (Optional[int]): GPU id. 127 | """ 128 | self.predictor = Predictor(cfg=cfg, gpu_id=gpu_id) 129 | self.async_vis = async_vis 130 | 131 | def put(self, task): 132 | """ 133 | Make prediction and put the results in `async_vis` task queue. 134 | Args: 135 | task (TaskInfo object): task object that contain 136 | the necessary information for action prediction. (e.g. frames, boxes) 137 | """ 138 | task = self.predictor(task) 139 | self.async_vis.put(task) 140 | 141 | 142 | class Detectron2Predictor: 143 | """ 144 | Wrapper around Detectron2 to return the required predicted bounding boxes 145 | as a ndarray. 146 | """ 147 | 148 | def __init__(self, cfg, gpu_id=None): 149 | """ 150 | Args: 151 | cfg (CfgNode): configs. Details can be found in 152 | slowfast/config/defaults.py 153 | gpu_id (Optional[int]): GPU id. 154 | """ 155 | 156 | self.cfg = get_cfg() 157 | self.cfg.merge_from_file( 158 | model_zoo.get_config_file(cfg.DEMO.DETECTRON2_CFG) 159 | ) 160 | self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = cfg.DEMO.DETECTRON2_THRESH 161 | self.cfg.MODEL.WEIGHTS = cfg.DEMO.DETECTRON2_WEIGHTS 162 | self.cfg.INPUT.FORMAT = cfg.DEMO.INPUT_FORMAT 163 | if cfg.NUM_GPUS and gpu_id is None: 164 | gpu_id = torch.cuda.current_device() 165 | self.cfg.MODEL.DEVICE = ( 166 | "cuda:{}".format(gpu_id) if cfg.NUM_GPUS > 0 else "cpu" 167 | ) 168 | 169 | logger.info("Initialized Detectron2 Object Detection Model.") 170 | 171 | self.predictor = DefaultPredictor(self.cfg) 172 | 173 | def __call__(self, task): 174 | """ 175 | Return bounding boxes predictions as a tensor. 176 | Args: 177 | task (TaskInfo object): task object that contain 178 | the necessary information for action prediction. (e.g. frames, boxes) 179 | Returns: 180 | task (TaskInfo object): the same task info object but filled with 181 | prediction values (a tensor) and the corresponding boxes for 182 | action detection task. 183 | """ 184 | middle_frame = task.frames[len(task.frames) // 2] 185 | outputs = self.predictor(middle_frame) 186 | # Get only human instances 187 | mask = outputs["instances"].pred_classes == 0 188 | pred_boxes = outputs["instances"].pred_boxes.tensor[mask] 189 | task.add_bboxes(pred_boxes) 190 | 191 | return task 192 | -------------------------------------------------------------------------------- /tools/__pycache__/test_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/self-supervised-maclr/8a92ef0586109ad3110376e61be7e97f61f08b0d/tools/__pycache__/test_net.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/train_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/self-supervised-maclr/8a92ef0586109ad3110376e61be7e97f61f08b0d/tools/__pycache__/train_net.cpython-38.pyc -------------------------------------------------------------------------------- /tools/run_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Modified by AWS AI Labs on 07/15/2022 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | 5 | """Wrapper to train and test a video classification model.""" 6 | import os.path as osp 7 | 8 | from slowfast.utils.misc import launch_job 9 | from slowfast.utils.parser import load_config, parse_args 10 | import slowfast.utils.checkpoint as cu 11 | 12 | from test_net import test 13 | from train_net import train 14 | 15 | 16 | def main(): 17 | """ 18 | Main function to spawn the train and test process. 19 | """ 20 | args = parse_args() 21 | cfg = None 22 | 23 | # Create the checkpoint dir. 24 | if cfg is None: cfg = load_config(args, mkdir=False) 25 | cu.make_checkpoint_dir(cfg.OUTPUT_DIR) 26 | 27 | # Perform training. 28 | if cfg.TRAIN.ENABLE: 29 | launch_job(cfg=cfg, init_method=args.init_method, func=train) 30 | 31 | # Perform multi-clip testing. 32 | if cfg.TEST.ENABLE: 33 | launch_job(cfg=cfg, init_method=args.init_method, func=test) 34 | 35 | # Perform full conv testing for every frame. 36 | if cfg.TEST.ENABLE_FULL_CONV_TEST: 37 | cfg.OUTPUT_DIR = osp.join(cfg.OUTPUT_DIR, 'full_test') 38 | cfg.TEST.BATCH_SIZE = 8 39 | cfg.MODEL.FULL_CONV_TEST = True 40 | cfg.DATA.FULL_CONV_NUM_FRAMES = 480 41 | cfg.DATA.FULL_CONV_AUDIO_FRAME_NUM = 1000 42 | cfg.LOG_MODEL_INFO = False 43 | launch_job(cfg=cfg, init_method=args.init_method, func=test) 44 | 45 | # Perform model visualization. 46 | if cfg.TENSORBOARD.ENABLE and cfg.TENSORBOARD.MODEL_VIS.ENABLE: 47 | from visualization import visualize 48 | launch_job(cfg=cfg, init_method=args.init_method, func=visualize) 49 | 50 | # Run demo. 51 | if cfg.DEMO.ENABLE: 52 | from demo_net import demo 53 | demo(cfg) 54 | 55 | 56 | if __name__ == "__main__": 57 | # torch.multiprocessing.set_start_method("forkserver") 58 | main() -------------------------------------------------------------------------------- /tools/test_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Modified by AWS AI Labs on 07/15/2022 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | 5 | """Multi-view test a video classification model.""" 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import slowfast.utils.checkpoint as cu 11 | import slowfast.utils.distributed as du 12 | import slowfast.utils.logging as logging 13 | import slowfast.utils.misc as misc 14 | import slowfast.visualization.tensorboard_vis as tb 15 | from slowfast.datasets import loader 16 | from slowfast.models import build_model 17 | from slowfast.utils.meters import AVAMeter, TestMeter, SimpleTestMeter 18 | 19 | logger = logging.get_logger(__name__) 20 | 21 | 22 | @torch.no_grad() 23 | def perform_test(test_loader, model, test_meter, cfg): 24 | """ 25 | For classification: 26 | Perform mutli-view testing that uniformly samples N clips from a video along 27 | its temporal axis. For each clip, it takes 3 crops to cover the spatial 28 | dimension, followed by averaging the softmax scores across all Nx3 views to 29 | form a video-level prediction. All video predictions are compared to 30 | ground-truth labels and the final testing performance is logged. 31 | For detection: 32 | Perform fully-convolutional testing on the full frames without crop. 33 | Args: 34 | test_loader (loader): video testing loader. 35 | model (model): the pretrained video model to test. 36 | test_meter (TestMeter): testing meters to log and ensemble the testing 37 | results. 38 | cfg (CfgNode): configs. Details can be found in 39 | slowfast/config/defaults.py 40 | """ 41 | # Enable eval mode. 42 | model.eval() 43 | test_meter.iter_tic() 44 | 45 | for cur_iter, (inputs, labels, video_idx, meta) in enumerate(test_loader): 46 | if cfg.NUM_GPUS: 47 | # Transfer the data to the current GPU device. 48 | if isinstance(inputs, (list,)): 49 | for i in range(len(inputs)): 50 | inputs[i] = inputs[i].cuda(non_blocking=True) 51 | else: 52 | inputs = inputs.cuda(non_blocking=True) 53 | 54 | # Transfer the data to the current GPU device. 55 | labels = labels.cuda() 56 | video_idx = video_idx.cuda() 57 | for key, val in meta.items(): 58 | if isinstance(val, (list,)): 59 | for i in range(len(val)): 60 | val[i] = val[i].cuda(non_blocking=True) 61 | else: 62 | meta[key] = val.cuda(non_blocking=True) 63 | 64 | if cfg.DETECTION.ENABLE: 65 | # Compute the predictions. 66 | preds = model(inputs, meta["boxes"]) 67 | ori_boxes = meta["ori_boxes"] 68 | metadata = meta["metadata"] 69 | 70 | preds = preds.detach().cpu() if cfg.NUM_GPUS else preds.detach() 71 | ori_boxes = ( 72 | ori_boxes.detach().cpu() if cfg.NUM_GPUS else ori_boxes.detach() 73 | ) 74 | metadata = ( 75 | metadata.detach().cpu() if cfg.NUM_GPUS else metadata.detach() 76 | ) 77 | 78 | if cfg.NUM_GPUS > 1: 79 | preds = torch.cat(du.all_gather_unaligned(preds), dim=0) 80 | ori_boxes = torch.cat(du.all_gather_unaligned(ori_boxes), dim=0) 81 | metadata = torch.cat(du.all_gather_unaligned(metadata), dim=0) 82 | 83 | test_meter.iter_toc() 84 | # Update and log stats. 85 | test_meter.update_stats(preds, ori_boxes, metadata) 86 | test_meter.log_iter_stats(None, cur_iter) 87 | else: 88 | # Perform the forward pass. 89 | preds = model(inputs)['pred'] 90 | 91 | # Gather all the predictions across all the devices to perform ensemble. 92 | if cfg.NUM_GPUS > 1: 93 | preds, labels, video_idx = du.all_gather( 94 | [preds, labels, video_idx] 95 | ) 96 | for k, v in meta.items(): 97 | meta[k] = du.all_gather([v])[0] 98 | if cfg.NUM_GPUS: 99 | preds = preds.cpu() 100 | labels = labels.cpu() 101 | video_idx = video_idx.cpu() 102 | for k, v in meta.items(): 103 | meta[k] = v.cpu() 104 | test_meter.iter_toc() 105 | # Update and log stats. 106 | test_meter.update_stats( 107 | preds.detach(), labels.detach(), video_idx.detach() 108 | ) 109 | test_meter.update_meta(meta, video_idx.detach()) 110 | test_meter.log_iter_stats(cur_iter) 111 | test_meter.iter_tic() 112 | 113 | test_meter.finalize_metrics() 114 | test_meter.reset() 115 | 116 | 117 | def test(cfg): 118 | """ 119 | Perform multi-view testing on the pretrained video model. 120 | Args: 121 | cfg (CfgNode): configs. Details can be found in 122 | slowfast/config/defaults.py 123 | """ 124 | # Set up environment. 125 | du.init_distributed_training(cfg) 126 | # Set random seed from configs. 127 | np.random.seed(cfg.RNG_SEED) 128 | torch.manual_seed(cfg.RNG_SEED) 129 | 130 | # Setup logging format. 131 | logging.setup_logging(cfg.OUTPUT_DIR) 132 | 133 | # Print config. 134 | logger.info("Test with config:") 135 | logger.info(cfg) 136 | 137 | # Build the video model and print model statistics. 138 | model = build_model(cfg) 139 | if du.is_master_proc() and cfg.LOG_MODEL_INFO: 140 | misc.log_model_info(model, cfg, use_train_input=False) 141 | 142 | # Load test checkpoint. 143 | cu.load_test_checkpoint(cfg, model) 144 | 145 | # Create video testing loaders. 146 | test_loader = loader.construct_loader(cfg, cfg.TEST.SPLIT, mode="test") 147 | logger.info("Testing model for {} iterations".format(len(test_loader))) 148 | 149 | if cfg.DETECTION.ENABLE: 150 | assert cfg.NUM_GPUS == cfg.TEST.BATCH_SIZE or cfg.NUM_GPUS == 0 151 | test_meter = AVAMeter(len(test_loader), cfg, mode="test") 152 | else: 153 | # Create meters for multi-view testing. 154 | if cfg.MODEL.FULL_CONV_TEST: 155 | test_meter = SimpleTestMeter(len(test_loader), cfg.OUTPUT_DIR, cfg) 156 | else: 157 | test_meter = TestMeter( 158 | len(test_loader.dataset) 159 | // (cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS), 160 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS, 161 | cfg.MODEL.NUM_CLASSES, 162 | len(test_loader), 163 | cfg.DATA.MULTI_LABEL, 164 | cfg.TEST.EVAL_METRIC, 165 | cfg.DATA.ENSEMBLE_METHOD, 166 | cfg.OUTPUT_DIR, 167 | cfg, 168 | ) 169 | 170 | # # Perform multi-view test on the entire dataset. 171 | perform_test(test_loader, model, test_meter, cfg) 172 | 173 | --------------------------------------------------------------------------------