├── .github └── intro.png ├── models ├── __init__.py ├── configs │ └── Kinetics │ │ ├── TimeSformer_spaceOnly_8x32_224.yaml │ │ ├── TimeSformer_divST_16x16_448.yaml │ │ ├── TimeSformer_divST_8x32_224.yaml │ │ ├── TimeSformer_divST_96x4_224.yaml │ │ ├── TimeSformer_jointST_8x32_224.yaml │ │ ├── TimeSformer_divST_8x32_224_4gpus.yaml │ │ ├── TimeSformer_divST_8x32_224_TEST.yaml │ │ ├── SLOWFAST_4x16_R50.yaml │ │ ├── SLOWFAST_8x8_R50.yaml │ │ └── SLOWFAST_8x8_R101.yaml ├── vit_utils.py ├── s3d.py └── helpers.py ├── datasets ├── __init__.py ├── preprocessing │ ├── kinetics_mini.py │ ├── verify_file_list.py │ ├── create_lists.py │ ├── check_corrupt_videos.py │ ├── downsample_kinetics.py │ ├── resize_videos.py │ └── flow_vis.py ├── video_container.py ├── build.py ├── DATASET.md ├── rand_conv.py ├── multigrid_helper.py ├── loader.py ├── ssv2.py ├── ucf101.py ├── hmdb51.py ├── data_utils.py └── kinetics.py ├── requirements.txt ├── scripts ├── eval_knn.sh ├── train.sh └── eval_linear.sh ├── utils ├── logging.py ├── parser.py ├── metrics.py └── meters.py ├── LICENSE ├── .gitignore ├── README.md ├── eval_knn.py ├── vision_transformer.py └── eval_linear.py /.github/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kahnchana/svt/HEAD/.github/intro.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .timesformer import get_vit_base_patch16_224, get_aux_token_vit 2 | from .swin_transformer import SwinTransformer3D 3 | from .s3d import S3D 4 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from .kinetics import Kinetics # noqa 4 | from .ucf101 import UCF101 5 | from .hmdb51 import HMDB51 6 | # from .ssv2 import Ssv2 # noqa 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.4 2 | torch>=1.7.1 3 | torchvision>=0.2.1 4 | pillow>=5.4.1 5 | fvcore>=0.1.5 6 | sklearn>=0.0 7 | scikit-learn>=0.23.2 8 | simplejson>=3.17.5 9 | einops>=0.3.0 10 | timm>=0.4.12 11 | tqdm>=4.29.1 12 | kornia>=0.5.8 13 | opencv-python>=3.4.3.18 14 | pandas>=0.23.4 15 | joblib>=0.16.0 16 | matplotlib>=3.1.1 17 | requests>=2.25.1 18 | scikit-image>=0.15.0 19 | av>=9.2.0 -------------------------------------------------------------------------------- /scripts/eval_knn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PROJECT_PATH="$HOME/repo/svt" 4 | CHECKPOINT="path/to/checkpoint.pth" 5 | DATASET="ucf101" 6 | DATA_PATH="${HOME}/repo/mmaction2/data/${DATASET}" 7 | 8 | cd "$PROJECT_PATH" || exit 9 | 10 | export CUDA_VISIBLE_DEVICES=0 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=1 \ 13 | --master_port="$RANDOM" \ 14 | eval_knn.py \ 15 | --arch "vit_base" \ 16 | --pretrained_weights "$CHECKPOINT" \ 17 | --batch_size_per_gpu 128 \ 18 | --nb_knn 5 \ 19 | --temperature 0.07 \ 20 | --num_workers 4 \ 21 | --dataset "$DATASET" \ 22 | --opts \ 23 | DATA.PATH_TO_DATA_DIR "${DATA_PATH}/knn_splits" \ 24 | DATA.PATH_PREFIX f"${DATA_PATH}/videos" 25 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PROJECT_PATH="$HOME/repo/svt" 4 | DATA_PATH="$HOME/data/kinetics/400/annotations" 5 | EXP_NAME="svt_test" 6 | 7 | cd "$PROJECT_PATH" || exit 8 | 9 | if [ ! -d "checkpoints/$EXP_NAME" ]; then 10 | mkdir "checkpoints/$EXP_NAME" 11 | fi 12 | 13 | export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | 15 | python -m torch.distributed.launch \ 16 | --nproc_per_node=4 \ 17 | --master_port="$RANDOM" \ 18 | train_ssl.py \ 19 | --arch "timesformer" \ 20 | --batch_size_per_gpu 8 \ 21 | --data_path "${DATA_PATH}" \ 22 | --output_dir "checkpoints/$EXP_NAME" \ 23 | --opts \ 24 | MODEL.TWO_STREAM False \ 25 | MODEL.TWO_TOKEN False \ 26 | DATA.NO_FLOW_AUG False \ 27 | DATA.USE_FLOW False \ 28 | DATA.RAND_CONV False \ 29 | DATA.NO_SPATIAL False 30 | 31 | -------------------------------------------------------------------------------- /scripts/eval_linear.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PROJECT_PATH="$HOME/repo/svt" 4 | EXP_NAME="le_001" 5 | DATASET="ucf101" 6 | DATA_PATH="${HOME}/repo/mmaction2/data/${DATASET}" 7 | CHECKPOINT="path/to/checkpoint.pth" 8 | 9 | cd "$PROJECT_PATH" || exit 10 | 11 | if [ ! -d "checkpoints/$EXP_NAME" ]; then 12 | mkdir "checkpoints/$EXP_NAME" 13 | fi 14 | 15 | export CUDA_VISIBLE_DEVICES=0 16 | python -m torch.distributed.launch \ 17 | --nproc_per_node=1 \ 18 | --master_port="$RANDOM" \ 19 | eval_linear.py \ 20 | --n_last_blocks 1 \ 21 | --arch "vit_base" \ 22 | --pretrained_weights "$CHECKPOINT" \ 23 | --epochs 20 \ 24 | --lr 0.001 \ 25 | --batch_size_per_gpu 16 \ 26 | --num_workers 4 \ 27 | --num_labels 101 \ 28 | --dataset "$DATASET" \ 29 | --output_dir "checkpoints/eval/$EXP_NAME" \ 30 | --opts \ 31 | DATA.PATH_TO_DATA_DIR "${DATA_PATH}/splits" \ 32 | DATA.PATH_PREFIX f"${DATA_PATH}/videos" \ 33 | DATA.USE_FLOW False 34 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Logging.""" 4 | 5 | import decimal 6 | 7 | import simplejson 8 | 9 | import logging 10 | 11 | 12 | def get_logger(name): 13 | """ 14 | Retrieve the logger with the specified name or, if name is None, return a 15 | logger which is the root logger of the hierarchy. 16 | Args: 17 | name (string): name of the logger. 18 | """ 19 | return logging.getLogger(name) 20 | 21 | 22 | def log_json_stats(stats): 23 | """ 24 | Logs json stats. 25 | Args: 26 | stats (dict): a dictionary of statistical information to log. 27 | """ 28 | stats = { 29 | k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v 30 | for k, v in stats.items() 31 | } 32 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 33 | logger = get_logger(__name__) 34 | logger.info("json_stats: {:s}".format(json_stats)) 35 | -------------------------------------------------------------------------------- /datasets/preprocessing/kinetics_mini.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | # root_path = "/home/kanchanaranasinghe/data/kinetics400" 4 | # anno_folder = "k400-mini" 5 | # 6 | # for split in ["train", "val",]: 7 | # 8 | # save_path = f"{root_path}/{anno_folder}/{split}.csv" 9 | # file_list = pd.read_csv(save_path, sep=" ") 10 | # 11 | # if split == "train": 12 | # file_list = file_list.sample(n=60000, replace=False) 13 | # new_save_path = f"{root_path}/{anno_folder}/{split}_60k.csv" 14 | # file_list.to_csv(new_save_path, index=False, header=False, sep=" ") 15 | # elif split == "val": 16 | # file_list = file_list.sample(n=10000, replace=False) 17 | # new_save_path = f"{root_path}/{anno_folder}/{split}_10k.csv" 18 | # file_list.to_csv(new_save_path, index=False, header=False, sep=" ") 19 | 20 | 21 | file_path = "/home/kanchanaranasinghe/data/kinetics400/k400-mini/train_60k.csv" 22 | df = pd.read_csv(file_path, header=None, sep=" ") 23 | print(f"{len(df[1].unique())} unique classes") 24 | df[1].hist(bins=len(df[1].unique())) 25 | -------------------------------------------------------------------------------- /models/configs/Kinetics/TimeSformer_spaceOnly_8x32_224.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 8 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | DATA: 9 | PATH_TO_DATA_DIR: /path/to/kinetics/ 10 | NUM_FRAMES: 8 11 | SAMPLING_RATE: 32 12 | TRAIN_JITTER_SCALES: [256, 320] 13 | TRAIN_CROP_SIZE: 224 14 | TEST_CROP_SIZE: 224 15 | INPUT_CHANNEL_NUM: [3] 16 | TIMESFORMER: 17 | ATTENTION_TYPE: 'space_only' 18 | SOLVER: 19 | BASE_LR: 0.005 20 | LR_POLICY: steps_with_relative_lrs 21 | STEPS: [0, 11, 14] 22 | LRS: [1, 0.1, 0.01] 23 | MAX_EPOCH: 15 24 | MOMENTUM: 0.9 25 | WEIGHT_DECAY: 1e-4 26 | OPTIMIZING_METHOD: sgd 27 | MODEL: 28 | MODEL_NAME: vit_base_patch16_224 29 | NUM_CLASSES: 400 30 | ARCH: vit 31 | LOSS_FUNC: cross_entropy 32 | DROPOUT_RATE: 0.5 33 | TEST: 34 | ENABLE: True 35 | DATASET: kinetics 36 | BATCH_SIZE: 8 37 | NUM_ENSEMBLE_VIEWS: 1 38 | NUM_SPATIAL_CROPS: 3 39 | DATA_LOADER: 40 | NUM_WORKERS: 8 41 | PIN_MEMORY: True 42 | NUM_GPUS: 8 43 | NUM_SHARDS: 1 44 | RNG_SEED: 0 45 | OUTPUT_DIR: . 46 | -------------------------------------------------------------------------------- /models/configs/Kinetics/TimeSformer_divST_16x16_448.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 8 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | DATA: 9 | PATH_TO_DATA_DIR: /path/to/kinetics/ 10 | NUM_FRAMES: 16 11 | SAMPLING_RATE: 16 12 | TRAIN_JITTER_SCALES: [448, 512] 13 | TRAIN_CROP_SIZE: 448 14 | TEST_CROP_SIZE: 448 15 | INPUT_CHANNEL_NUM: [3] 16 | TIMESFORMER: 17 | ATTENTION_TYPE: 'divided_space_time' 18 | SOLVER: 19 | BASE_LR: 0.005 20 | LR_POLICY: steps_with_relative_lrs 21 | STEPS: [0, 11, 14] 22 | LRS: [1, 0.1, 0.01] 23 | MAX_EPOCH: 15 24 | MOMENTUM: 0.9 25 | WEIGHT_DECAY: 1e-4 26 | OPTIMIZING_METHOD: sgd 27 | MODEL: 28 | MODEL_NAME: vit_base_patch16_224 29 | NUM_CLASSES: 400 30 | ARCH: vit 31 | LOSS_FUNC: cross_entropy 32 | DROPOUT_RATE: 0.5 33 | TEST: 34 | ENABLE: True 35 | DATASET: kinetics 36 | BATCH_SIZE: 8 37 | NUM_ENSEMBLE_VIEWS: 1 38 | NUM_SPATIAL_CROPS: 3 39 | DATA_LOADER: 40 | NUM_WORKERS: 8 41 | PIN_MEMORY: True 42 | NUM_GPUS: 8 43 | NUM_SHARDS: 1 44 | RNG_SEED: 0 45 | OUTPUT_DIR: . 46 | -------------------------------------------------------------------------------- /models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 8 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | DATA: 9 | PATH_TO_DATA_DIR: /path/to/kinetics/ 10 | NUM_FRAMES: 8 11 | SAMPLING_RATE: 32 12 | TRAIN_JITTER_SCALES: [256, 320] 13 | TRAIN_CROP_SIZE: 224 14 | TEST_CROP_SIZE: 224 15 | INPUT_CHANNEL_NUM: [3] 16 | TIMESFORMER: 17 | ATTENTION_TYPE: 'divided_space_time' 18 | SOLVER: 19 | BASE_LR: 0.005 20 | LR_POLICY: steps_with_relative_lrs 21 | STEPS: [0, 11, 14] 22 | LRS: [1, 0.1, 0.01] 23 | MAX_EPOCH: 15 24 | MOMENTUM: 0.9 25 | WEIGHT_DECAY: 1e-4 26 | OPTIMIZING_METHOD: sgd 27 | MODEL: 28 | MODEL_NAME: vit_base_patch16_224 29 | NUM_CLASSES: 400 30 | ARCH: vit 31 | LOSS_FUNC: cross_entropy 32 | DROPOUT_RATE: 0.5 33 | TEST: 34 | ENABLE: True 35 | DATASET: kinetics 36 | BATCH_SIZE: 8 37 | NUM_ENSEMBLE_VIEWS: 1 38 | NUM_SPATIAL_CROPS: 3 39 | DATA_LOADER: 40 | NUM_WORKERS: 8 41 | PIN_MEMORY: True 42 | NUM_GPUS: 8 43 | NUM_SHARDS: 1 44 | RNG_SEED: 0 45 | OUTPUT_DIR: . 46 | -------------------------------------------------------------------------------- /models/configs/Kinetics/TimeSformer_divST_96x4_224.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 8 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | DATA: 9 | PATH_TO_DATA_DIR: /path/to/kinetics/ 10 | NUM_FRAMES: 96 11 | SAMPLING_RATE: 4 12 | TRAIN_JITTER_SCALES: [256, 320] 13 | TRAIN_CROP_SIZE: 224 14 | TEST_CROP_SIZE: 224 15 | INPUT_CHANNEL_NUM: [3] 16 | TIMESFORMER: 17 | ATTENTION_TYPE: 'divided_space_time' 18 | SOLVER: 19 | BASE_LR: 0.005 20 | LR_POLICY: steps_with_relative_lrs 21 | STEPS: [0, 11, 14] 22 | LRS: [1, 0.1, 0.01] 23 | MAX_EPOCH: 15 24 | MOMENTUM: 0.9 25 | WEIGHT_DECAY: 1e-4 26 | OPTIMIZING_METHOD: sgd 27 | MODEL: 28 | MODEL_NAME: vit_base_patch16_224 29 | NUM_CLASSES: 400 30 | ARCH: vit 31 | LOSS_FUNC: cross_entropy 32 | DROPOUT_RATE: 0.5 33 | TEST: 34 | ENABLE: True 35 | DATASET: kinetics 36 | BATCH_SIZE: 8 37 | NUM_ENSEMBLE_VIEWS: 1 38 | NUM_SPATIAL_CROPS: 3 39 | DATA_LOADER: 40 | NUM_WORKERS: 8 41 | PIN_MEMORY: True 42 | NUM_GPUS: 8 43 | NUM_SHARDS: 1 44 | RNG_SEED: 0 45 | OUTPUT_DIR: . 46 | -------------------------------------------------------------------------------- /models/configs/Kinetics/TimeSformer_jointST_8x32_224.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 8 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | DATA: 9 | PATH_TO_DATA_DIR: /path/to/kinetics/ 10 | NUM_FRAMES: 8 11 | SAMPLING_RATE: 32 12 | TRAIN_JITTER_SCALES: [256, 320] 13 | TRAIN_CROP_SIZE: 224 14 | TEST_CROP_SIZE: 224 15 | INPUT_CHANNEL_NUM: [3] 16 | TIMESFORMER: 17 | ATTENTION_TYPE: 'joint_space_time' 18 | SOLVER: 19 | BASE_LR: 0.005 20 | LR_POLICY: steps_with_relative_lrs 21 | STEPS: [0, 11, 14] 22 | LRS: [1, 0.1, 0.01] 23 | MAX_EPOCH: 15 24 | MOMENTUM: 0.9 25 | WEIGHT_DECAY: 1e-4 26 | OPTIMIZING_METHOD: sgd 27 | MODEL: 28 | MODEL_NAME: vit_base_patch16_224 29 | NUM_CLASSES: 400 30 | ARCH: vit 31 | LOSS_FUNC: cross_entropy 32 | DROPOUT_RATE: 0.5 33 | TEST: 34 | ENABLE: True 35 | DATASET: kinetics 36 | BATCH_SIZE: 8 37 | NUM_ENSEMBLE_VIEWS: 1 38 | NUM_SPATIAL_CROPS: 3 39 | DATA_LOADER: 40 | NUM_WORKERS: 8 41 | PIN_MEMORY: True 42 | NUM_GPUS: 8 43 | NUM_SHARDS: 1 44 | RNG_SEED: 0 45 | OUTPUT_DIR: . 46 | -------------------------------------------------------------------------------- /models/configs/Kinetics/TimeSformer_divST_8x32_224_4gpus.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 4 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | DATA: 9 | PATH_TO_DATA_DIR: /path/to/kinetics/ 10 | NUM_FRAMES: 8 11 | SAMPLING_RATE: 32 12 | TRAIN_JITTER_SCALES: [256, 320] 13 | TRAIN_CROP_SIZE: 224 14 | TEST_CROP_SIZE: 224 15 | INPUT_CHANNEL_NUM: [3] 16 | TIMESFORMER: 17 | ATTENTION_TYPE: 'divided_space_time' 18 | SOLVER: 19 | BASE_LR: 0.005 20 | LR_POLICY: steps_with_relative_lrs 21 | STEPS: [0, 11, 14] 22 | LRS: [1, 0.1, 0.01] 23 | MAX_EPOCH: 15 24 | MOMENTUM: 0.9 25 | WEIGHT_DECAY: 1e-4 26 | OPTIMIZING_METHOD: sgd 27 | MODEL: 28 | MODEL_NAME: vit_base_patch16_224 29 | NUM_CLASSES: 400 30 | ARCH: vit 31 | LOSS_FUNC: cross_entropy 32 | DROPOUT_RATE: 0.5 33 | TEST: 34 | ENABLE: True 35 | DATASET: kinetics 36 | BATCH_SIZE: 4 37 | NUM_ENSEMBLE_VIEWS: 1 38 | NUM_SPATIAL_CROPS: 3 39 | DATA_LOADER: 40 | NUM_WORKERS: 4 41 | PIN_MEMORY: True 42 | NUM_GPUS: 4 43 | NUM_SHARDS: 1 44 | RNG_SEED: 0 45 | OUTPUT_DIR: . 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kanchana Ranasinghe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/video_container.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | import av 4 | 5 | 6 | def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"): 7 | """ 8 | Given the path to the video, return the pyav video container. 9 | Args: 10 | path_to_vid (str): path to the video. 11 | multi_thread_decode (bool): if True, perform multi-thread decoding. 12 | backend (str): decoder backend, options include `pyav` and 13 | `torchvision`, default is `pyav`. 14 | Returns: 15 | container (container): video container. 16 | """ 17 | if backend == "torchvision": 18 | with open(path_to_vid, "rb") as fp: 19 | container = fp.read() 20 | return container 21 | elif backend == "pyav": 22 | #try: 23 | container = av.open(path_to_vid) 24 | if multi_thread_decode: 25 | # Enable multiple threads for decoding. 26 | container.streams.video[0].thread_type = "AUTO" 27 | #except: 28 | # container = None 29 | return container 30 | else: 31 | raise NotImplementedError("Unknown backend {}".format(backend)) 32 | -------------------------------------------------------------------------------- /datasets/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from fvcore.common.registry import Registry 4 | 5 | DATASET_REGISTRY = Registry("DATASET") 6 | DATASET_REGISTRY.__doc__ = """ 7 | Registry for dataset. 8 | 9 | The registered object will be called with `obj(cfg, split)`. 10 | The call should return a `torch.utils.data.Dataset` object. 11 | """ 12 | 13 | 14 | def build_dataset(dataset_name, cfg, split): 15 | """ 16 | Build a dataset, defined by `dataset_name`. 17 | Args: 18 | dataset_name (str): the name of the dataset to be constructed. 19 | cfg (CfgNode): configs. Details can be found in 20 | slowfast/config/defaults.py 21 | split (str): the split of the data loader. Options include `train`, 22 | `val`, and `test`. 23 | Returns: 24 | Dataset: a constructed dataset specified by dataset_name. 25 | """ 26 | # Capitalize the the first letter of the dataset_name since the dataset_name 27 | # in configs may be in lowercase but the name of dataset class should always 28 | # start with an uppercase letter. 29 | name = dataset_name.capitalize() 30 | return DATASET_REGISTRY.get(name)(cfg, split) 31 | -------------------------------------------------------------------------------- /models/configs/Kinetics/TimeSformer_divST_8x32_224_TEST.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: False 3 | DATASET: kinetics 4 | BATCH_SIZE: 8 5 | EVAL_PERIOD: 5 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | DATA: 9 | PATH_TO_DATA_DIR: /path/to/kinetics/ 10 | NUM_FRAMES: 8 11 | SAMPLING_RATE: 32 12 | TRAIN_JITTER_SCALES: [256, 320] 13 | TRAIN_CROP_SIZE: 224 14 | TEST_CROP_SIZE: 224 15 | INPUT_CHANNEL_NUM: [3] 16 | TIMESFORMER: 17 | ATTENTION_TYPE: 'divided_space_time' 18 | SOLVER: 19 | BASE_LR: 0.005 20 | LR_POLICY: steps_with_relative_lrs 21 | STEPS: [0, 11, 14] 22 | LRS: [1, 0.1, 0.01] 23 | MAX_EPOCH: 15 24 | MOMENTUM: 0.9 25 | WEIGHT_DECAY: 1e-4 26 | OPTIMIZING_METHOD: sgd 27 | MODEL: 28 | MODEL_NAME: vit_base_patch16_224 29 | NUM_CLASSES: 400 30 | ARCH: vit 31 | LOSS_FUNC: cross_entropy 32 | DROPOUT_RATE: 0.5 33 | TEST: 34 | ENABLE: True 35 | DATASET: kinetics 36 | BATCH_SIZE: 8 37 | NUM_ENSEMBLE_VIEWS: 1 38 | NUM_SPATIAL_CROPS: 3 39 | CHECKPOINT_FILE_PATH: '/checkpoint/gedas/jobs/timesformer/kinetics_400/TimeSformer_divST_8x32_224/checkpoints/checkpoint_epoch_00025.pyth' 40 | DATA_LOADER: 41 | NUM_WORKERS: 8 42 | PIN_MEMORY: True 43 | NUM_GPUS: 8 44 | NUM_SHARDS: 1 45 | RNG_SEED: 0 46 | OUTPUT_DIR: . 47 | -------------------------------------------------------------------------------- /datasets/preprocessing/verify_file_list.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from tqdm import tqdm 3 | 4 | from datasets.decoder import decode 5 | from datasets.video_container import get_video_container 6 | 7 | root_path = "/home/kanchanaranasinghe/data/kinetics400" 8 | 9 | for split in ["train", "val", "test"]: 10 | anno_folder = "new_annotations" 11 | 12 | if split == "train": 13 | save_path = f"{root_path}/{anno_folder}/{split}_60k.csv" 14 | file_list = pd.read_csv(save_path, sep=" ") 15 | elif split in ["val", "test"]: 16 | save_path = f"{root_path}/{anno_folder}/{split}_10k.csv" 17 | file_list = pd.read_csv(save_path, sep=" ") 18 | else: 19 | raise NotImplementedError("invalid split") 20 | 21 | good_count = 0 22 | frames = "0" 23 | bad_list = [] 24 | print(f"Processing files from: {save_path}") 25 | for file in tqdm(file_list.values[:, 0]): 26 | try: 27 | container = get_video_container(file, True) 28 | frames = decode(container, 32, 8) 29 | assert frames is not None, "frames is None" 30 | except Exception as e: 31 | print(e, file) 32 | bad_list.append(file) 33 | else: 34 | good_count += 1 35 | 36 | print(f"{len(file_list) - good_count} files bad. {good_count} / {len(file_list)} files are good.") 37 | -------------------------------------------------------------------------------- /datasets/DATASET.md: -------------------------------------------------------------------------------- 1 | # Dataset Preparation 2 | 3 | ## Kinetics 4 | 5 | The Kinetics Dataset could be downloaded from the following [link](https://github.com/cvdfoundation/kinetics-dataset): 6 | 7 | After all the videos were downloaded, resize the video to the short edge size of 256, then prepare the csv files for training, validation, and testing set as `train.csv`, `val.csv`, `test.csv`. The format of the csv file is: 8 | 9 | ``` 10 | path_to_video_1 label_1 11 | path_to_video_2 label_2 12 | path_to_video_3 label_3 13 | ... 14 | path_to_video_N label_N 15 | ``` 16 | 17 | ## Something-Something V2 18 | 1. Please download the dataset and annotations from [dataset provider](https://20bn.com/datasets/something-something). 19 | 20 | 2. Download the *frame list* from the following links: ([train](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/val.csv)). 21 | 22 | 3. Extract the frames at 30 FPS. (We used ffmpeg-4.1.3 with command 23 | `ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}"` 24 | in experiments.) Please put the frames in a structure consistent with the frame lists. 25 | 26 | Please put all annotation json files and the frame lists in the same folder, and set `DATA.PATH_TO_DATA_DIR` to the path. Set `DATA.PATH_PREFIX` to be the path to the folder containing extracted frames. 27 | -------------------------------------------------------------------------------- /datasets/preprocessing/create_lists.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from tqdm import tqdm 4 | import json 5 | 6 | root_path = "/home/kanchanaranasinghe/data/kinetics400" 7 | split = "val" 8 | anno_folder = "new_annotations" 9 | 10 | video_list = set(os.listdir(f"{root_path}/{split}_d256")) 11 | save_path = f"{root_path}/{anno_folder}/{split}.csv" 12 | os.makedirs(f"{root_path}/{anno_folder}", exist_ok=True) 13 | 14 | labels = pd.read_csv(f"{root_path}/annotations/{split}.csv") 15 | label_dict = {y: x for x, y in enumerate(sorted(labels.label.unique().tolist()))} 16 | json.dump(label_dict, open(f"{root_path}/new_annotations/{split}_label_dict.json", "w")) 17 | 18 | with open(f"{root_path}/bad_files_{split}.txt", "r") as fo: 19 | bad_videos = fo.readlines() 20 | bad_videos = [x.strip() for x in bad_videos] 21 | 22 | video_label = [] 23 | for idx, row in tqdm(labels.iterrows()): 24 | video_name = f"{row.youtube_id}_{row.time_start:06d}_{row.time_end:06d}.mp4" 25 | if video_name in bad_videos: 26 | continue 27 | if video_name not in video_list: 28 | continue 29 | assert os.path.exists(f"{root_path}/{split}_d256/{video_name}"), "video not found" 30 | video_label.append((video_name, label_dict[row.label])) 31 | 32 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 33 | pd.DataFrame(video_label).to_csv(save_path, index=False, header=False, sep=" ") 34 | -------------------------------------------------------------------------------- /models/configs/Kinetics/SLOWFAST_4x16_R50.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 64 5 | EVAL_PERIOD: 10 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | DATA: 9 | PATH_TO_DATA_DIR: /path/to/kinetics/ 10 | NUM_FRAMES: 32 11 | SAMPLING_RATE: 2 12 | TRAIN_JITTER_SCALES: [256, 320] 13 | TRAIN_CROP_SIZE: 224 14 | TEST_CROP_SIZE: 256 15 | INPUT_CHANNEL_NUM: [3, 3] 16 | SLOWFAST: 17 | ALPHA: 8 18 | BETA_INV: 8 19 | FUSION_CONV_CHANNEL_RATIO: 2 20 | FUSION_KERNEL_SZ: 5 21 | RESNET: 22 | ZERO_INIT_FINAL_BN: True 23 | WIDTH_PER_GROUP: 64 24 | NUM_GROUPS: 1 25 | DEPTH: 50 26 | TRANS_FUNC: bottleneck_transform 27 | STRIDE_1X1: False 28 | NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]] 29 | SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]] 30 | SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]] 31 | NONLOCAL: 32 | LOCATION: [[[], []], [[], []], [[], []], [[], []]] 33 | GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]] 34 | INSTANTIATION: dot_product 35 | BN: 36 | USE_PRECISE_STATS: True 37 | NUM_BATCHES_PRECISE: 200 38 | SOLVER: 39 | BASE_LR: 0.8 40 | LR_POLICY: cosine 41 | MAX_EPOCH: 196 42 | MOMENTUM: 0.9 43 | WEIGHT_DECAY: 1e-4 44 | WARMUP_EPOCHS: 34.0 45 | WARMUP_START_LR: 0.01 46 | OPTIMIZING_METHOD: sgd 47 | MODEL: 48 | NUM_CLASSES: 400 49 | ARCH: slowfast 50 | MODEL_NAME: SlowFast 51 | LOSS_FUNC: cross_entropy 52 | DROPOUT_RATE: 0.5 53 | TEST: 54 | ENABLE: True 55 | DATASET: kinetics 56 | BATCH_SIZE: 64 57 | DATA_LOADER: 58 | NUM_WORKERS: 8 59 | PIN_MEMORY: True 60 | NUM_GPUS: 8 61 | NUM_SHARDS: 1 62 | RNG_SEED: 0 63 | OUTPUT_DIR: . 64 | -------------------------------------------------------------------------------- /models/configs/Kinetics/SLOWFAST_8x8_R50.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 64 5 | EVAL_PERIOD: 10 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | DATA: 9 | PATH_TO_DATA_DIR: /path/to/kinetics/ 10 | NUM_FRAMES: 32 11 | SAMPLING_RATE: 2 12 | TRAIN_JITTER_SCALES: [256, 320] 13 | TRAIN_CROP_SIZE: 224 14 | TEST_CROP_SIZE: 256 15 | INPUT_CHANNEL_NUM: [3, 3] 16 | SLOWFAST: 17 | ALPHA: 4 18 | BETA_INV: 8 19 | FUSION_CONV_CHANNEL_RATIO: 2 20 | FUSION_KERNEL_SZ: 7 21 | RESNET: 22 | ZERO_INIT_FINAL_BN: True 23 | WIDTH_PER_GROUP: 64 24 | NUM_GROUPS: 1 25 | DEPTH: 50 26 | TRANS_FUNC: bottleneck_transform 27 | STRIDE_1X1: False 28 | NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]] 29 | SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]] 30 | SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]] 31 | NONLOCAL: 32 | LOCATION: [[[], []], [[], []], [[], []], [[], []]] 33 | GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]] 34 | INSTANTIATION: dot_product 35 | BN: 36 | USE_PRECISE_STATS: True 37 | NUM_BATCHES_PRECISE: 200 38 | SOLVER: 39 | BASE_LR: 0.8 40 | LR_POLICY: cosine 41 | MAX_EPOCH: 196 42 | MOMENTUM: 0.9 43 | WEIGHT_DECAY: 1e-4 44 | WARMUP_EPOCHS: 34.0 45 | WARMUP_START_LR: 0.01 46 | OPTIMIZING_METHOD: sgd 47 | MODEL: 48 | NUM_CLASSES: 400 49 | ARCH: slowfast 50 | MODEL_NAME: SlowFast 51 | LOSS_FUNC: cross_entropy 52 | DROPOUT_RATE: 0.5 53 | TEST: 54 | ENABLE: True 55 | DATASET: kinetics 56 | BATCH_SIZE: 64 57 | DATA_LOADER: 58 | NUM_WORKERS: 8 59 | PIN_MEMORY: True 60 | NUM_GPUS: 8 61 | NUM_SHARDS: 1 62 | RNG_SEED: 0 63 | OUTPUT_DIR: . 64 | -------------------------------------------------------------------------------- /models/configs/Kinetics/SLOWFAST_8x8_R101.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: kinetics 4 | BATCH_SIZE: 64 5 | EVAL_PERIOD: 10 6 | CHECKPOINT_PERIOD: 5 7 | AUTO_RESUME: True 8 | DATA: 9 | PATH_TO_DATA_DIR: /path/to/kinetics/ 10 | NUM_FRAMES: 32 11 | SAMPLING_RATE: 2 12 | TRAIN_JITTER_SCALES: [256, 340] 13 | TRAIN_CROP_SIZE: 224 14 | TEST_CROP_SIZE: 256 15 | INPUT_CHANNEL_NUM: [3, 3] 16 | SLOWFAST: 17 | ALPHA: 4 18 | BETA_INV: 8 19 | FUSION_CONV_CHANNEL_RATIO: 2 20 | FUSION_KERNEL_SZ: 5 21 | RESNET: 22 | ZERO_INIT_FINAL_BN: True 23 | WIDTH_PER_GROUP: 64 24 | NUM_GROUPS: 1 25 | DEPTH: 101 26 | TRANS_FUNC: bottleneck_transform 27 | STRIDE_1X1: False 28 | NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]] 29 | SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]] 30 | SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]] 31 | NONLOCAL: 32 | LOCATION: [[[], []], [[], []], [[], []], [[], []]] 33 | GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]] 34 | INSTANTIATION: dot_product 35 | BN: 36 | USE_PRECISE_STATS: True 37 | NUM_BATCHES_PRECISE: 200 38 | SOLVER: 39 | BASE_LR: 0.8 ## 8 nodes 40 | LR_POLICY: cosine 41 | MAX_EPOCH: 196 42 | MOMENTUM: 0.9 43 | WEIGHT_DECAY: 1e-4 44 | WARMUP_EPOCHS: 34.0 45 | WARMUP_START_LR: 0.01 46 | OPTIMIZING_METHOD: sgd 47 | MODEL: 48 | NUM_CLASSES: 400 49 | ARCH: slowfast 50 | MODEL_NAME: SlowFast 51 | LOSS_FUNC: cross_entropy 52 | DROPOUT_RATE: 0.5 53 | TEST: 54 | ENABLE: True 55 | DATASET: kinetics 56 | BATCH_SIZE: 64 57 | DATA_LOADER: 58 | NUM_WORKERS: 8 59 | PIN_MEMORY: True 60 | NUM_GPUS: 8 61 | NUM_SHARDS: 1 62 | RNG_SEED: 0 63 | OUTPUT_DIR: . 64 | -------------------------------------------------------------------------------- /datasets/preprocessing/check_corrupt_videos.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from datasets.video_container import get_video_container 3 | from datasets.decoder import decode 4 | from tqdm import tqdm 5 | import json 6 | 7 | root_path = "/home/kanchanaranasinghe/data/kinetics400" 8 | root_path = "/home/kanchanaranasinghe/repo/mmaction2/data/hmdb51/videos" 9 | split = "hmdb" 10 | 11 | file_list = glob.glob(f"{root_path}/{split}_256/*.mp4") 12 | file_list = glob.glob(f"{root_path}/*/*.avi") 13 | good_count = 0 14 | frames = "0" 15 | bad_list = [] 16 | for file in tqdm(file_list): 17 | try: 18 | container = get_video_container(file, True) 19 | frames = decode(container, 32, 8) 20 | assert frames is not None, "frames is None" 21 | except Exception as e: 22 | print(e, file) 23 | bad_list.append(file) 24 | else: 25 | good_count += 1 26 | 27 | print(f"{len(file_list) - good_count} files bad. {good_count} / {len(file_list)} files are good.") 28 | json.dump({"bad": bad_list}, open(f"{root_path}/{split}_256_bad.json", "w")) 29 | 30 | # import os 31 | # import json 32 | # import shutil 33 | # import pandas as pd 34 | # 35 | # root_path = "/home/kanchanaranasinghe/repo/mmaction2/data/hmdb51/splits" 36 | # txt_file_name = "/home/kanchanaranasinghe/repo/mmaction2/data/hmdb51/splits/hmdb51_val_split_1_videos.txt" 37 | # files = json.load(open(f"{root_path}/{split}_256_bad.json", "r")) 38 | # bad_names = [x[59:] for x in files['bad']] 39 | # df = pd.read_csv(f"{txt_file_name}", sep=" ", 40 | # header=None) 41 | # df = df[df[0].isin(bad_names) == False] 42 | # df.to_csv(f"{txt_file_name}", sep=" ", 43 | # header=None, index=None) 44 | 45 | # for file in files["bad"]: 46 | # shutil.move(file, file.replace(f"/{split}_256/", f"/{split}_256_bad/")) 47 | -------------------------------------------------------------------------------- /datasets/preprocessing/downsample_kinetics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from tqdm import tqdm 4 | from joblib import Parallel 5 | from joblib import delayed 6 | 7 | 8 | def downscale_clip(inname, outname): 9 | inname = '"%s"' % inname 10 | outname = '"%s"' % outname 11 | # command = "ffmpeg -loglevel panic -i {} -filter:v scale=\"trunc(oh*a/2)*2:256\" -q:v 1 -c:a copy {}".format( 12 | # inname, outname) 13 | command = f"ffmpeg -i {inname} -filter:v scale=\"trunc(oh*a/2)*2:256\" -q:v 1 -c:a copy {outname}" 14 | try: 15 | output = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT) 16 | except subprocess.CalledProcessError as err: 17 | print(err) 18 | return err.output 19 | 20 | return output 21 | 22 | 23 | def downscale_clip_wrapper(file_name): 24 | in_name = f"{folder_path}/{file_name}" 25 | out_name = f"{output_path}/{file_name}" 26 | 27 | log = downscale_clip(in_name, out_name) 28 | return file_name, log 29 | 30 | 31 | if __name__ == '__main__': 32 | root_path = "/home/salman/data/kinetics/400" 33 | split = "val" 34 | 35 | folder_path = f'{root_path}/{split}' 36 | output_path = f'{root_path}/{split}_256' 37 | os.makedirs(output_path, exist_ok=True) 38 | 39 | file_list = os.listdir(folder_path) 40 | completed_file_list = set(os.listdir(output_path)) 41 | file_list = [x for x in file_list if x not in completed_file_list] 42 | 43 | # file_list = file_list[:100] 44 | print(f"Starting to downsample {len(file_list)} video files.") 45 | 46 | # split = len(file_list) // 10 47 | # list_of_lists = [file_list[x * split:(x + 1) * split] for x in range(10)] 48 | # list_of_lists[-1].extend(file_list[10 * split:]) 49 | 50 | for file in tqdm(file_list): 51 | _, log = downscale_clip_wrapper(file) 52 | 53 | # status_lst = Parallel(n_jobs=16)(delayed(downscale_clip_wrapper)(row) for row in file_list) 54 | # status_lst = Parallel(n_jobs=16)(downscale_clip_wrapper(row) for row in file_list) 55 | # with open(f"{root_path}/downsample_{split}_logs.txt", "w") as fo: 56 | # fo.writelines([f"{x[0], x[1]}\n" for x in status_lst]) 57 | -------------------------------------------------------------------------------- /datasets/rand_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from PIL import Image 5 | from einops import rearrange, repeat 6 | 7 | 8 | class RandConv(nn.Module): 9 | 10 | def __init__(self, kernel_size=3, alpha=0.7, temporal_input=False): 11 | super(RandConv, self).__init__() 12 | self.m = nn.Conv2d(3, 3, kernel_size, stride=1, padding=kernel_size // 2, bias=False) 13 | self.std_normal = 1 / (np.sqrt(3) * kernel_size) 14 | self.alpha = alpha 15 | self.temporal_input = temporal_input 16 | 17 | def forward(self, image): 18 | with torch.no_grad(): 19 | self.m.weight = torch.nn.Parameter(torch.normal(mean=torch.zeros_like(self.m.weight), 20 | std=torch.ones_like(self.m.weight) * self.std_normal)) 21 | if self.temporal_input: 22 | batch_dim = image.shape[0] 23 | filtered_im = rearrange(image, "b c t h w -> (b t) c h w") 24 | filtered_im = self.m(filtered_im) 25 | filtered_im = rearrange(filtered_im, "(b t) c h w -> b c t h w", b=batch_dim) 26 | else: 27 | filtered_im = self.m(image) 28 | return self.alpha * image + (1 - self.alpha) * filtered_im 29 | 30 | 31 | if __name__ == '__main__': 32 | 33 | filter_im = RandConv(temporal_input=True) 34 | mean = torch.Tensor([0.485, 0.456, 0.406]) 35 | std = torch.Tensor([0.229, 0.224, 0.225]) 36 | 37 | 38 | def whiten(input): 39 | return (input - mean.reshape(1, -1, 1, 1, 1)) / std.reshape(1, -1, 1, 1, 1) 40 | 41 | 42 | def dewhiten(input): 43 | return torch.clip(input * std.reshape(1, -1, 1, 1, 1) + mean.reshape(1, -1, 1, 1, 1), 0, 1) 44 | 45 | 46 | img = Image.open("/Users/kanchana/Documents/current/video_research/repo/dino/data/rand_conv/raw.jpg") 47 | # img_arr = torch.Tensor(np.array(img).transpose(2, 0, 1)).unsqueeze(0) 48 | img_arr = repeat(torch.Tensor(np.array(img)).unsqueeze(0), "b h w c -> b c t h w", t=8) 49 | img_arr = img_arr.float() / 255. 50 | 51 | for idx in range(10): 52 | with torch.no_grad(): 53 | out = dewhiten(filter_im(whiten(img_arr))) 54 | # out = alpha * img_arr + (1 - alpha) * out 55 | im_vis = out[0, :, 0].permute(1, 2, 0).detach().numpy() * 255. 56 | im_vis = Image.fromarray(im_vis.astype(np.uint8)) 57 | im_vis.save(f"/Users/kanchana/Documents/current/video_research/repo/dino/data/rand_conv/{idx + 1:05d}.jpg") 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific 2 | .idea 3 | .DS_Store 4 | checkpoints 5 | /data 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /datasets/multigrid_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Helper functions for multigrid training.""" 4 | 5 | import numpy as np 6 | from torch._six import int_classes as _int_classes 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | class ShortCycleBatchSampler(Sampler): 11 | """ 12 | Extend Sampler to support "short cycle" sampling. 13 | See paper "A Multigrid Method for Efficiently Training Video Models", 14 | Wu et al., 2019 (https://arxiv.org/abs/1912.00998) for details. 15 | """ 16 | 17 | def __init__(self, sampler, batch_size, drop_last, cfg): 18 | if not isinstance(sampler, Sampler): 19 | raise ValueError( 20 | "sampler should be an instance of " 21 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 22 | ) 23 | if ( 24 | not isinstance(batch_size, _int_classes) 25 | or isinstance(batch_size, bool) 26 | or batch_size <= 0 27 | ): 28 | raise ValueError( 29 | "batch_size should be a positive integer value, " 30 | "but got batch_size={}".format(batch_size) 31 | ) 32 | if not isinstance(drop_last, bool): 33 | raise ValueError( 34 | "drop_last should be a boolean value, but got " 35 | "drop_last={}".format(drop_last) 36 | ) 37 | self.sampler = sampler 38 | self.drop_last = drop_last 39 | 40 | bs_factor = [ 41 | int( 42 | round( 43 | ( 44 | float(cfg.DATA.TRAIN_CROP_SIZE) 45 | / (s * cfg.MULTIGRID.DEFAULT_S) 46 | ) 47 | ** 2 48 | ) 49 | ) 50 | for s in cfg.MULTIGRID.SHORT_CYCLE_FACTORS 51 | ] 52 | 53 | self.batch_sizes = [ 54 | batch_size * bs_factor[0], 55 | batch_size * bs_factor[1], 56 | batch_size, 57 | ] 58 | 59 | def __iter__(self): 60 | counter = 0 61 | batch_size = self.batch_sizes[0] 62 | batch = [] 63 | for idx in self.sampler: 64 | batch.append((idx, counter % 3)) 65 | if len(batch) == batch_size: 66 | yield batch 67 | counter += 1 68 | batch_size = self.batch_sizes[counter % 3] 69 | batch = [] 70 | if len(batch) > 0 and not self.drop_last: 71 | yield batch 72 | 73 | def __len__(self): 74 | avg_batch_size = sum(self.batch_sizes) / 3.0 75 | if self.drop_last: 76 | return int(np.floor(len(self.sampler) / avg_batch_size)) 77 | else: 78 | return int(np.ceil(len(self.sampler) / avg_batch_size)) 79 | -------------------------------------------------------------------------------- /utils/parser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Argument parser functions.""" 4 | 5 | import argparse 6 | import sys 7 | 8 | from utils.defaults import get_cfg 9 | 10 | 11 | def parse_args(): 12 | """ 13 | Parse the following arguments for a default parser for PySlowFast users. 14 | Args: 15 | shard_id (int): shard id for the current machine. Starts from 0 to 16 | num_shards - 1. If single machine is used, then set shard id to 0. 17 | num_shards (int): number of shards using by the job. 18 | init_method (str): initialization method to launch the job with multiple 19 | devices. Options includes TCP or shared file-system for 20 | initialization. details can be find in 21 | https://pytorch.org/docs/stable/distributed.html#tcp-initialization 22 | cfg (str): path to the config file. 23 | opts (argument): provide addtional options from the command line, it 24 | overwrites the config loaded from file. 25 | """ 26 | parser = argparse.ArgumentParser( 27 | description="Provide SlowFast video training and testing pipeline." 28 | ) 29 | parser.add_argument( 30 | "--shard_id", 31 | help="The shard id of current node, Starts from 0 to num_shards - 1", 32 | default=0, 33 | type=int, 34 | ) 35 | parser.add_argument( 36 | "--num_shards", 37 | help="Number of shards using by the job", 38 | default=1, 39 | type=int, 40 | ) 41 | parser.add_argument( 42 | "--init_method", 43 | help="Initialization method, includes TCP or shared file-system", 44 | default="tcp://localhost:9999", 45 | type=str, 46 | ) 47 | parser.add_argument( 48 | "--cfg", 49 | dest="cfg_file", 50 | help="Path to the config file", 51 | default="configs/Kinetics/SLOWFAST_4x16_R50.yaml", 52 | type=str, 53 | ) 54 | parser.add_argument( 55 | "opts", 56 | help="See slowfast/config/defaults.py for all options", 57 | default=None, 58 | nargs=argparse.REMAINDER, 59 | ) 60 | if len(sys.argv) == 1: 61 | parser.print_help() 62 | return parser.parse_args() 63 | 64 | 65 | def load_config(args): 66 | """ 67 | Given the arguemnts, load and initialize the configs. 68 | Args: 69 | args (argument): arguments includes `shard_id`, `num_shards`, 70 | `init_method`, `cfg_file`, and `opts`. 71 | """ 72 | # Setup cfg. 73 | cfg = get_cfg() 74 | # Load config from cfg. 75 | if args.cfg_file is not None: 76 | cfg.merge_from_file(args.cfg_file) 77 | # Load config from command line, overwrite config from opts. 78 | if args.opts is not None: 79 | cfg.merge_from_list(args.opts) 80 | 81 | # Inherit parameters from args. 82 | if hasattr(args, "num_shards") and hasattr(args, "shard_id"): 83 | cfg.NUM_SHARDS = args.num_shards 84 | cfg.SHARD_ID = args.shard_id 85 | if hasattr(args, "rng_seed"): 86 | cfg.RNG_SEED = args.rng_seed 87 | if hasattr(args, "output_dir"): 88 | cfg.OUTPUT_DIR = args.output_dir 89 | 90 | return cfg 91 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Functions for computing metrics.""" 4 | 5 | import torch 6 | import numpy as np 7 | 8 | def topks_correct(preds, labels, ks): 9 | """ 10 | Given the predictions, labels, and a list of top-k values, compute the 11 | number of correct predictions for each top-k value. 12 | 13 | Args: 14 | preds (array): array of predictions. Dimension is batchsize 15 | N x ClassNum. 16 | labels (array): array of labels. Dimension is batchsize N. 17 | ks (list): list of top-k values. For example, ks = [1, 5] correspods 18 | to top-1 and top-5. 19 | 20 | Returns: 21 | topks_correct (list): list of numbers, where the `i`-th entry 22 | corresponds to the number of top-`ks[i]` correct predictions. 23 | """ 24 | assert preds.size(0) == labels.size( 25 | 0 26 | ), "Batch dim of predictions and labels must match" 27 | # Find the top max_k predictions for each sample 28 | _top_max_k_vals, top_max_k_inds = torch.topk( 29 | preds, max(ks), dim=1, largest=True, sorted=True 30 | ) 31 | # (batch_size, max_k) -> (max_k, batch_size). 32 | top_max_k_inds = top_max_k_inds.t() 33 | # (batch_size, ) -> (max_k, batch_size). 34 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 35 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct. 36 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 37 | # Compute the number of topk correct predictions for each k. 38 | topks_correct = [top_max_k_correct[:k, :].float().sum() for k in ks] 39 | return topks_correct 40 | 41 | 42 | def topk_errors(preds, labels, ks): 43 | """ 44 | Computes the top-k error for each k. 45 | Args: 46 | preds (array): array of predictions. Dimension is N. 47 | labels (array): array of labels. Dimension is N. 48 | ks (list): list of ks to calculate the top accuracies. 49 | """ 50 | num_topks_correct = topks_correct(preds, labels, ks) 51 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] 52 | 53 | 54 | def topk_accuracies(preds, labels, ks): 55 | """ 56 | Computes the top-k accuracy for each k. 57 | Args: 58 | preds (array): array of predictions. Dimension is N. 59 | labels (array): array of labels. Dimension is N. 60 | ks (list): list of ks to calculate the top accuracies. 61 | """ 62 | num_topks_correct = topks_correct(preds, labels, ks) 63 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] 64 | 65 | def multitask_topks_correct(preds, labels, ks=(1,)): 66 | """ 67 | Args: 68 | preds: tuple(torch.FloatTensor), each tensor should be of shape 69 | [batch_size, class_count], class_count can vary on a per task basis, i.e. 70 | outputs[i].shape[1] can be different to outputs[j].shape[j]. 71 | labels: tuple(torch.LongTensor), each tensor should be of shape [batch_size] 72 | ks: tuple(int), compute accuracy at top-k for the values of k specified 73 | in this parameter. 74 | Returns: 75 | tuple(float), same length at topk with the corresponding accuracy@k in. 76 | """ 77 | max_k = int(np.max(ks)) 78 | task_count = len(preds) 79 | batch_size = labels[0].size(0) 80 | all_correct = torch.zeros(max_k, batch_size).type(torch.ByteTensor) 81 | if torch.cuda.is_available(): 82 | all_correct = all_correct.cuda() 83 | for output, label in zip(preds, labels): 84 | _, max_k_idx = output.topk(max_k, dim=1, largest=True, sorted=True) 85 | # Flip batch_size, class_count as .view doesn't work on non-contiguous 86 | max_k_idx = max_k_idx.t() 87 | correct_for_task = max_k_idx.eq(label.view(1, -1).expand_as(max_k_idx)) 88 | all_correct.add_(correct_for_task) 89 | 90 | multitask_topks_correct = [ 91 | torch.ge(all_correct[:k].float().sum(0), task_count).float().sum(0) for k in ks 92 | ] 93 | 94 | return multitask_topks_correct 95 | -------------------------------------------------------------------------------- /datasets/preprocessing/resize_videos.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import os.path as osp 5 | import sys 6 | from multiprocessing import Pool 7 | 8 | 9 | def resize_videos(vid_item): 10 | """Generate resized video cache. 11 | 12 | Args: 13 | vid_item (list): Video item containing video full path, 14 | video relative path. 15 | 16 | Returns: 17 | bool: Whether generate video cache successfully. 18 | """ 19 | full_path, vid_path = vid_item 20 | out_full_path = osp.join(args.out_dir, vid_path) 21 | dir_name = osp.dirname(vid_path) 22 | out_dir = osp.join(args.out_dir, dir_name) 23 | if not osp.exists(out_dir): 24 | os.makedirs(out_dir) 25 | result = os.popen( 26 | f'ffprobe -hide_banner -loglevel error -select_streams v:0 -show_entries stream=width,height -of csv=p=0 {full_path}' # noqa:E501 27 | ) 28 | try: 29 | w, h = [int(d) for d in result.readline().rstrip().split(',')] 30 | if w > h: 31 | cmd = (f'ffmpeg -hide_banner -loglevel error -i {full_path} ' 32 | f'-vf {"mpdecimate," if args.remove_dup else ""}' 33 | f'scale=-2:{args.scale} ' 34 | f'{"-vsync vfr" if args.remove_dup else ""} ' 35 | f'-c:v libx264 {"-g 16" if args.dense else ""} ' 36 | f'-an {out_full_path} -y') 37 | else: 38 | cmd = (f'ffmpeg -hide_banner -loglevel error -i {full_path} ' 39 | f'-vf {"mpdecimate," if args.remove_dup else ""}' 40 | f'scale={args.scale}:-2 ' 41 | f'{"-vsync vfr" if args.remove_dup else ""} ' 42 | f'-c:v libx264 {"-g 16" if args.dense else ""} ' 43 | f'-an {out_full_path} -y') 44 | os.popen(cmd) 45 | print(f'{vid_path} done') 46 | sys.stdout.flush() 47 | except Exception as e: 48 | print(e) 49 | return False 50 | 51 | return True 52 | 53 | 54 | def parse_args(): 55 | parser = argparse.ArgumentParser( 56 | description='Generate the resized cache of original videos') 57 | parser.add_argument('src_dir', type=str, help='source video directory') 58 | parser.add_argument('out_dir', type=str, help='output video directory') 59 | parser.add_argument('--save', type=str, help='path to save output') 60 | parser.add_argument( 61 | '--dense', 62 | action='store_true', 63 | help='whether to generate a faster cache') 64 | parser.add_argument( 65 | '--level', 66 | type=int, 67 | choices=[1, 2], 68 | default=2, 69 | help='directory level of data') 70 | parser.add_argument( 71 | '--remove-dup', 72 | action='store_true', 73 | help='whether to remove duplicated frames') 74 | parser.add_argument( 75 | '--ext', 76 | type=str, 77 | default='mp4', 78 | choices=['avi', 'mp4', 'webm', 'mkv'], 79 | help='video file extensions') 80 | parser.add_argument( 81 | '--scale', 82 | type=int, 83 | default=256, 84 | help='resize image short side length keeping ratio') 85 | parser.add_argument( 86 | '--num-worker', type=int, default=8, help='number of workers') 87 | args = parser.parse_args() 88 | 89 | return args 90 | 91 | 92 | if __name__ == '__main__': 93 | args = parse_args() 94 | 95 | if not osp.isdir(args.out_dir): 96 | print(f'Creating folder: {args.out_dir}') 97 | os.makedirs(args.out_dir) 98 | 99 | print('Reading videos from folder: ', args.src_dir) 100 | print('Extension of videos: ', args.ext) 101 | fullpath_list = glob.glob(args.src_dir + '/*' * args.level + '.' + 102 | args.ext) 103 | done_fullpath_list = glob.glob(args.out_dir + '/*' * args.level + args.ext) 104 | print('Total number of videos found: ', len(fullpath_list)) 105 | print('Total number of videos transfer finished: ', 106 | len(done_fullpath_list)) 107 | if args.level == 2: 108 | vid_list = list( 109 | map( 110 | lambda p: osp.join( 111 | osp.basename(osp.dirname(p)), osp.basename(p)), 112 | fullpath_list)) 113 | elif args.level == 1: 114 | vid_list = list(map(osp.basename, fullpath_list)) 115 | pool = Pool(args.num_worker) 116 | out = pool.map(resize_videos, zip(fullpath_list, vid_list)) 117 | bad_videos = [osp.basename(x) for x, y in zip(fullpath_list, out) if not y] 118 | with open(f"{args.save}", "w") as fo: 119 | fo.writelines([f"{x}\n" for x in bad_videos]) 120 | -------------------------------------------------------------------------------- /datasets/preprocessing/flow_vis.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | 21 | def make_colorwheel(): 22 | """ 23 | Generates a color wheel for optical flow visualization as presented in: 24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 26 | 27 | Code follows the original C++ source code of Daniel Scharstein. 28 | Code follows the the Matlab source code of Deqing Sun. 29 | 30 | Returns: 31 | np.ndarray: Color wheel 32 | """ 33 | 34 | RY = 15 35 | YG = 6 36 | GC = 4 37 | CB = 11 38 | BM = 13 39 | MR = 6 40 | 41 | ncols = RY + YG + GC + CB + BM + MR 42 | colorwheel = np.zeros((ncols, 3)) 43 | col = 0 44 | 45 | # RY 46 | colorwheel[0:RY, 0] = 255 47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 48 | col = col + RY 49 | # YG 50 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 51 | colorwheel[col:col + YG, 1] = 255 52 | col = col + YG 53 | # GC 54 | colorwheel[col:col + GC, 1] = 255 55 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 56 | col = col + GC 57 | # CB 58 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 59 | colorwheel[col:col + CB, 2] = 255 60 | col = col + CB 61 | # BM 62 | colorwheel[col:col + BM, 2] = 255 63 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 64 | col = col + BM 65 | # MR 66 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 67 | colorwheel[col:col + MR, 0] = 255 68 | return colorwheel 69 | 70 | 71 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 72 | """ 73 | Applies the flow color wheel to (possibly clipped) flow components u and v. 74 | 75 | According to the C++ source code of Daniel Scharstein 76 | According to the Matlab source code of Deqing Sun 77 | 78 | Args: 79 | u (np.ndarray): Input horizontal flow of shape [H,W] 80 | v (np.ndarray): Input vertical flow of shape [H,W] 81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 82 | 83 | Returns: 84 | np.ndarray: Flow visualization image of shape [H,W,3] 85 | """ 86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 87 | colorwheel = make_colorwheel() # shape [55x3] 88 | ncols = colorwheel.shape[0] 89 | rad = np.sqrt(np.square(u) + np.square(v)) 90 | a = np.arctan2(-v, -u) / np.pi 91 | fk = (a + 1) / 2 * (ncols - 1) 92 | k0 = np.floor(fk).astype(np.int32) 93 | k1 = k0 + 1 94 | k1[k1 == ncols] = 0 95 | f = fk - k0 96 | for i in range(colorwheel.shape[1]): 97 | tmp = colorwheel[:, i] 98 | col0 = tmp[k0] / 255.0 99 | col1 = tmp[k1] / 255.0 100 | col = (1 - f) * col0 + f * col1 101 | idx = (rad <= 1) 102 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 103 | col[~idx] = col[~idx] * 0.75 # out of range 104 | # Note the 2-i => BGR instead of RGB 105 | ch_idx = 2 - i if convert_to_bgr else i 106 | flow_image[:, :, ch_idx] = np.floor(255 * col) 107 | return flow_image 108 | 109 | 110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 111 | """ 112 | Expects a two dimensional flow image of shape. 113 | 114 | Args: 115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 118 | 119 | Returns: 120 | np.ndarray: Flow visualization image of shape [H,W,3] 121 | """ 122 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 123 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 124 | if clip_flow is not None: 125 | flow_uv = np.clip(flow_uv, 0, clip_flow) 126 | u = flow_uv[:, :, 0] 127 | v = flow_uv[:, :, 1] 128 | rad = np.sqrt(np.square(u) + np.square(v)) 129 | rad_max = np.max(rad) 130 | epsilon = 1e-5 131 | u = u / (rad_max + epsilon) 132 | v = v / (rad_max + epsilon) 133 | return flow_uv_to_colors(u, v, convert_to_bgr) 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Video Transformer (CVPR'22-Oral) 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervised-video-transformer/self-supervised-action-recognition-linear-on-3)](https://paperswithcode.com/sota/self-supervised-action-recognition-linear-on-3?p=self-supervised-video-transformer) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervised-video-transformer/self-supervised-action-recognition-linear-on)](https://paperswithcode.com/sota/self-supervised-action-recognition-linear-on?p=self-supervised-video-transformer) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervised-video-transformer/self-supervised-action-recognition-linear-on-1)](https://paperswithcode.com/sota/self-supervised-action-recognition-linear-on-1?p=self-supervised-video-transformer) 6 | 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervised-video-transformer/action-recognition-in-videos-on-ucf101)](https://paperswithcode.com/sota/action-recognition-in-videos-on-ucf101?p=self-supervised-video-transformer) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervised-video-transformer/action-recognition-in-videos-on-hmdb-51)](https://paperswithcode.com/sota/action-recognition-in-videos-on-hmdb-51?p=self-supervised-video-transformer) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervised-video-transformer/action-recognition-in-videos-on-something)](https://paperswithcode.com/sota/action-recognition-in-videos-on-something?p=self-supervised-video-transformer) 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervised-video-transformer/action-classification-on-kinetics-400)](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=self-supervised-video-transformer) 11 | 12 | [Kanchana Ranasinghe](https://kahnchana.github.io), 13 | [Muzammal Naseer](https://muzammal-naseer.netlify.app/), 14 | [Salman Khan](https://salman-h-khan.github.io), 15 | [Fahad Shahbaz Khan](https://sites.google.com/view/fahadkhans/home), 16 | [Michael Ryoo](http://michaelryoo.com) 17 | 18 | **[Paper Link](https://arxiv.org/abs/2112.01514)** | **[Project Page](https://kahnchana.github.io/svt)** 19 | 20 | 21 | > **Abstract:** 22 | >*In this paper, we propose self-supervised training for video transformers using unlabelled video data. From a given video, we create local and global spatiotemporal views with varying spatial sizes and frame rates. Our self-supervised objective seeks to match the features of these different views representing the same video, to be invariant to spatiotemporal variations in actions. To the best of our knowledge, the proposed approach is the first to alleviate the dependency on negative samples or dedicated memory banks in Self-supervised Video Transformer (SVT). Further, owing to the flexibility of Transformer models, SVT supports slow-fast video processing within a single architecture using dynamically adjusted positional encodings and supports long-term relationship modeling along spatiotemporal dimensions. Our approach performs well on four action recognition benchmarks (Kinetics-400, UCF-101, HMDB-51, and SSv2) and converges faster with small batch sizes.* 23 | 24 | 25 |

26 | intro_image 27 |

28 | 29 | 30 | 31 | ## Usage & Data 32 | Refer to `requirements.txt` for installing all python dependencies. We use python 3.7 with pytorch 1.7.1. 33 | 34 | We download the official version of Kinetics-400 from [here](https://github.com/cvdfoundation/kinetics-dataset) and videos are resized using code [here](https://github.com/open-mmlab/mmaction2/tree/master/tools/data/kinetics). 35 | 36 | 37 | ## Self-supervised Training 38 | For self-supervised pre-training on models on the Kinetics-400 dataset, use the scripts in the `scripts` directory as follows. Change the paths to dataset as required. 39 | 40 | ``` 41 | ./scripts/train.sh 42 | ``` 43 | 44 | 45 | ## Downstream Evaluation 46 | Scripts to perform evaluation (linear or knn) on selected downstream tasks are as below. Paths to datasets and pre-trained models must be set appropriately. Note that in the case of linear evaluation, a linear layer will be fine-tuned on the new dataset and this training can be time-consuming on a single GPU. 47 | 48 | ``` 49 | ./scripts/eval_linear.sh 50 | ./scripts/eval_knn.sh 51 | ``` 52 | 53 | 54 | ## Pretrained Models 55 | Our pre-trained models can be found under [releases](https://github.com/kahnchana/svt/releases/tag/v1.0). 56 | 57 | 58 | ## Citation 59 | If you find our work, this repository, or pretrained models useful, please consider giving a star :star: and citation. 60 | ```bibtex 61 | @inproceedings{ranasinghe2022selfsupervised, 62 | title={Self-supervised Video Transformer}, 63 | author={Kanchana Ranasinghe and Muzammal Naseer and Salman Khan and Fahad Shahbaz Khan and Michael Ryoo}, 64 | booktitle={IEEE/CVF International Conference on Computer Vision and Pattern Recognition}, 65 | month = {June}, 66 | year={2022} 67 | } 68 | ``` 69 | 70 | 71 | ## Acknowledgements 72 | Our code is based on [DINO](https://github.com/facebookresearch/dino) and [TimeSformer](https://github.com/facebookresearch/TimeSformer) repositories. We thank the authors for releasing their code. If you use our model, please consider citing these works as well. 73 | -------------------------------------------------------------------------------- /datasets/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Data loader.""" 4 | 5 | import itertools 6 | import numpy as np 7 | import torch 8 | from torch.utils.data._utils.collate import default_collate 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.utils.data.sampler import RandomSampler 11 | 12 | from timesformer.datasets.multigrid_helper import ShortCycleBatchSampler 13 | 14 | from . import data_utils as utils 15 | from .build import build_dataset 16 | 17 | 18 | def detection_collate(batch): 19 | """ 20 | Collate function for detection task. Concatanate bboxes, labels and 21 | metadata from different samples in the first dimension instead of 22 | stacking them to have a batch-size dimension. 23 | Args: 24 | batch (tuple or list): data batch to collate. 25 | Returns: 26 | (tuple): collated detection data batch. 27 | """ 28 | inputs, labels, video_idx, extra_data = zip(*batch) 29 | inputs, video_idx = default_collate(inputs), default_collate(video_idx) 30 | labels = torch.tensor(np.concatenate(labels, axis=0)).float() 31 | 32 | collated_extra_data = {} 33 | for key in extra_data[0].keys(): 34 | data = [d[key] for d in extra_data] 35 | if key == "boxes" or key == "ori_boxes": 36 | # Append idx info to the bboxes before concatenating them. 37 | bboxes = [ 38 | np.concatenate( 39 | [np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1 40 | ) 41 | for i in range(len(data)) 42 | ] 43 | bboxes = np.concatenate(bboxes, axis=0) 44 | collated_extra_data[key] = torch.tensor(bboxes).float() 45 | elif key == "metadata": 46 | collated_extra_data[key] = torch.tensor( 47 | list(itertools.chain(*data)) 48 | ).view(-1, 2) 49 | else: 50 | collated_extra_data[key] = default_collate(data) 51 | 52 | return inputs, labels, video_idx, collated_extra_data 53 | 54 | 55 | def construct_loader(cfg, split, is_precise_bn=False): 56 | """ 57 | Constructs the data loader for the given dataset. 58 | Args: 59 | cfg (CfgNode): configs. Details can be found in 60 | slowfast/config/defaults.py 61 | split (str): the split of the data loader. Options include `train`, 62 | `val`, and `test`. 63 | """ 64 | assert split in ["train", "val", "test"] 65 | if split in ["train"]: 66 | dataset_name = cfg.TRAIN.DATASET 67 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 68 | shuffle = True 69 | drop_last = True 70 | elif split in ["val"]: 71 | dataset_name = cfg.TRAIN.DATASET 72 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 73 | shuffle = False 74 | drop_last = False 75 | elif split in ["test"]: 76 | dataset_name = cfg.TEST.DATASET 77 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 78 | shuffle = False 79 | drop_last = False 80 | 81 | # Construct the dataset 82 | dataset = build_dataset(dataset_name, cfg, split) 83 | 84 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn: 85 | # Create a sampler for multi-process training 86 | sampler = utils.create_sampler(dataset, shuffle, cfg) 87 | batch_sampler = ShortCycleBatchSampler( 88 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg 89 | ) 90 | # Create a loader 91 | loader = torch.utils.data.DataLoader( 92 | dataset, 93 | batch_sampler=batch_sampler, 94 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 95 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 96 | worker_init_fn=utils.loader_worker_init_fn(dataset), 97 | ) 98 | else: 99 | # Create a sampler for multi-process training 100 | sampler = utils.create_sampler(dataset, shuffle, cfg) 101 | # Create a loader 102 | loader = torch.utils.data.DataLoader( 103 | dataset, 104 | batch_size=batch_size, 105 | shuffle=(False if sampler else shuffle), 106 | sampler=sampler, 107 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 108 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 109 | drop_last=drop_last, 110 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None, 111 | worker_init_fn=utils.loader_worker_init_fn(dataset), 112 | ) 113 | return loader 114 | 115 | 116 | def shuffle_dataset(loader, cur_epoch): 117 | """ " 118 | Shuffles the data. 119 | Args: 120 | loader (loader): data loader to perform shuffle. 121 | cur_epoch (int): number of the current epoch. 122 | """ 123 | sampler = ( 124 | loader.batch_sampler.sampler 125 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler) 126 | else loader.sampler 127 | ) 128 | assert isinstance( 129 | sampler, (RandomSampler, DistributedSampler) 130 | ), "Sampler type '{}' not supported".format(type(sampler)) 131 | # RandomSampler handles shuffling automatically 132 | if isinstance(sampler, DistributedSampler): 133 | # DistributedSampler shuffles data based on epoch 134 | sampler.set_epoch(cur_epoch) 135 | -------------------------------------------------------------------------------- /models/vit_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Ross Wightman 2 | # Various utility functions 3 | 4 | import math 5 | import warnings 6 | from itertools import repeat 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 13 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 14 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8: 15 | from torch._six import container_abcs, int_classes 16 | else: 17 | import collections.abc as container_abcs 18 | 19 | DEFAULT_CROP_PCT = 0.875 20 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 21 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 22 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 23 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 24 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 25 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 26 | 27 | 28 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 29 | def norm_cdf(x): 30 | # Computes standard normal cumulative distribution function 31 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 32 | 33 | if (mean < a - 2 * std) or (mean > b + 2 * std): 34 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 35 | "The distribution of values may be incorrect.", 36 | stacklevel=2) 37 | 38 | with torch.no_grad(): 39 | # Values are generated by using a truncated uniform distribution and 40 | # then using the inverse CDF for the normal distribution. 41 | # Get upper and lower cdf values 42 | l = norm_cdf((a - mean) / std) 43 | u = norm_cdf((b - mean) / std) 44 | 45 | # Uniformly fill tensor with values from [l, u], then translate to 46 | # [2l-1, 2u-1]. 47 | tensor.uniform_(2 * l - 1, 2 * u - 1) 48 | 49 | # Use inverse cdf transform for normal distribution to get truncated 50 | # standard normal 51 | tensor.erfinv_() 52 | 53 | # Transform to proper mean, std 54 | tensor.mul_(std * math.sqrt(2.)) 55 | tensor.add_(mean) 56 | 57 | # Clamp to ensure it's in the proper range 58 | tensor.clamp_(min=a, max=b) 59 | return tensor 60 | 61 | 62 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 63 | # type: (Tensor, float, float, float, float) -> Tensor 64 | r"""Fills the input Tensor with values drawn from a truncated 65 | normal distribution. The values are effectively drawn from the 66 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 67 | with values outside :math:`[a, b]` redrawn until they are within 68 | the bounds. The method used for generating the random values works 69 | best when :math:`a \leq \text{mean} \leq b`. 70 | Args: 71 | tensor: an n-dimensional `torch.Tensor` 72 | mean: the mean of the normal distribution 73 | std: the standard deviation of the normal distribution 74 | a: the minimum cutoff value 75 | b: the maximum cutoff value 76 | Examples: 77 | >>> w = torch.empty(3, 5) 78 | >>> nn.init.trunc_normal_(w) 79 | """ 80 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 81 | 82 | 83 | # From PyTorch internals 84 | def _ntuple(n): 85 | def parse(x): 86 | if isinstance(x, container_abcs.Iterable): 87 | return x 88 | return tuple(repeat(x, n)) 89 | 90 | return parse 91 | 92 | 93 | to_2tuple = _ntuple(2) 94 | 95 | 96 | # Calculate symmetric padding for a convolution 97 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 98 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 99 | return padding 100 | 101 | 102 | def get_padding_value(padding, kernel_size, **kwargs): 103 | dynamic = False 104 | if isinstance(padding, str): 105 | # for any string padding, the padding will be calculated for you, one of three ways 106 | padding = padding.lower() 107 | if padding == 'same': 108 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 109 | if is_static_pad(kernel_size, **kwargs): 110 | # static case, no extra overhead 111 | padding = get_padding(kernel_size, **kwargs) 112 | else: 113 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 114 | padding = 0 115 | dynamic = True 116 | elif padding == 'valid': 117 | # 'VALID' padding, same as padding=0 118 | padding = 0 119 | else: 120 | # Default to PyTorch style 'same'-ish symmetric padding 121 | padding = get_padding(kernel_size, **kwargs) 122 | return padding, dynamic 123 | 124 | 125 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 126 | def get_same_padding(x: int, k: int, s: int, d: int): 127 | return max((int(math.ceil(x // s)) - 1) * s + (k - 1) * d + 1 - x, 0) 128 | 129 | 130 | # Can SAME padding for given args be done statically? 131 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 132 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 133 | 134 | 135 | # Dynamically pad input x with 'SAME' padding for conv with specified args 136 | # def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 137 | def pad_same(x, k, s, d=(1, 1), value=0): 138 | ih, iw = x.size()[-2:] 139 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 140 | if pad_h > 0 or pad_w > 0: 141 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 142 | return x 143 | 144 | 145 | def adaptive_pool_feat_mult(pool_type='avg'): 146 | if pool_type == 'catavgmax': 147 | return 2 148 | else: 149 | return 1 150 | 151 | 152 | def drop_path(x, drop_prob: float = 0., training: bool = False): 153 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 154 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 155 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 156 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 157 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 158 | 'survival rate' as the argument. 159 | """ 160 | if drop_prob == 0. or not training: 161 | return x 162 | keep_prob = 1 - drop_prob 163 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 164 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 165 | random_tensor.floor_() # binarize 166 | output = x.div(keep_prob) * random_tensor 167 | return output 168 | 169 | 170 | class DropPath(nn.Module): 171 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 172 | """ 173 | 174 | def __init__(self, drop_prob=None): 175 | super(DropPath, self).__init__() 176 | self.drop_prob = drop_prob 177 | 178 | def forward(self, x): 179 | return drop_path(x, self.drop_prob, self.training) 180 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Meters.""" 4 | 5 | import datetime 6 | 7 | import numpy as np 8 | import torch 9 | from fvcore.common.timer import Timer 10 | from sklearn.metrics import average_precision_score 11 | 12 | import utils.logging as logging 13 | import utils.metrics as metrics 14 | 15 | logger = logging.get_logger(__name__) 16 | 17 | 18 | class TestMeter(object): 19 | """ 20 | Perform the multi-view ensemble for testing: each video with an unique index 21 | will be sampled with multiple clips, and the predictions of the clips will 22 | be aggregated to produce the final prediction for the video. 23 | The accuracy is calculated with the given ground truth labels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | num_videos, 29 | num_clips, 30 | num_cls, 31 | overall_iters, 32 | multi_label=False, 33 | ensemble_method="sum", 34 | ): 35 | """ 36 | Construct tensors to store the predictions and labels. Expect to get 37 | num_clips predictions from each video, and calculate the metrics on 38 | num_videos videos. 39 | Args: 40 | num_videos (int): number of videos to test. 41 | num_clips (int): number of clips sampled from each video for 42 | aggregating the final prediction for the video. 43 | num_cls (int): number of classes for each prediction. 44 | overall_iters (int): overall iterations for testing. 45 | multi_label (bool): if True, use map as the metric. 46 | ensemble_method (str): method to perform the ensemble, options 47 | include "sum", and "max". 48 | """ 49 | 50 | self.iter_timer = Timer() 51 | self.data_timer = Timer() 52 | self.net_timer = Timer() 53 | self.num_clips = num_clips 54 | self.overall_iters = overall_iters 55 | self.multi_label = multi_label 56 | self.ensemble_method = ensemble_method 57 | # Initialize tensors. 58 | self.video_preds = torch.zeros((num_videos, num_cls)) 59 | if multi_label: 60 | self.video_preds -= 1e10 61 | 62 | self.video_labels = ( 63 | torch.zeros((num_videos, num_cls)) 64 | if multi_label 65 | else torch.zeros((num_videos)).long() 66 | ) 67 | self.clip_count = torch.zeros((num_videos)).long() 68 | self.topk_accs = [] 69 | self.stats = {} 70 | 71 | # Reset metric. 72 | self.reset() 73 | 74 | def reset(self): 75 | """ 76 | Reset the metric. 77 | """ 78 | self.clip_count.zero_() 79 | self.video_preds.zero_() 80 | if self.multi_label: 81 | self.video_preds -= 1e10 82 | self.video_labels.zero_() 83 | 84 | def update_stats(self, preds, labels, clip_ids): 85 | """ 86 | Collect the predictions from the current batch and perform on-the-flight 87 | summation as ensemble. 88 | Args: 89 | preds (tensor): predictions from the current batch. Dimension is 90 | N x C where N is the batch size and C is the channel size 91 | (num_cls). 92 | labels (tensor): the corresponding labels of the current batch. 93 | Dimension is N. 94 | clip_ids (tensor): clip indexes of the current batch, dimension is 95 | N. 96 | """ 97 | for ind in range(preds.shape[0]): 98 | vid_id = int(clip_ids[ind]) // self.num_clips 99 | if self.video_labels[vid_id].sum() > 0: 100 | assert torch.equal( 101 | self.video_labels[vid_id].type(torch.FloatTensor), 102 | labels[ind].type(torch.FloatTensor), 103 | ) 104 | self.video_labels[vid_id] = labels[ind] 105 | if self.ensemble_method == "sum": 106 | self.video_preds[vid_id] += preds[ind] 107 | elif self.ensemble_method == "max": 108 | self.video_preds[vid_id] = torch.max( 109 | self.video_preds[vid_id], preds[ind] 110 | ) 111 | else: 112 | raise NotImplementedError( 113 | "Ensemble Method {} is not supported".format( 114 | self.ensemble_method 115 | ) 116 | ) 117 | self.clip_count[vid_id] += 1 118 | 119 | def log_iter_stats(self, cur_iter): 120 | """ 121 | Log the stats. 122 | Args: 123 | cur_iter (int): the current iteration of testing. 124 | """ 125 | eta_sec = self.iter_timer.seconds() * (self.overall_iters - cur_iter) 126 | eta = str(datetime.timedelta(seconds=int(eta_sec))) 127 | stats = { 128 | "split": "test_iter", 129 | "cur_iter": "{}".format(cur_iter + 1), 130 | "eta": eta, 131 | "time_diff": self.iter_timer.seconds(), 132 | } 133 | logging.log_json_stats(stats) 134 | 135 | def iter_tic(self): 136 | """ 137 | Start to record time. 138 | """ 139 | self.iter_timer.reset() 140 | self.data_timer.reset() 141 | 142 | def iter_toc(self): 143 | """ 144 | Stop to record time. 145 | """ 146 | self.iter_timer.pause() 147 | self.net_timer.pause() 148 | 149 | def data_toc(self): 150 | self.data_timer.pause() 151 | self.net_timer.reset() 152 | 153 | def finalize_metrics(self, ks=(1, 5)): 154 | """ 155 | Calculate and log the final ensembled metrics. 156 | ks (tuple): list of top-k values for topk_accuracies. For example, 157 | ks = (1, 5) correspods to top-1 and top-5 accuracy. 158 | """ 159 | if not all(self.clip_count == self.num_clips): 160 | logger.warning( 161 | "clip count {} ~= num clips {}".format( 162 | ", ".join( 163 | [ 164 | "{}: {}".format(i, k) 165 | for i, k in enumerate(self.clip_count.tolist()) 166 | ] 167 | ), 168 | self.num_clips, 169 | ) 170 | ) 171 | 172 | self.stats = {"split": "test_final"} 173 | if self.multi_label: 174 | map = get_map( 175 | self.video_preds.cpu().numpy(), self.video_labels.cpu().numpy() 176 | ) 177 | self.stats["map"] = map 178 | else: 179 | num_topks_correct = metrics.topks_correct( 180 | self.video_preds, self.video_labels, ks 181 | ) 182 | topks = [ 183 | (x / self.video_preds.size(0)) * 100.0 184 | for x in num_topks_correct 185 | ] 186 | 187 | assert len({len(ks), len(topks)}) == 1 188 | for k, topk in zip(ks, topks): 189 | self.stats["top{}_acc".format(k)] = "{:.{prec}f}".format( 190 | topk, prec=2 191 | ) 192 | logging.log_json_stats(self.stats) 193 | 194 | 195 | def get_map(preds, labels): 196 | """ 197 | Compute mAP for multi-label case. 198 | Args: 199 | preds (numpy tensor): num_examples x num_classes. 200 | labels (numpy tensor): num_examples x num_classes. 201 | Returns: 202 | mean_ap (int): final mAP score. 203 | """ 204 | 205 | logger.info("Getting mAP for {} examples".format(preds.shape[0])) 206 | 207 | preds = preds[:, ~(np.all(labels == 0, axis=0))] 208 | labels = labels[:, ~(np.all(labels == 0, axis=0))] 209 | aps = [0] 210 | try: 211 | aps = average_precision_score(labels, preds, average=None) 212 | except ValueError: 213 | print( 214 | "Average precision requires a sufficient number of samples \ 215 | in a batch which are missing in this sample." 216 | ) 217 | 218 | mean_ap = np.mean(aps) 219 | return mean_ap 220 | -------------------------------------------------------------------------------- /models/s3d.py: -------------------------------------------------------------------------------- 1 | # modified from https://raw.githubusercontent.com/qijiezhao/s3d.pytorch/master/S3DG_Pytorch.py 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | ## pytorch default: torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 7 | ## tensorflow s3d code: torch.nn.BatchNorm3d(num_features, eps=1e-3, momentum=0.001, affine=True, track_running_stats=True) 8 | 9 | class BasicConv3d(nn.Module): 10 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 11 | super(BasicConv3d, self).__init__() 12 | self.conv = nn.Conv3d(in_planes, out_planes, 13 | kernel_size=kernel_size, stride=stride, 14 | padding=padding, bias=False) 15 | 16 | # self.bn = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True) 17 | self.bn = nn.BatchNorm3d(out_planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | 20 | # init 21 | self.conv.weight.data.normal_(mean=0, std=0.01) # original s3d is truncated normal within 2 std 22 | self.bn.weight.data.fill_(1) 23 | self.bn.bias.data.zero_() 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | x = self.bn(x) 28 | x = self.relu(x) 29 | return x 30 | 31 | 32 | class STConv3d(nn.Module): 33 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 34 | super(STConv3d, self).__init__() 35 | if isinstance(stride, tuple): 36 | t_stride = stride[0] 37 | stride = stride[-1] 38 | else: # int 39 | t_stride = stride 40 | 41 | self.conv1 = nn.Conv3d(in_planes, out_planes, kernel_size=(1, kernel_size, kernel_size), 42 | stride=(1, stride, stride), padding=(0, padding, padding), bias=False) 43 | self.conv2 = nn.Conv3d(out_planes, out_planes, kernel_size=(kernel_size, 1, 1), 44 | stride=(t_stride, 1, 1), padding=(padding, 0, 0), bias=False) 45 | 46 | # self.bn1=nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True) 47 | # self.bn2=nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True) 48 | self.bn1 = nn.BatchNorm3d(out_planes) 49 | self.bn2 = nn.BatchNorm3d(out_planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | 52 | # init 53 | self.conv1.weight.data.normal_(mean=0, std=0.01) # original s3d is truncated normal within 2 std 54 | self.conv2.weight.data.normal_(mean=0, std=0.01) # original s3d is truncated normal within 2 std 55 | self.bn1.weight.data.fill_(1) 56 | self.bn1.bias.data.zero_() 57 | self.bn2.weight.data.fill_(1) 58 | self.bn2.bias.data.zero_() 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = self.bn1(x) 63 | x = self.relu(x) 64 | x = self.conv2(x) 65 | x = self.bn2(x) 66 | x = self.relu(x) 67 | return x 68 | 69 | 70 | class SelfGating(nn.Module): 71 | def __init__(self, input_dim): 72 | super(SelfGating, self).__init__() 73 | self.fc = nn.Linear(input_dim, input_dim) 74 | 75 | def forward(self, input_tensor): 76 | """Feature gating as used in S3D-G""" 77 | spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4]) 78 | weights = self.fc(spatiotemporal_average) 79 | weights = torch.sigmoid(weights) 80 | return weights[:, :, None, None, None] * input_tensor 81 | 82 | 83 | class SepInception(nn.Module): 84 | def __init__(self, in_planes, out_planes, gating=False): 85 | super(SepInception, self).__init__() 86 | 87 | assert len(out_planes) == 6 88 | assert isinstance(out_planes, list) 89 | 90 | [num_out_0_0a, 91 | num_out_1_0a, num_out_1_0b, 92 | num_out_2_0a, num_out_2_0b, 93 | num_out_3_0b] = out_planes 94 | 95 | self.branch0 = nn.Sequential( 96 | BasicConv3d(in_planes, num_out_0_0a, kernel_size=1, stride=1), 97 | ) 98 | self.branch1 = nn.Sequential( 99 | BasicConv3d(in_planes, num_out_1_0a, kernel_size=1, stride=1), 100 | STConv3d(num_out_1_0a, num_out_1_0b, kernel_size=3, stride=1, padding=1), 101 | ) 102 | self.branch2 = nn.Sequential( 103 | BasicConv3d(in_planes, num_out_2_0a, kernel_size=1, stride=1), 104 | STConv3d(num_out_2_0a, num_out_2_0b, kernel_size=3, stride=1, padding=1), 105 | ) 106 | self.branch3 = nn.Sequential( 107 | nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), 108 | BasicConv3d(in_planes, num_out_3_0b, kernel_size=1, stride=1), 109 | ) 110 | 111 | self.out_channels = sum([num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b]) 112 | 113 | self.gating = gating 114 | if gating: 115 | self.gating_b0 = SelfGating(num_out_0_0a) 116 | self.gating_b1 = SelfGating(num_out_1_0b) 117 | self.gating_b2 = SelfGating(num_out_2_0b) 118 | self.gating_b3 = SelfGating(num_out_3_0b) 119 | 120 | def forward(self, x): 121 | x0 = self.branch0(x) 122 | x1 = self.branch1(x) 123 | x2 = self.branch2(x) 124 | x3 = self.branch3(x) 125 | if self.gating: 126 | x0 = self.gating_b0(x0) 127 | x1 = self.gating_b1(x1) 128 | x2 = self.gating_b2(x2) 129 | x3 = self.gating_b3(x3) 130 | 131 | out = torch.cat((x0, x1, x2, x3), 1) 132 | 133 | return out 134 | 135 | 136 | class S3D(nn.Module): 137 | 138 | def __init__(self, input_channel=3, gating=False, slow=False): 139 | super(S3D, self).__init__() 140 | self.gating = gating 141 | self.slow = slow 142 | 143 | if slow: 144 | self.Conv_1a = STConv3d(input_channel, 64, kernel_size=7, stride=(1, 2, 2), padding=3) 145 | else: # normal 146 | self.Conv_1a = STConv3d(input_channel, 64, kernel_size=7, stride=2, padding=3) 147 | 148 | self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112) 149 | 150 | ################################### 151 | 152 | self.MaxPool_2a = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 153 | self.Conv_2b = BasicConv3d(64, 64, kernel_size=1, stride=1) 154 | self.Conv_2c = STConv3d(64, 192, kernel_size=3, stride=1, padding=1) 155 | 156 | self.block2 = nn.Sequential( 157 | self.MaxPool_2a, # (64, 32, 56, 56) 158 | self.Conv_2b, # (64, 32, 56, 56) 159 | self.Conv_2c) # (192, 32, 56, 56) 160 | 161 | ################################### 162 | 163 | self.MaxPool_3a = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 164 | self.Mixed_3b = SepInception(in_planes=192, out_planes=[64, 96, 128, 16, 32, 32], gating=gating) 165 | self.Mixed_3c = SepInception(in_planes=256, out_planes=[128, 128, 192, 32, 96, 64], gating=gating) 166 | 167 | self.block3 = nn.Sequential( 168 | self.MaxPool_3a, # (192, 32, 28, 28) 169 | self.Mixed_3b, # (256, 32, 28, 28) 170 | self.Mixed_3c) # (480, 32, 28, 28) 171 | 172 | ################################### 173 | 174 | self.MaxPool_4a = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) 175 | self.Mixed_4b = SepInception(in_planes=480, out_planes=[192, 96, 208, 16, 48, 64], gating=gating) 176 | self.Mixed_4c = SepInception(in_planes=512, out_planes=[160, 112, 224, 24, 64, 64], gating=gating) 177 | self.Mixed_4d = SepInception(in_planes=512, out_planes=[128, 128, 256, 24, 64, 64], gating=gating) 178 | self.Mixed_4e = SepInception(in_planes=512, out_planes=[112, 144, 288, 32, 64, 64], gating=gating) 179 | self.Mixed_4f = SepInception(in_planes=528, out_planes=[256, 160, 320, 32, 128, 128], gating=gating) 180 | 181 | self.block4 = nn.Sequential( 182 | self.MaxPool_4a, # (480, 16, 14, 14) 183 | self.Mixed_4b, # (512, 16, 14, 14) 184 | self.Mixed_4c, # (512, 16, 14, 14) 185 | self.Mixed_4d, # (512, 16, 14, 14) 186 | self.Mixed_4e, # (528, 16, 14, 14) 187 | self.Mixed_4f) # (832, 16, 14, 14) 188 | 189 | ################################### 190 | 191 | self.MaxPool_5a = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)) 192 | self.Mixed_5b = SepInception(in_planes=832, out_planes=[256, 160, 320, 32, 128, 128], gating=gating) 193 | self.Mixed_5c = SepInception(in_planes=832, out_planes=[384, 192, 384, 48, 128, 128], gating=gating) 194 | 195 | self.block5 = nn.Sequential( 196 | self.MaxPool_5a, # (832, 8, 7, 7) 197 | self.Mixed_5b, # (832, 8, 7, 7) 198 | self.Mixed_5c) # (1024, 8, 7, 7) 199 | 200 | ################################### 201 | 202 | # self.AvgPool_0a = nn.AvgPool3d(kernel_size=(2, 7, 7), stride=1) 203 | # self.Dropout_0b = nn.Dropout3d(dropout_keep_prob) 204 | # self.Conv_0c = nn.Conv3d(1024, num_classes, kernel_size=1, stride=1, bias=True) 205 | 206 | # self.classifier = nn.Sequential( 207 | # self.AvgPool_0a, 208 | # self.Dropout_0b, 209 | # self.Conv_0c) 210 | 211 | def forward(self, x): 212 | x = self.block1(x) 213 | x = self.block2(x) 214 | x = self.block3(x) 215 | x = self.block4(x) 216 | x = self.block5(x) 217 | return x 218 | 219 | 220 | if __name__ == '__main__': 221 | model = S3D(num_classes=400) 222 | -------------------------------------------------------------------------------- /models/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Copyright 2020 Ross Wightman 3 | # Modified model creation / weight loading / state_dict helpers 4 | 5 | import logging 6 | import math 7 | import os 8 | from collections import OrderedDict 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.utils.model_zoo as model_zoo 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | def load_state_dict(checkpoint_path, use_ema=False): 18 | if checkpoint_path and os.path.isfile(checkpoint_path): 19 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 20 | elif checkpoint_path.startswith("https://"): 21 | checkpoint = torch.hub.load_state_dict_from_url(checkpoint_path, map_location='cpu') 22 | else: 23 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 24 | raise FileNotFoundError() 25 | 26 | state_dict_key = 'state_dict' 27 | if isinstance(checkpoint, dict): 28 | if use_ema and 'state_dict_ema' in checkpoint: 29 | state_dict_key = 'state_dict_ema' 30 | if state_dict_key and state_dict_key in checkpoint: 31 | new_state_dict = OrderedDict() 32 | for k, v in checkpoint[state_dict_key].items(): 33 | # strip `module.` prefix 34 | name = k[7:] if k.startswith('module') else k 35 | new_state_dict[name] = v 36 | state_dict = new_state_dict 37 | elif 'model_state' in checkpoint: 38 | state_dict_key = 'model_state' 39 | new_state_dict = OrderedDict() 40 | for k, v in checkpoint[state_dict_key].items(): 41 | # strip `model.` prefix 42 | name = k[6:] if k.startswith('model') else k 43 | new_state_dict[name] = v 44 | state_dict = new_state_dict 45 | else: 46 | state_dict = checkpoint 47 | _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) 48 | return state_dict 49 | 50 | 51 | def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): 52 | state_dict = load_state_dict(checkpoint_path, use_ema) 53 | model.load_state_dict(state_dict, strict=strict) 54 | 55 | 56 | def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): 57 | resume_epoch = None 58 | if os.path.isfile(checkpoint_path): 59 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 60 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 61 | if log_info: 62 | _logger.info('Restoring model state from checkpoint...') 63 | new_state_dict = OrderedDict() 64 | for k, v in checkpoint['state_dict'].items(): 65 | name = k[7:] if k.startswith('module') else k 66 | new_state_dict[name] = v 67 | model.load_state_dict(new_state_dict) 68 | 69 | if optimizer is not None and 'optimizer' in checkpoint: 70 | if log_info: 71 | _logger.info('Restoring optimizer state from checkpoint...') 72 | optimizer.load_state_dict(checkpoint['optimizer']) 73 | 74 | if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: 75 | if log_info: 76 | _logger.info('Restoring AMP loss scaler state from checkpoint...') 77 | loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) 78 | 79 | if 'epoch' in checkpoint: 80 | resume_epoch = checkpoint['epoch'] 81 | if 'version' in checkpoint and checkpoint['version'] > 1: 82 | resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save 83 | 84 | if log_info: 85 | _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) 86 | else: 87 | model.load_state_dict(checkpoint) 88 | if log_info: 89 | _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) 90 | return resume_epoch 91 | else: 92 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 93 | raise FileNotFoundError() 94 | 95 | 96 | def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_frames=8, 97 | num_patches=196, attention_type='divided_space_time', pretrained_model="", strict=True): 98 | if cfg is None: 99 | cfg = getattr(model, 'default_cfg') 100 | if cfg is None or 'url' not in cfg or not cfg['url']: 101 | _logger.warning("Pretrained model URL is invalid, using random initialization.") 102 | return 103 | 104 | if len(pretrained_model) == 0: 105 | state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') 106 | else: 107 | try: 108 | state_dict = load_state_dict(pretrained_model)['model'] 109 | except: 110 | state_dict = load_state_dict(pretrained_model) 111 | 112 | if filter_fn is not None: 113 | state_dict = filter_fn(state_dict) 114 | 115 | if in_chans == 1: 116 | conv1_name = cfg['first_conv'] 117 | _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) 118 | conv1_weight = state_dict[conv1_name + '.weight'] 119 | conv1_type = conv1_weight.dtype 120 | conv1_weight = conv1_weight.float() 121 | O, I, J, K = conv1_weight.shape 122 | if I > 3: 123 | assert conv1_weight.shape[1] % 3 == 0 124 | # For models with space2depth stems 125 | conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) 126 | conv1_weight = conv1_weight.sum(dim=2, keepdim=False) 127 | else: 128 | conv1_weight = conv1_weight.sum(dim=1, keepdim=True) 129 | conv1_weight = conv1_weight.to(conv1_type) 130 | state_dict[conv1_name + '.weight'] = conv1_weight 131 | elif in_chans != 3: 132 | conv1_name = cfg['first_conv'] 133 | conv1_weight = state_dict[conv1_name + '.weight'] 134 | conv1_type = conv1_weight.dtype 135 | conv1_weight = conv1_weight.float() 136 | O, I, J, K = conv1_weight.shape 137 | if I != 3: 138 | _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name) 139 | del state_dict[conv1_name + '.weight'] 140 | strict = False 141 | else: 142 | _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name) 143 | repeat = int(math.ceil(in_chans / 3)) 144 | conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] 145 | conv1_weight *= (3 / float(in_chans)) 146 | conv1_weight = conv1_weight.to(conv1_type) 147 | state_dict[conv1_name + '.weight'] = conv1_weight 148 | 149 | classifier_name = cfg['classifier'] 150 | if num_classes == 1000 and cfg['num_classes'] == 1001: 151 | # special case for imagenet trained models with extra background class in pretrained weights 152 | classifier_weight = state_dict[classifier_name + '.weight'] 153 | state_dict[classifier_name + '.weight'] = classifier_weight[1:] 154 | classifier_bias = state_dict[classifier_name + '.bias'] 155 | state_dict[classifier_name + '.bias'] = classifier_bias[1:] 156 | elif not (classifier_name + '.weight') in state_dict: 157 | pass 158 | elif num_classes != state_dict[classifier_name + '.weight'].size(0): 159 | # print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True) 160 | # completely discard fully connected for all other differences between pretrained and created model 161 | del state_dict[classifier_name + '.weight'] 162 | del state_dict[classifier_name + '.bias'] 163 | strict = False 164 | 165 | ## Resizing the positional embeddings in case they don't match 166 | if num_patches + 1 != state_dict['pos_embed'].size(1): 167 | pos_embed = state_dict['pos_embed'] 168 | cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) 169 | other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) 170 | new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest') 171 | new_pos_embed = new_pos_embed.transpose(1, 2) 172 | new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) 173 | state_dict['pos_embed'] = new_pos_embed 174 | 175 | ## Resizing time embeddings in case they don't match 176 | if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1): 177 | time_embed = state_dict['time_embed'].transpose(1, 2) 178 | new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest') 179 | state_dict['time_embed'] = new_time_embed.transpose(1, 2) 180 | 181 | ## Initializing temporal attention 182 | if attention_type == 'divided_space_time': 183 | new_state_dict = state_dict.copy() 184 | for key in state_dict: 185 | if 'blocks' in key and 'attn' in key: 186 | new_key = key.replace('attn', 'temporal_attn') 187 | if not new_key in state_dict: 188 | new_state_dict[new_key] = state_dict[key] 189 | else: 190 | new_state_dict[new_key] = state_dict[new_key] 191 | if 'blocks' in key and 'norm1' in key: 192 | new_key = key.replace('norm1', 'temporal_norm1') 193 | if not new_key in state_dict: 194 | new_state_dict[new_key] = state_dict[key] 195 | else: 196 | new_state_dict[new_key] = state_dict[new_key] 197 | state_dict = new_state_dict 198 | 199 | ## Loading the weights 200 | msg = model.load_state_dict(state_dict, strict=False) 201 | print(f"Loaded model inside with msg: {msg}") 202 | -------------------------------------------------------------------------------- /datasets/ssv2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | import json 4 | import numpy as np 5 | import os 6 | import random 7 | from itertools import chain as chain 8 | import torch 9 | import torch.utils.data 10 | from fvcore.common.file_io import PathManager 11 | 12 | import timesformer.utils.logging as logging 13 | 14 | from . import data_utils as utils 15 | from .build import DATASET_REGISTRY 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | @DATASET_REGISTRY.register() 21 | class Ssv2(torch.utils.data.Dataset): 22 | """ 23 | Something-Something v2 (SSV2) video loader. Construct the SSV2 video loader, 24 | then sample clips from the videos. For training and validation, a single 25 | clip is randomly sampled from every video with random cropping, scaling, and 26 | flipping. For testing, multiple clips are uniformaly sampled from every 27 | video with uniform cropping. For uniform cropping, we take the left, center, 28 | and right crop if the width is larger than height, or take top, center, and 29 | bottom crop if the height is larger than the width. 30 | """ 31 | 32 | def __init__(self, cfg, mode, num_retries=10): 33 | """ 34 | Load Something-Something V2 data (frame paths, labels, etc. ) to a given 35 | Dataset object. The dataset could be downloaded from Something-Something 36 | official website (https://20bn.com/datasets/something-something). 37 | Please see datasets/DATASET.md for more information about the data format. 38 | Args: 39 | cfg (CfgNode): configs. 40 | mode (string): Options includes `train`, `val`, or `test` mode. 41 | For the train and val mode, the data loader will take data 42 | from the train or val set, and sample one clip per video. 43 | For the test mode, the data loader will take data from test set, 44 | and sample multiple clips per video. 45 | num_retries (int): number of retries for reading frames from disk. 46 | """ 47 | # Only support train, val, and test mode. 48 | assert mode in [ 49 | "train", 50 | "val", 51 | "test", 52 | ], "Split '{}' not supported for Something-Something V2".format(mode) 53 | self.mode = mode 54 | self.cfg = cfg 55 | 56 | self._video_meta = {} 57 | self._num_retries = num_retries 58 | # For training or validation mode, one single clip is sampled from every 59 | # video. For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every 60 | # video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from 61 | # the frames. 62 | if self.mode in ["train", "val"]: 63 | self._num_clips = 1 64 | elif self.mode in ["test"]: 65 | self._num_clips = ( 66 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS 67 | ) 68 | 69 | logger.info("Constructing Something-Something V2 {}...".format(mode)) 70 | self._construct_loader() 71 | 72 | def _construct_loader(self): 73 | """ 74 | Construct the video loader. 75 | """ 76 | # Loading label names. 77 | with PathManager.open( 78 | os.path.join( 79 | self.cfg.DATA.PATH_TO_DATA_DIR, 80 | "something-something-v2-labels.json", 81 | ), 82 | "r", 83 | ) as f: 84 | label_dict = json.load(f) 85 | 86 | # Loading labels. 87 | label_file = os.path.join( 88 | self.cfg.DATA.PATH_TO_DATA_DIR, 89 | "something-something-v2-{}.json".format( 90 | "train" if self.mode == "train" else "validation" 91 | ), 92 | ) 93 | with PathManager.open(label_file, "r") as f: 94 | label_json = json.load(f) 95 | 96 | self._video_names = [] 97 | self._labels = [] 98 | for video in label_json: 99 | video_name = video["id"] 100 | template = video["template"] 101 | template = template.replace("[", "") 102 | template = template.replace("]", "") 103 | label = int(label_dict[template]) 104 | self._video_names.append(video_name) 105 | self._labels.append(label) 106 | 107 | path_to_file = os.path.join( 108 | self.cfg.DATA.PATH_TO_DATA_DIR, 109 | "{}.csv".format("train" if self.mode == "train" else "val"), 110 | ) 111 | assert PathManager.exists(path_to_file), "{} dir not found".format( 112 | path_to_file 113 | ) 114 | 115 | self._path_to_videos, _ = utils.load_image_lists( 116 | path_to_file, self.cfg.DATA.PATH_PREFIX 117 | ) 118 | 119 | assert len(self._path_to_videos) == len(self._video_names), ( 120 | len(self._path_to_videos), 121 | len(self._video_names), 122 | ) 123 | 124 | 125 | # From dict to list. 126 | new_paths, new_labels = [], [] 127 | for index in range(len(self._video_names)): 128 | if self._video_names[index] in self._path_to_videos: 129 | new_paths.append(self._path_to_videos[self._video_names[index]]) 130 | new_labels.append(self._labels[index]) 131 | 132 | self._labels = new_labels 133 | self._path_to_videos = new_paths 134 | 135 | # Extend self when self._num_clips > 1 (during testing). 136 | self._path_to_videos = list( 137 | chain.from_iterable( 138 | [[x] * self._num_clips for x in self._path_to_videos] 139 | ) 140 | ) 141 | self._labels = list( 142 | chain.from_iterable([[x] * self._num_clips for x in self._labels]) 143 | ) 144 | self._spatial_temporal_idx = list( 145 | chain.from_iterable( 146 | [ 147 | range(self._num_clips) 148 | for _ in range(len(self._path_to_videos)) 149 | ] 150 | ) 151 | ) 152 | logger.info( 153 | "Something-Something V2 dataloader constructed " 154 | " (size: {}) from {}".format( 155 | len(self._path_to_videos), path_to_file 156 | ) 157 | ) 158 | 159 | def __getitem__(self, index): 160 | """ 161 | Given the video index, return the list of frames, label, and video 162 | index if the video frames can be fetched. 163 | Args: 164 | index (int): the video index provided by the pytorch sampler. 165 | Returns: 166 | frames (tensor): the frames of sampled from the video. The dimension 167 | is `channel` x `num frames` x `height` x `width`. 168 | label (int): the label of the current video. 169 | index (int): the index of the video. 170 | """ 171 | short_cycle_idx = None 172 | # When short cycle is used, input index is a tupple. 173 | if isinstance(index, tuple): 174 | index, short_cycle_idx = index 175 | 176 | if self.mode in ["train", "val"]: #or self.cfg.MODEL.ARCH in ['resformer', 'vit']: 177 | # -1 indicates random sampling. 178 | spatial_sample_index = -1 179 | min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0] 180 | max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1] 181 | crop_size = self.cfg.DATA.TRAIN_CROP_SIZE 182 | if short_cycle_idx in [0, 1]: 183 | crop_size = int( 184 | round( 185 | self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx] 186 | * self.cfg.MULTIGRID.DEFAULT_S 187 | ) 188 | ) 189 | if self.cfg.MULTIGRID.DEFAULT_S > 0: 190 | # Decreasing the scale is equivalent to using a larger "span" 191 | # in a sampling grid. 192 | min_scale = int( 193 | round( 194 | float(min_scale) 195 | * crop_size 196 | / self.cfg.MULTIGRID.DEFAULT_S 197 | ) 198 | ) 199 | elif self.mode in ["test"]: 200 | # spatial_sample_index is in [0, 1, 2]. Corresponding to left, 201 | # center, or right if width is larger than height, and top, middle, 202 | # or bottom if height is larger than width. 203 | spatial_sample_index = ( 204 | self._spatial_temporal_idx[index] 205 | % self.cfg.TEST.NUM_SPATIAL_CROPS 206 | ) 207 | if self.cfg.TEST.NUM_SPATIAL_CROPS == 1: 208 | spatial_sample_index = 1 209 | 210 | min_scale, max_scale, crop_size = [self.cfg.DATA.TEST_CROP_SIZE] * 3 211 | # The testing is deterministic and no jitter should be performed. 212 | # min_scale, max_scale, and crop_size are expect to be the same. 213 | assert len({min_scale, max_scale, crop_size}) == 1 214 | else: 215 | raise NotImplementedError( 216 | "Does not support {} mode".format(self.mode) 217 | ) 218 | 219 | label = self._labels[index] 220 | 221 | num_frames = self.cfg.DATA.NUM_FRAMES 222 | video_length = len(self._path_to_videos[index]) 223 | 224 | 225 | seg_size = float(video_length - 1) / num_frames 226 | seq = [] 227 | for i in range(num_frames): 228 | start = int(np.round(seg_size * i)) 229 | end = int(np.round(seg_size * (i + 1))) 230 | if self.mode == "train": 231 | seq.append(random.randint(start, end)) 232 | else: 233 | seq.append((start + end) // 2) 234 | 235 | frames = torch.as_tensor( 236 | utils.retry_load_images( 237 | [self._path_to_videos[index][frame] for frame in seq], 238 | self._num_retries, 239 | ) 240 | ) 241 | 242 | # Perform color normalization. 243 | frames = utils.tensor_normalize( 244 | frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD 245 | ) 246 | 247 | # T H W C -> C T H W. 248 | frames = frames.permute(3, 0, 1, 2) 249 | frames = utils.spatial_sampling( 250 | frames, 251 | spatial_idx=spatial_sample_index, 252 | min_scale=min_scale, 253 | max_scale=max_scale, 254 | crop_size=crop_size, 255 | random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP, 256 | inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE, 257 | ) 258 | #if not self.cfg.RESFORMER.ACTIVE: 259 | if not self.cfg.MODEL.ARCH in ['vit']: 260 | frames = utils.pack_pathway_output(self.cfg, frames) 261 | else: 262 | # Perform temporal sampling from the fast pathway. 263 | frames = torch.index_select( 264 | frames, 265 | 1, 266 | torch.linspace( 267 | 0, frames.shape[1] - 1, self.cfg.DATA.NUM_FRAMES 268 | 269 | ).long(), 270 | ) 271 | return frames, label, index, {} 272 | 273 | def __len__(self): 274 | """ 275 | Returns: 276 | (int): the number of videos in the dataset. 277 | """ 278 | return len(self._path_to_videos) 279 | -------------------------------------------------------------------------------- /eval_knn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | import torch.distributed as dist 20 | import torch.utils.data 21 | from torch import nn 22 | 23 | from datasets.hmdb51 import HMDB51 24 | from datasets.ucf101 import UCF101 25 | from models import get_vit_base_patch16_224 26 | from utils import utils 27 | from utils.parser import load_config 28 | 29 | 30 | def extract_feature_pipeline(args): 31 | # ============ preparing data ... ============ 32 | config = load_config(args) 33 | # config.DATA.PATH_TO_DATA_DIR = f"{os.path.expanduser('~')}/repo/mmaction2/data/{args.dataset}/knn_splits" 34 | # config.DATA.PATH_PREFIX = f"{os.path.expanduser('~')}/repo/mmaction2/data/{args.dataset}/videos" 35 | config.TEST.NUM_SPATIAL_CROPS = 1 36 | if args.dataset == "ucf101": 37 | dataset_train = UCFReturnIndexDataset(cfg=config, mode="train", num_retries=10) 38 | dataset_val = UCFReturnIndexDataset(cfg=config, mode="val", num_retries=10) 39 | elif args.dataset == "hmdb51": 40 | dataset_train = HMDBReturnIndexDataset(cfg=config, mode="train", num_retries=10) 41 | dataset_val = HMDBReturnIndexDataset(cfg=config, mode="val", num_retries=10) 42 | else: 43 | raise NotImplementedError(f"invalid dataset: {args.dataset}") 44 | 45 | sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) 46 | data_loader_train = torch.utils.data.DataLoader( 47 | dataset_train, 48 | sampler=sampler, 49 | batch_size=args.batch_size_per_gpu, 50 | num_workers=args.num_workers, 51 | pin_memory=True, 52 | drop_last=False, 53 | ) 54 | data_loader_val = torch.utils.data.DataLoader( 55 | dataset_val, 56 | batch_size=args.batch_size_per_gpu, 57 | num_workers=args.num_workers, 58 | pin_memory=True, 59 | drop_last=False, 60 | ) 61 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") 62 | 63 | # ============ building network ... ============ 64 | model = get_vit_base_patch16_224(cfg=config, no_head=True) 65 | ckpt = torch.load(args.pretrained_weights) 66 | # select_ckpt = "teacher" 67 | renamed_checkpoint = {x[len("backbone."):]: y for x, y in ckpt.items() if x.startswith("backbone.")} 68 | msg = model.load_state_dict(renamed_checkpoint, strict=False) 69 | print(f"Loaded model with msg: {msg}") 70 | model.cuda() 71 | model.eval() 72 | 73 | # ============ extract features ... ============ 74 | print("Extracting features for train set...") 75 | train_features = extract_features(model, data_loader_train) 76 | print("Extracting features for val set...") 77 | test_features = extract_features(model, data_loader_val) 78 | 79 | if utils.get_rank() == 0: 80 | train_features = nn.functional.normalize(train_features, dim=1, p=2) 81 | test_features = nn.functional.normalize(test_features, dim=1, p=2) 82 | 83 | train_labels = torch.tensor([s for s in dataset_train._labels]).long() 84 | test_labels = torch.tensor([s for s in dataset_val._labels]).long() 85 | # save features and labels 86 | if args.dump_features and dist.get_rank() == 0: 87 | torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")) 88 | torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")) 89 | torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")) 90 | torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")) 91 | return train_features, test_features, train_labels, test_labels 92 | 93 | 94 | @torch.no_grad() 95 | def extract_features(model, data_loader): 96 | metric_logger = utils.MetricLogger(delimiter=" ") 97 | features = None 98 | for samples, index in metric_logger.log_every(data_loader, 10): 99 | samples = samples.cuda(non_blocking=True) 100 | index = index.cuda(non_blocking=True) 101 | feats = model(samples).clone() 102 | 103 | # init storage feature matrix 104 | if dist.get_rank() == 0 and features is None: 105 | features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) 106 | # if args.use_cuda: 107 | features = features.cuda(non_blocking=True) 108 | print(f"Storing features into tensor of shape {features.shape}") 109 | 110 | # get indexes from all processes 111 | y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) 112 | y_l = list(y_all.unbind(0)) 113 | y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) 114 | y_all_reduce.wait() 115 | index_all = torch.cat(y_l) 116 | 117 | # share features between processes 118 | feats_all = torch.empty( 119 | dist.get_world_size(), 120 | feats.size(0), 121 | feats.size(1), 122 | dtype=feats.dtype, 123 | device=feats.device, 124 | ) 125 | output_l = list(feats_all.unbind(0)) 126 | output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) 127 | output_all_reduce.wait() 128 | 129 | # update storage feature matrix 130 | if dist.get_rank() == 0: 131 | # if args.use_cuda: 132 | features.index_copy_(0, index_all, torch.cat(output_l)) 133 | # else: 134 | # features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) 135 | return features 136 | 137 | 138 | @torch.no_grad() 139 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000): 140 | top1, top5, total = 0.0, 0.0, 0 141 | train_features = train_features.t() 142 | num_test_images, num_chunks = test_labels.shape[0], 100 143 | imgs_per_chunk = num_test_images // num_chunks 144 | retrieval_one_hot = torch.zeros(k, num_classes).cuda() 145 | for idx in range(0, num_test_images, imgs_per_chunk): 146 | # get the features for test images 147 | features = test_features[ 148 | idx : min((idx + imgs_per_chunk), num_test_images), : 149 | ] 150 | targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)] 151 | batch_size = targets.shape[0] 152 | 153 | # calculate the dot product and compute top-k neighbors 154 | similarity = torch.mm(features, train_features) 155 | distances, indices = similarity.topk(k, largest=True, sorted=True) 156 | candidates = train_labels.view(1, -1).expand(batch_size, -1) 157 | retrieved_neighbors = torch.gather(candidates, 1, indices) 158 | 159 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() 160 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 161 | distances_transform = distances.clone().div_(T).exp_() 162 | probs = torch.sum( 163 | torch.mul( 164 | retrieval_one_hot.view(batch_size, -1, num_classes), 165 | distances_transform.view(batch_size, -1, 1), 166 | ), 167 | 1, 168 | ) 169 | _, predictions = probs.sort(1, True) 170 | 171 | # find the predictions that match the target 172 | correct = predictions.eq(targets.data.view(-1, 1)) 173 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 174 | top5 = top5 + correct.narrow(1, 0, 5).sum().item() 175 | total += targets.size(0) 176 | top1 = top1 * 100.0 / total 177 | top5 = top5 * 100.0 / total 178 | return top1, top5 179 | 180 | 181 | class UCFReturnIndexDataset(UCF101): 182 | def __getitem__(self, idx): 183 | img, _, _, _ = super(UCFReturnIndexDataset, self).__getitem__(idx) 184 | return img, idx 185 | 186 | 187 | class HMDBReturnIndexDataset(HMDB51): 188 | def __getitem__(self, idx): 189 | img, _, _, _ = super(HMDBReturnIndexDataset, self).__getitem__(idx) 190 | return img, idx 191 | 192 | if __name__ == '__main__': 193 | parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet') 194 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') 195 | parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int, 196 | help='Number of NN to use. 20 is usually working the best.') 197 | parser.add_argument('--temperature', default=0.07, type=float, 198 | help='Temperature used in the voting coefficient') 199 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 200 | parser.add_argument('--use_cuda', default=True, type=utils.bool_flag, 201 | help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM") 202 | parser.add_argument('--arch', default='vit_small', type=str, 203 | choices=['vit_tiny', 'vit_small', 'vit_base', 'timesformer'], help='Architecture (support only ViT atm).') 204 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 205 | parser.add_argument("--checkpoint_key", default="teacher", type=str, 206 | help='Key to use in the checkpoint (example: "teacher")') 207 | parser.add_argument('--dump_features', default=None, 208 | help='Path where to save computed features, empty for no saving') 209 | parser.add_argument('--load_features', default=None, help="""If the features have 210 | already been computed, where to find them.""") 211 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 212 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 213 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 214 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 215 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) 216 | 217 | parser.add_argument('--dataset', default="ucf101", help='Dataset: ucf101 / hmdb51') 218 | parser.add_argument("--cfg", dest="cfg_file", help="Path to the config file", type=str, 219 | default="models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml") 220 | parser.add_argument("--opts", help="See utils/defaults.py for all options", default=None, nargs=argparse.REMAINDER) 221 | 222 | args = parser.parse_args() 223 | 224 | utils.init_distributed_mode(args) 225 | print("git:\n {}\n".format(utils.get_sha())) 226 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 227 | cudnn.benchmark = True 228 | 229 | if args.load_features: 230 | train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth")) 231 | test_features = torch.load(os.path.join(args.load_features, "testfeat.pth")) 232 | train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth")) 233 | test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth")) 234 | else: 235 | # need to extract features ! 236 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args) 237 | 238 | if utils.get_rank() == 0: 239 | if args.use_cuda: 240 | train_features = train_features.cuda() 241 | test_features = test_features.cuda() 242 | train_labels = train_labels.cuda() 243 | test_labels = test_labels.cuda() 244 | 245 | print("Features are ready!\nStart the k-NN classification.") 246 | for k in args.nb_knn: 247 | top1, top5 = knn_classifier(train_features, train_labels, 248 | test_features, test_labels, k, args.temperature) 249 | print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") 250 | dist.barrier() 251 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import warnings 4 | 5 | import torch 6 | import torch.utils.data 7 | 8 | from datasets.data_utils import get_random_sampling_rate, tensor_normalize, spatial_sampling, pack_pathway_output 9 | from datasets.decoder import decode 10 | from datasets.video_container import get_video_container 11 | from datasets.transform import VideoDataAugmentationDINO 12 | from einops import rearrange 13 | 14 | 15 | class UCF101(torch.utils.data.Dataset): 16 | """ 17 | UCF101 video loader. Construct the UCF101 video loader, then sample 18 | clips from the videos. For training and validation, a single clip is 19 | randomly sampled from every video with random cropping, scaling, and 20 | flipping. For testing, multiple clips are uniformaly sampled from every 21 | video with uniform cropping. For uniform cropping, we take the left, center, 22 | and right crop if the width is larger than height, or take top, center, and 23 | bottom crop if the height is larger than the width. 24 | """ 25 | 26 | def __init__(self, cfg, mode, num_retries=10): 27 | """ 28 | Construct the UCF101 video loader with a given csv file. The format of 29 | the csv file is: 30 | ``` 31 | path_to_video_1 label_1 32 | path_to_video_2 label_2 33 | ... 34 | path_to_video_N label_N 35 | ``` 36 | Args: 37 | cfg (CfgNode): configs. 38 | mode (string): Options includes `train`, `val`, or `test` mode. 39 | For the train mode, the data loader will take data from the 40 | train set, and sample one clip per video. For the val and 41 | test mode, the data loader will take data from relevent set, 42 | and sample multiple clips per video. 43 | num_retries (int): number of retries. 44 | """ 45 | # Only support train, val, and test mode. 46 | assert mode in ["train", "val", "test"], "Split '{}' not supported for UCF101".format(mode) 47 | self.mode = mode 48 | self.cfg = cfg 49 | 50 | self._video_meta = {} 51 | self._num_retries = num_retries 52 | self._split_idx = mode 53 | # For training mode, one single clip is sampled from every video. For validation or testing, NUM_ENSEMBLE_VIEWS 54 | # clips are sampled from every video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from the frames. 55 | if self.mode in ["train"]: 56 | self._num_clips = 1 57 | elif self.mode in ["val", "test"]: 58 | self._num_clips = ( 59 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS 60 | ) 61 | 62 | print("Constructing UCF101 {}...".format(mode)) 63 | self._construct_loader() 64 | 65 | def _construct_loader(self): 66 | """ 67 | Construct the video loader. 68 | """ 69 | path_to_file = os.path.join( 70 | self.cfg.DATA.PATH_TO_DATA_DIR, "ucf101_{}_split_1_videos.txt".format(self.mode) 71 | ) 72 | assert os.path.exists(path_to_file), "{} dir not found".format( 73 | path_to_file 74 | ) 75 | 76 | self._path_to_videos = [] 77 | self._labels = [] 78 | self._spatial_temporal_idx = [] 79 | with open(path_to_file, "r") as f: 80 | for clip_idx, path_label in enumerate(f.read().splitlines()): 81 | assert ( 82 | len(path_label.split(self.cfg.DATA.PATH_LABEL_SEPARATOR)) 83 | == 2 84 | ) 85 | path, label = path_label.split( 86 | self.cfg.DATA.PATH_LABEL_SEPARATOR 87 | ) 88 | for idx in range(self._num_clips): 89 | self._path_to_videos.append( 90 | os.path.join(self.cfg.DATA.PATH_PREFIX, path) 91 | ) 92 | self._labels.append(int(label)) 93 | self._spatial_temporal_idx.append(idx) 94 | self._video_meta[clip_idx * self._num_clips + idx] = {} 95 | assert (len(self._path_to_videos) > 0), f"Failed to load UCF101 split {self._split_idx} from {path_to_file}" 96 | print(f"Constructing UCF101 dataloader (size: {len(self._path_to_videos)}) from {path_to_file}") 97 | 98 | def __getitem__(self, index): 99 | """ 100 | Given the video index, return the list of frames, label, and video 101 | index if the video can be fetched and decoded successfully, otherwise 102 | repeatly find a random video that can be decoded as a replacement. 103 | Args: 104 | index (int): the video index provided by the pytorch sampler. 105 | Returns: 106 | frames (tensor): the frames of sampled from the video. The dimension 107 | is `channel` x `num frames` x `height` x `width`. 108 | label (int): the label of the current video. 109 | index (int): if the video provided by pytorch sampler can be 110 | decoded, then return the index of the video. If not, return the 111 | index of the video replacement that can be decoded. 112 | """ 113 | short_cycle_idx = None 114 | # When short cycle is used, input index is a tupple. 115 | if isinstance(index, tuple): 116 | index, short_cycle_idx = index 117 | 118 | if self.mode in ["train"]: 119 | # -1 indicates random sampling. 120 | temporal_sample_index = -1 121 | spatial_sample_index = -1 122 | min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0] 123 | max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1] 124 | crop_size = self.cfg.DATA.TRAIN_CROP_SIZE 125 | if short_cycle_idx in [0, 1]: 126 | crop_size = int( 127 | round( 128 | self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx] 129 | * self.cfg.MULTIGRID.DEFAULT_S 130 | ) 131 | ) 132 | if self.cfg.MULTIGRID.DEFAULT_S > 0: 133 | # Decreasing the scale is equivalent to using a larger "span" 134 | # in a sampling grid. 135 | min_scale = int( 136 | round( 137 | float(min_scale) 138 | * crop_size 139 | / self.cfg.MULTIGRID.DEFAULT_S 140 | ) 141 | ) 142 | elif self.mode in ["val", "test"]: 143 | temporal_sample_index = (self._spatial_temporal_idx[index] // self.cfg.TEST.NUM_SPATIAL_CROPS) 144 | # spatial_sample_index is in [0, 1, 2]. Corresponding to left, 145 | # center, or right if width is larger than height, and top, middle, 146 | # or bottom if height is larger than width. 147 | spatial_sample_index = ( 148 | (self._spatial_temporal_idx[index] % self.cfg.TEST.NUM_SPATIAL_CROPS) 149 | if self.cfg.TEST.NUM_SPATIAL_CROPS > 1 else 1 150 | ) 151 | min_scale, max_scale, crop_size = ( 152 | [self.cfg.DATA.TEST_CROP_SIZE] * 3 if self.cfg.TEST.NUM_SPATIAL_CROPS > 1 153 | else [self.cfg.DATA.TRAIN_JITTER_SCALES[0]] * 2 + [self.cfg.DATA.TEST_CROP_SIZE] 154 | ) 155 | # The testing is deterministic and no jitter should be performed. 156 | # min_scale, max_scale, and crop_size are expect to be the same. 157 | assert len({min_scale, max_scale}) == 1 158 | else: 159 | raise NotImplementedError( 160 | "Does not support {} mode".format(self.mode) 161 | ) 162 | sampling_rate = get_random_sampling_rate( 163 | self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE, 164 | self.cfg.DATA.SAMPLING_RATE, 165 | ) 166 | # Try to decode and sample a clip from a video. If the video can not be 167 | # decoded, repeatedly find a random video replacement that can be decoded. 168 | for i_try in range(self._num_retries): 169 | video_container = None 170 | try: 171 | video_container = get_video_container( 172 | self._path_to_videos[index], 173 | self.cfg.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE, 174 | self.cfg.DATA.DECODING_BACKEND, 175 | ) 176 | except Exception as e: 177 | print( 178 | "Failed to load video from {} with error {}".format( 179 | self._path_to_videos[index], e 180 | ) 181 | ) 182 | # Select a random video if the current video was not able to access. 183 | if video_container is None: 184 | warnings.warn( 185 | "Failed to meta load video idx {} from {}; trial {}".format( 186 | index, self._path_to_videos[index], i_try 187 | ) 188 | ) 189 | if self.mode not in ["val", "test"] and i_try > self._num_retries // 2: 190 | # let's try another one 191 | index = random.randint(0, len(self._path_to_videos) - 1) 192 | continue 193 | 194 | # Decode video. Meta info is used to perform selective decoding. 195 | frames = decode( 196 | container=video_container, 197 | sampling_rate=sampling_rate, 198 | num_frames=self.cfg.DATA.NUM_FRAMES, 199 | clip_idx=temporal_sample_index, 200 | num_clips=self.cfg.TEST.NUM_ENSEMBLE_VIEWS, 201 | video_meta=self._video_meta[index], 202 | target_fps=self.cfg.DATA.TARGET_FPS, 203 | backend=self.cfg.DATA.DECODING_BACKEND, 204 | max_spatial_scale=min_scale, 205 | ) 206 | 207 | # If decoding failed (wrong format, video is too short, and etc), 208 | # select another video. 209 | if frames is None: 210 | warnings.warn( 211 | "Failed to decode video idx {} from {}; trial {}".format( 212 | index, self._path_to_videos[index], i_try 213 | ) 214 | ) 215 | if self.mode not in ["test"] and i_try > self._num_retries // 2: 216 | # let's try another one 217 | index = random.randint(0, len(self._path_to_videos) - 1) 218 | continue 219 | 220 | label = self._labels[index] 221 | 222 | # Perform color normalization. 223 | frames = tensor_normalize( 224 | frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD 225 | ) 226 | frames = frames.permute(3, 0, 1, 2) 227 | 228 | # Perform data augmentation. 229 | frames = spatial_sampling( 230 | frames, 231 | spatial_idx=spatial_sample_index, 232 | min_scale=min_scale, 233 | max_scale=max_scale, 234 | crop_size=crop_size, 235 | random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP, 236 | inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE, 237 | ) 238 | 239 | # if not self.cfg.MODEL.ARCH in ['vit']: 240 | # frames = pack_pathway_output(self.cfg, frames) 241 | # else: 242 | # Perform temporal sampling from the fast pathway. 243 | # frames = [torch.index_select( 244 | # x, 245 | # 1, 246 | # torch.linspace( 247 | # 0, x.shape[1] - 1, self.cfg.DATA.NUM_FRAMES 248 | # ).long(), 249 | # ) for x in frames] 250 | 251 | return frames, label, index, {} 252 | else: 253 | raise RuntimeError( 254 | "Failed to fetch video after {} retries.".format( 255 | self._num_retries 256 | ) 257 | ) 258 | 259 | def __len__(self): 260 | """ 261 | Returns: 262 | (int): the number of videos in the dataset. 263 | """ 264 | return len(self._path_to_videos) 265 | 266 | 267 | if __name__ == '__main__': 268 | 269 | from utils.parser import parse_args, load_config 270 | from tqdm import tqdm 271 | 272 | args = parse_args() 273 | args.cfg_file = "models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml" 274 | config = load_config(args) 275 | config.DATA.PATH_TO_DATA_DIR = "/home/kanchanaranasinghe/repo/mmaction2/data/ucf101/splits" 276 | config.DATA.PATH_PREFIX = "/home/kanchanaranasinghe/repo/mmaction2/data/ucf101/videos" 277 | dataset = UCF101(cfg=config, mode="train", num_retries=10) 278 | dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=4) 279 | print(f"Loaded train dataset of length: {len(dataset)}") 280 | for idx, i in enumerate(dataloader): 281 | print(idx, i[0].shape, i[1:]) 282 | if idx > 2: 283 | break 284 | 285 | test_dataset = UCF101(cfg=config, mode="val", num_retries=10) 286 | test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=4) 287 | print(f"Loaded test dataset of length: {len(test_dataset)}") 288 | for idx, i in enumerate(test_dataloader): 289 | print(idx, i[0].shape, i[1:]) 290 | if idx > 2: 291 | break 292 | -------------------------------------------------------------------------------- /datasets/hmdb51.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import warnings 4 | 5 | import torch 6 | import torch.utils.data 7 | 8 | from datasets.data_utils import get_random_sampling_rate, tensor_normalize, spatial_sampling, pack_pathway_output 9 | from datasets.decoder import decode 10 | from datasets.video_container import get_video_container 11 | from datasets.transform import VideoDataAugmentationDINO 12 | from einops import rearrange 13 | 14 | 15 | class HMDB51(torch.utils.data.Dataset): 16 | """ 17 | UCF101 video loader. Construct the UCF101 video loader, then sample 18 | clips from the videos. For training and validation, a single clip is 19 | randomly sampled from every video with random cropping, scaling, and 20 | flipping. For testing, multiple clips are uniformaly sampled from every 21 | video with uniform cropping. For uniform cropping, we take the left, center, 22 | and right crop if the width is larger than height, or take top, center, and 23 | bottom crop if the height is larger than the width. 24 | """ 25 | 26 | def __init__(self, cfg, mode, num_retries=10): 27 | """ 28 | Construct the UCF101 video loader with a given csv file. The format of 29 | the csv file is: 30 | ``` 31 | path_to_video_1 label_1 32 | path_to_video_2 label_2 33 | ... 34 | path_to_video_N label_N 35 | ``` 36 | Args: 37 | cfg (CfgNode): configs. 38 | mode (string): Options includes `train`, `val`, or `test` mode. 39 | For the train mode, the data loader will take data from the 40 | train set, and sample one clip per video. For the val and 41 | test mode, the data loader will take data from relevent set, 42 | and sample multiple clips per video. 43 | num_retries (int): number of retries. 44 | """ 45 | # Only support train, val, and test mode. 46 | assert mode in ["train", "val", "test"], "Split '{}' not supported for UCF101".format(mode) 47 | self.mode = mode 48 | self.cfg = cfg 49 | 50 | self._video_meta = {} 51 | self._num_retries = num_retries 52 | self._split_idx = mode 53 | # For training mode, one single clip is sampled from every video. For validation or testing, NUM_ENSEMBLE_VIEWS 54 | # clips are sampled from every video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from the frames. 55 | if self.mode in ["train"]: 56 | self._num_clips = 1 57 | elif self.mode in ["val", "test"]: 58 | self._num_clips = ( 59 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS 60 | ) 61 | 62 | print("Constructing HMDB51 {}...".format(mode)) 63 | self._construct_loader() 64 | 65 | def _construct_loader(self): 66 | """ 67 | Construct the video loader. 68 | """ 69 | path_to_file = os.path.join( 70 | self.cfg.DATA.PATH_TO_DATA_DIR, "hmdb51_{}_split_1_videos.txt".format(self.mode) 71 | ) 72 | assert os.path.exists(path_to_file), "{} dir not found".format( 73 | path_to_file 74 | ) 75 | 76 | self._path_to_videos = [] 77 | self._labels = [] 78 | self._spatial_temporal_idx = [] 79 | with open(path_to_file, "r") as f: 80 | for clip_idx, path_label in enumerate(f.read().splitlines()): 81 | assert ( 82 | len(path_label.split(self.cfg.DATA.PATH_LABEL_SEPARATOR)) 83 | == 2 84 | ) 85 | path, label = path_label.split( 86 | self.cfg.DATA.PATH_LABEL_SEPARATOR 87 | ) 88 | for idx in range(self._num_clips): 89 | self._path_to_videos.append( 90 | os.path.join(self.cfg.DATA.PATH_PREFIX, path) 91 | ) 92 | self._labels.append(int(label)) 93 | self._spatial_temporal_idx.append(idx) 94 | self._video_meta[clip_idx * self._num_clips + idx] = {} 95 | assert (len(self._path_to_videos) > 0), f"Failed to load UCF101 split {self._split_idx} from {path_to_file}" 96 | print(f"Constructing HMDB51 dataloader (size: {len(self._path_to_videos)}) from {path_to_file}") 97 | 98 | def __getitem__(self, index): 99 | """ 100 | Given the video index, return the list of frames, label, and video 101 | index if the video can be fetched and decoded successfully, otherwise 102 | repeatly find a random video that can be decoded as a replacement. 103 | Args: 104 | index (int): the video index provided by the pytorch sampler. 105 | Returns: 106 | frames (tensor): the frames of sampled from the video. The dimension 107 | is `channel` x `num frames` x `height` x `width`. 108 | label (int): the label of the current video. 109 | index (int): if the video provided by pytorch sampler can be 110 | decoded, then return the index of the video. If not, return the 111 | index of the video replacement that can be decoded. 112 | """ 113 | short_cycle_idx = None 114 | # When short cycle is used, input index is a tupple. 115 | if isinstance(index, tuple): 116 | index, short_cycle_idx = index 117 | 118 | if self.mode in ["train"]: 119 | # -1 indicates random sampling. 120 | temporal_sample_index = -1 121 | spatial_sample_index = -1 122 | min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0] 123 | max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1] 124 | crop_size = self.cfg.DATA.TRAIN_CROP_SIZE 125 | if short_cycle_idx in [0, 1]: 126 | crop_size = int( 127 | round( 128 | self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx] 129 | * self.cfg.MULTIGRID.DEFAULT_S 130 | ) 131 | ) 132 | if self.cfg.MULTIGRID.DEFAULT_S > 0: 133 | # Decreasing the scale is equivalent to using a larger "span" 134 | # in a sampling grid. 135 | min_scale = int( 136 | round( 137 | float(min_scale) 138 | * crop_size 139 | / self.cfg.MULTIGRID.DEFAULT_S 140 | ) 141 | ) 142 | elif self.mode in ["val", "test"]: 143 | temporal_sample_index = (self._spatial_temporal_idx[index] // self.cfg.TEST.NUM_SPATIAL_CROPS) 144 | # spatial_sample_index is in [0, 1, 2]. Corresponding to left, 145 | # center, or right if width is larger than height, and top, middle, 146 | # or bottom if height is larger than width. 147 | spatial_sample_index = ( 148 | (self._spatial_temporal_idx[index] % self.cfg.TEST.NUM_SPATIAL_CROPS) 149 | if self.cfg.TEST.NUM_SPATIAL_CROPS > 1 else 1 150 | ) 151 | min_scale, max_scale, crop_size = ( 152 | [self.cfg.DATA.TEST_CROP_SIZE] * 3 if self.cfg.TEST.NUM_SPATIAL_CROPS > 1 153 | else [self.cfg.DATA.TRAIN_JITTER_SCALES[0]] * 2 + [self.cfg.DATA.TEST_CROP_SIZE] 154 | ) 155 | # The testing is deterministic and no jitter should be performed. 156 | # min_scale, max_scale, and crop_size are expect to be the same. 157 | assert len({min_scale, max_scale}) == 1 158 | else: 159 | raise NotImplementedError( 160 | "Does not support {} mode".format(self.mode) 161 | ) 162 | sampling_rate = get_random_sampling_rate( 163 | self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE, 164 | self.cfg.DATA.SAMPLING_RATE, 165 | ) 166 | # Try to decode and sample a clip from a video. If the video can not be 167 | # decoded, repeatedly find a random video replacement that can be decoded. 168 | for i_try in range(self._num_retries): 169 | video_container = None 170 | try: 171 | video_container = get_video_container( 172 | self._path_to_videos[index], 173 | self.cfg.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE, 174 | self.cfg.DATA.DECODING_BACKEND, 175 | ) 176 | except Exception as e: 177 | print( 178 | "Failed to load video from {} with error {}".format( 179 | self._path_to_videos[index], e 180 | ) 181 | ) 182 | # Select a random video if the current video was not able to access. 183 | if video_container is None: 184 | warnings.warn( 185 | "Failed to meta load video idx {} from {}; trial {}".format( 186 | index, self._path_to_videos[index], i_try 187 | ) 188 | ) 189 | if self.mode not in ["val", "test"] and i_try > self._num_retries // 2: 190 | # let's try another one 191 | index = random.randint(0, len(self._path_to_videos) - 1) 192 | continue 193 | 194 | # Decode video. Meta info is used to perform selective decoding. 195 | frames = decode( 196 | container=video_container, 197 | sampling_rate=sampling_rate, 198 | num_frames=self.cfg.DATA.NUM_FRAMES, 199 | clip_idx=temporal_sample_index, 200 | num_clips=self.cfg.TEST.NUM_ENSEMBLE_VIEWS, 201 | video_meta=self._video_meta[index], 202 | target_fps=self.cfg.DATA.TARGET_FPS, 203 | backend=self.cfg.DATA.DECODING_BACKEND, 204 | max_spatial_scale=min_scale, 205 | ) 206 | 207 | # If decoding failed (wrong format, video is too short, and etc), 208 | # select another video. 209 | if frames is None: 210 | warnings.warn( 211 | "Failed to decode video idx {} from {}; trial {}".format( 212 | index, self._path_to_videos[index], i_try 213 | ) 214 | ) 215 | if self.mode not in ["test"] and i_try > self._num_retries // 2: 216 | # let's try another one 217 | index = random.randint(0, len(self._path_to_videos) - 1) 218 | continue 219 | 220 | label = self._labels[index] 221 | 222 | # Perform color normalization. 223 | frames = tensor_normalize( 224 | frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD 225 | ) 226 | frames = frames.permute(3, 0, 1, 2) 227 | 228 | # Perform data augmentation. 229 | frames = spatial_sampling( 230 | frames, 231 | spatial_idx=spatial_sample_index, 232 | min_scale=min_scale, 233 | max_scale=max_scale, 234 | crop_size=crop_size, 235 | random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP, 236 | inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE, 237 | ) 238 | 239 | # if not self.cfg.MODEL.ARCH in ['vit']: 240 | # frames = pack_pathway_output(self.cfg, frames) 241 | # else: 242 | # Perform temporal sampling from the fast pathway. 243 | # frames = [torch.index_select( 244 | # x, 245 | # 1, 246 | # torch.linspace( 247 | # 0, x.shape[1] - 1, self.cfg.DATA.NUM_FRAMES 248 | # ).long(), 249 | # ) for x in frames] 250 | 251 | return frames, label, index, {} 252 | else: 253 | raise RuntimeError( 254 | "Failed to fetch video after {} retries.".format( 255 | self._num_retries 256 | ) 257 | ) 258 | 259 | def __len__(self): 260 | """ 261 | Returns: 262 | (int): the number of videos in the dataset. 263 | """ 264 | return len(self._path_to_videos) 265 | 266 | 267 | if __name__ == '__main__': 268 | 269 | from utils.parser import parse_args, load_config 270 | from tqdm import tqdm 271 | 272 | args = parse_args() 273 | args.cfg_file = "models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml" 274 | config = load_config(args) 275 | config.DATA.PATH_TO_DATA_DIR = "/home/kanchanaranasinghe/repo/mmaction2/data/hmdb51/splits" 276 | config.DATA.PATH_PREFIX = "/home/kanchanaranasinghe/repo/mmaction2/data/hmdb51/videos" 277 | dataset = HMDB51(cfg=config, mode="train", num_retries=10) 278 | dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=4) 279 | print(f"Loaded train dataset of length: {len(dataset)}") 280 | for idx, i in tqdm(enumerate(dataloader)): 281 | # continue 282 | print(idx, i[0].shape, i[1:]) 283 | if idx > 2: 284 | break 285 | 286 | test_dataset = HMDB51(cfg=config, mode="val", num_retries=10) 287 | test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=4) 288 | print(f"Loaded test dataset of length: {len(test_dataset)}") 289 | for idx, i in tqdm(enumerate(test_dataloader)): 290 | # continue 291 | print(idx, i[0].shape, i[1:]) 292 | if idx > 2: 293 | break 294 | -------------------------------------------------------------------------------- /vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from utils.utils import trunc_normal_ 25 | 26 | 27 | def drop_path(x, drop_prob: float = 0., training: bool = False): 28 | if drop_prob == 0. or not training: 29 | return x 30 | keep_prob = 1 - drop_prob 31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 33 | random_tensor.floor_() # binarize 34 | output = x.div(keep_prob) * random_tensor 35 | return output 36 | 37 | 38 | class DropPath(nn.Module): 39 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 40 | """ 41 | def __init__(self, drop_prob=None): 42 | super(DropPath, self).__init__() 43 | self.drop_prob = drop_prob 44 | 45 | def forward(self, x): 46 | return drop_path(x, self.drop_prob, self.training) 47 | 48 | 49 | class Mlp(nn.Module): 50 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 51 | super().__init__() 52 | out_features = out_features or in_features 53 | hidden_features = hidden_features or in_features 54 | self.fc1 = nn.Linear(in_features, hidden_features) 55 | self.act = act_layer() 56 | self.fc2 = nn.Linear(hidden_features, out_features) 57 | self.drop = nn.Dropout(drop) 58 | 59 | def forward(self, x): 60 | x = self.fc1(x) 61 | x = self.act(x) 62 | x = self.drop(x) 63 | x = self.fc2(x) 64 | x = self.drop(x) 65 | return x 66 | 67 | 68 | class Attention(nn.Module): 69 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 70 | super().__init__() 71 | self.num_heads = num_heads 72 | head_dim = dim // num_heads 73 | self.scale = qk_scale or head_dim ** -0.5 74 | 75 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 76 | self.attn_drop = nn.Dropout(attn_drop) 77 | self.proj = nn.Linear(dim, dim) 78 | self.proj_drop = nn.Dropout(proj_drop) 79 | 80 | def forward(self, x): 81 | B, N, C = x.shape 82 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 83 | q, k, v = qkv[0], qkv[1], qkv[2] 84 | 85 | attn = (q @ k.transpose(-2, -1)) * self.scale 86 | attn = attn.softmax(dim=-1) 87 | attn = self.attn_drop(attn) 88 | 89 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x, attn 93 | 94 | 95 | class Block(nn.Module): 96 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 97 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 98 | super().__init__() 99 | self.norm1 = norm_layer(dim) 100 | self.attn = Attention( 101 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 106 | 107 | def forward(self, x, return_attention=False): 108 | y, attn = self.attn(self.norm1(x)) 109 | if return_attention: 110 | return attn 111 | x = x + self.drop_path(y) 112 | x = x + self.drop_path(self.mlp(self.norm2(x))) 113 | return x 114 | 115 | 116 | class PatchEmbed(nn.Module): 117 | """ Image to Patch Embedding 118 | """ 119 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 120 | super().__init__() 121 | num_patches = (img_size // patch_size) * (img_size // patch_size) 122 | self.img_size = img_size 123 | self.patch_size = patch_size 124 | self.num_patches = num_patches 125 | 126 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 127 | 128 | def forward(self, x): 129 | B, C, H, W = x.shape 130 | x = self.proj(x).flatten(2).transpose(1, 2) 131 | return x 132 | 133 | 134 | class VisionTransformer(nn.Module): 135 | """ Vision Transformer """ 136 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 137 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 138 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 139 | super().__init__() 140 | self.num_features = self.embed_dim = embed_dim 141 | 142 | self.patch_embed = PatchEmbed( 143 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 144 | num_patches = self.patch_embed.num_patches 145 | 146 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 147 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 148 | self.pos_drop = nn.Dropout(p=drop_rate) 149 | 150 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 151 | self.blocks = nn.ModuleList([ 152 | Block( 153 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 154 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 155 | for i in range(depth)]) 156 | self.norm = norm_layer(embed_dim) 157 | 158 | # Classifier head 159 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 160 | 161 | trunc_normal_(self.pos_embed, std=.02) 162 | trunc_normal_(self.cls_token, std=.02) 163 | self.apply(self._init_weights) 164 | 165 | def _init_weights(self, m): 166 | if isinstance(m, nn.Linear): 167 | trunc_normal_(m.weight, std=.02) 168 | if isinstance(m, nn.Linear) and m.bias is not None: 169 | nn.init.constant_(m.bias, 0) 170 | elif isinstance(m, nn.LayerNorm): 171 | nn.init.constant_(m.bias, 0) 172 | nn.init.constant_(m.weight, 1.0) 173 | 174 | def interpolate_pos_encoding(self, x, w, h): 175 | npatch = x.shape[1] - 1 176 | N = self.pos_embed.shape[1] - 1 177 | if npatch == N and w == h: 178 | return self.pos_embed 179 | class_pos_embed = self.pos_embed[:, 0] 180 | patch_pos_embed = self.pos_embed[:, 1:] 181 | dim = x.shape[-1] 182 | w0 = w // self.patch_embed.patch_size 183 | h0 = h // self.patch_embed.patch_size 184 | # we add a small number to avoid floating point error in the interpolation 185 | # see discussion at https://github.com/facebookresearch/dino/issues/8 186 | w0, h0 = w0 + 0.1, h0 + 0.1 187 | patch_pos_embed = nn.functional.interpolate( 188 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 189 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 190 | mode='bicubic', 191 | ) 192 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 193 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 194 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 195 | 196 | def prepare_tokens(self, x): 197 | B, nc, w, h = x.shape 198 | x = self.patch_embed(x) # patch linear embedding 199 | 200 | # add the [CLS] token to the embed patch tokens 201 | cls_tokens = self.cls_token.expand(B, -1, -1) 202 | x = torch.cat((cls_tokens, x), dim=1) 203 | 204 | # add positional encoding to each token 205 | x = x + self.interpolate_pos_encoding(x, w, h) 206 | 207 | return self.pos_drop(x) 208 | 209 | def forward(self, x): 210 | x = self.prepare_tokens(x) 211 | for blk in self.blocks: 212 | x = blk(x) 213 | x = self.norm(x) 214 | return x[:, 0] 215 | 216 | def get_intermediate_layers(self, x, n=1): 217 | x = self.prepare_tokens(x) 218 | # we return the output tokens from the `n` last blocks 219 | output = [] 220 | for i, blk in enumerate(self.blocks): 221 | x = blk(x) 222 | if len(self.blocks) - i <= n: 223 | output.append(self.norm(x)) 224 | return output 225 | 226 | 227 | def vit_tiny(patch_size=16, **kwargs): 228 | model = VisionTransformer( 229 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 230 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 231 | return model 232 | 233 | 234 | def vit_small(patch_size=16, **kwargs): 235 | model = VisionTransformer( 236 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 237 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 238 | return model 239 | 240 | 241 | def vit_base(patch_size=16, **kwargs): 242 | model = VisionTransformer( 243 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 244 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 245 | return model 246 | 247 | 248 | class DINOHead(nn.Module): 249 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 250 | super().__init__() 251 | nlayers = max(nlayers, 1) 252 | if nlayers == 1: 253 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 254 | else: 255 | layers = [nn.Linear(in_dim, hidden_dim)] 256 | if use_bn: 257 | layers.append(nn.BatchNorm1d(hidden_dim)) 258 | layers.append(nn.GELU()) 259 | for _ in range(nlayers - 2): 260 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 261 | if use_bn: 262 | layers.append(nn.BatchNorm1d(hidden_dim)) 263 | layers.append(nn.GELU()) 264 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 265 | self.mlp = nn.Sequential(*layers) 266 | self.apply(self._init_weights) 267 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 268 | self.last_layer.weight_g.data.fill_(1) 269 | if norm_last_layer: 270 | self.last_layer.weight_g.requires_grad = False 271 | 272 | def _init_weights(self, m): 273 | if isinstance(m, nn.Linear): 274 | trunc_normal_(m.weight, std=.02) 275 | if isinstance(m, nn.Linear) and m.bias is not None: 276 | nn.init.constant_(m.bias, 0) 277 | 278 | def forward(self, x): 279 | x = self.mlp(x) 280 | x = nn.functional.normalize(x, dim=-1, p=2) 281 | x = self.last_layer(x) 282 | return x 283 | 284 | 285 | class MultiDINOHead(nn.Module): 286 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, 287 | bottleneck_dim=256): 288 | super().__init__() 289 | nlayers = max(nlayers, 1) 290 | if nlayers == 1: 291 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 292 | self.aux_mlp = nn.Linear(in_dim, bottleneck_dim) 293 | else: 294 | layers = [nn.Linear(in_dim, hidden_dim)] 295 | if use_bn: 296 | layers.append(nn.BatchNorm1d(hidden_dim)) 297 | layers.append(nn.GELU()) 298 | for _ in range(nlayers - 2): 299 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 300 | if use_bn: 301 | layers.append(nn.BatchNorm1d(hidden_dim)) 302 | layers.append(nn.GELU()) 303 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 304 | self.mlp = nn.Sequential(*layers) 305 | 306 | aux_layers = [nn.Linear(in_dim, hidden_dim)] 307 | if use_bn: 308 | aux_layers.append(nn.BatchNorm1d(hidden_dim)) 309 | aux_layers.append(nn.GELU()) 310 | for _ in range(nlayers - 2): 311 | aux_layers.append(nn.Linear(hidden_dim, hidden_dim)) 312 | if use_bn: 313 | aux_layers.append(nn.BatchNorm1d(hidden_dim)) 314 | aux_layers.append(nn.GELU()) 315 | aux_layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 316 | self.aux_mlp = nn.Sequential(*aux_layers) 317 | 318 | self.apply(self._init_weights) 319 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 320 | self.last_layer.weight_g.data.fill_(1) 321 | if norm_last_layer: 322 | self.last_layer.weight_g.requires_grad = False 323 | 324 | self.aux_last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 325 | self.aux_last_layer.weight_g.data.fill_(1) 326 | if norm_last_layer: 327 | self.aux_last_layer.weight_g.requires_grad = False 328 | 329 | def _init_weights(self, m): 330 | if isinstance(m, nn.Linear): 331 | trunc_normal_(m.weight, std=.02) 332 | if isinstance(m, nn.Linear) and m.bias is not None: 333 | nn.init.constant_(m.bias, 0) 334 | 335 | def forward(self, x): 336 | rgb_x, aux_x = x[0], x[1] 337 | 338 | rgb_x = self.mlp(rgb_x) 339 | rgb_x = nn.functional.normalize(rgb_x, dim=-1, p=2) 340 | rgb_x = self.last_layer(rgb_x) 341 | 342 | aux_x = self.aux_mlp(aux_x) 343 | aux_x = nn.functional.normalize(aux_x, dim=-1, p=2) 344 | aux_x = self.aux_last_layer(aux_x) 345 | return rgb_x, aux_x 346 | -------------------------------------------------------------------------------- /datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import logging 4 | import numpy as np 5 | import os 6 | import random 7 | import time 8 | from collections import defaultdict 9 | import cv2 10 | import torch 11 | from fvcore.common.file_io import PathManager 12 | from torch.utils.data.distributed import DistributedSampler 13 | 14 | import datasets.transform as transform 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def retry_load_images(image_paths, retry=10, backend="pytorch"): 20 | """ 21 | This function is to load images with support of retrying for failed load. 22 | 23 | Args: 24 | image_paths (list): paths of images needed to be loaded. 25 | retry (int, optional): maximum time of loading retrying. Defaults to 10. 26 | backend (str): `pytorch` or `cv2`. 27 | 28 | Returns: 29 | imgs (list): list of loaded images. 30 | """ 31 | for i in range(retry): 32 | imgs = [] 33 | for image_path in image_paths: 34 | with PathManager.open(image_path, "rb") as f: 35 | img_str = np.frombuffer(f.read(), np.uint8) 36 | img = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR) 37 | imgs.append(img) 38 | 39 | if all(img is not None for img in imgs): 40 | if backend == "pytorch": 41 | imgs = torch.as_tensor(np.stack(imgs)) 42 | return imgs 43 | else: 44 | logger.warn("Reading failed. Will retry.") 45 | time.sleep(1.0) 46 | if i == retry - 1: 47 | raise Exception("Failed to load images {}".format(image_paths)) 48 | 49 | 50 | def get_sequence(center_idx, half_len, sample_rate, num_frames): 51 | """ 52 | Sample frames among the corresponding clip. 53 | 54 | Args: 55 | center_idx (int): center frame idx for current clip 56 | half_len (int): half of the clip length 57 | sample_rate (int): sampling rate for sampling frames inside of the clip 58 | num_frames (int): number of expected sampled frames 59 | 60 | Returns: 61 | seq (list): list of indexes of sampled frames in this clip. 62 | """ 63 | seq = list(range(center_idx - half_len, center_idx + half_len, sample_rate)) 64 | 65 | for seq_idx in range(len(seq)): 66 | if seq[seq_idx] < 0: 67 | seq[seq_idx] = 0 68 | elif seq[seq_idx] >= num_frames: 69 | seq[seq_idx] = num_frames - 1 70 | return seq 71 | 72 | 73 | def pack_pathway_output(cfg, frames): 74 | """ 75 | Prepare output as a list of tensors. Each tensor corresponding to a 76 | unique pathway. 77 | Args: 78 | frames (tensor): frames of images sampled from the video. The 79 | dimension is `channel` x `num frames` x `height` x `width`. 80 | Returns: 81 | frame_list (list): list of tensors with the dimension of 82 | `channel` x `num frames` x `height` x `width`. 83 | """ 84 | if cfg.DATA.REVERSE_INPUT_CHANNEL: 85 | frames = frames[[2, 1, 0], :, :, :] 86 | if cfg.MODEL.ARCH in cfg.MODEL.SINGLE_PATHWAY_ARCH: 87 | frame_list = [frames] 88 | elif cfg.MODEL.ARCH in cfg.MODEL.MULTI_PATHWAY_ARCH: 89 | fast_pathway = frames 90 | # Perform temporal sampling from the fast pathway. 91 | slow_pathway = torch.index_select( 92 | frames, 93 | 1, 94 | torch.linspace( 95 | 0, frames.shape[1] - 1, frames.shape[1] // cfg.SLOWFAST.ALPHA 96 | ).long(), 97 | ) 98 | frame_list = [slow_pathway, fast_pathway] 99 | else: 100 | raise NotImplementedError( 101 | "Model arch {} is not in {}".format( 102 | cfg.MODEL.ARCH, 103 | cfg.MODEL.SINGLE_PATHWAY_ARCH + cfg.MODEL.MULTI_PATHWAY_ARCH, 104 | ) 105 | ) 106 | return frame_list 107 | 108 | 109 | def spatial_sampling( 110 | frames, 111 | spatial_idx=-1, 112 | min_scale=256, 113 | max_scale=320, 114 | crop_size=224, 115 | random_horizontal_flip=True, 116 | inverse_uniform_sampling=False, 117 | ): 118 | """ 119 | Perform spatial sampling on the given video frames. If spatial_idx is 120 | -1, perform random scale, random crop, and random flip on the given 121 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 122 | with the given spatial_idx. 123 | Args: 124 | frames (tensor): frames of images sampled from the video. The 125 | dimension is `num frames` x `height` x `width` x `channel`. 126 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 127 | or 2, perform left, center, right crop if width is larger than 128 | height, and perform top, center, buttom crop if height is larger 129 | than width. 130 | min_scale (int): the minimal size of scaling. 131 | max_scale (int): the maximal size of scaling. 132 | crop_size (int): the size of height and width used to crop the 133 | frames. 134 | inverse_uniform_sampling (bool): if True, sample uniformly in 135 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 136 | scale. If False, take a uniform sample from [min_scale, 137 | max_scale]. 138 | Returns: 139 | frames (tensor): spatially sampled frames. 140 | """ 141 | assert spatial_idx in [-1, 0, 1, 2] 142 | if spatial_idx == -1: 143 | frames, _ = transform.random_short_side_scale_jitter( 144 | images=frames, 145 | min_size=min_scale, 146 | max_size=max_scale, 147 | inverse_uniform_sampling=inverse_uniform_sampling, 148 | ) 149 | frames, _ = transform.random_crop(frames, crop_size) 150 | if random_horizontal_flip: 151 | frames, _ = transform.horizontal_flip(0.5, frames) 152 | else: 153 | # The testing is deterministic and no jitter should be performed. 154 | # min_scale, max_scale, and crop_size are expect to be the same. 155 | #assert len({min_scale, max_scale, crop_size}) == 1 156 | frames, _ = transform.random_short_side_scale_jitter( 157 | frames, min_scale, max_scale 158 | ) 159 | frames, _ = transform.uniform_crop(frames, crop_size, spatial_idx) 160 | return frames 161 | 162 | def spatial_sampling_2crops( 163 | frames, 164 | spatial_idx=-1, 165 | min_scale=256, 166 | max_scale=320, 167 | crop_size=224, 168 | random_horizontal_flip=True, 169 | inverse_uniform_sampling=False, 170 | ): 171 | """ 172 | Perform spatial sampling on the given video frames. If spatial_idx is 173 | -1, perform random scale, random crop, and random flip on the given 174 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 175 | with the given spatial_idx. 176 | Args: 177 | frames (tensor): frames of images sampled from the video. The 178 | dimension is `num frames` x `height` x `width` x `channel`. 179 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 180 | or 2, perform left, center, right crop if width is larger than 181 | height, and perform top, center, buttom crop if height is larger 182 | than width. 183 | min_scale (int): the minimal size of scaling. 184 | max_scale (int): the maximal size of scaling. 185 | crop_size (int): the size of height and width used to crop the 186 | frames. 187 | inverse_uniform_sampling (bool): if True, sample uniformly in 188 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 189 | scale. If False, take a uniform sample from [min_scale, 190 | max_scale]. 191 | Returns: 192 | frames (tensor): spatially sampled frames. 193 | """ 194 | assert spatial_idx in [-1, 0, 1, 2] 195 | if spatial_idx == -1: 196 | frames, _ = transform.random_short_side_scale_jitter( 197 | images=frames, 198 | min_size=min_scale, 199 | max_size=max_scale, 200 | inverse_uniform_sampling=inverse_uniform_sampling, 201 | ) 202 | frames, _ = transform.random_crop(frames, crop_size) 203 | if random_horizontal_flip: 204 | frames, _ = transform.horizontal_flip(0.5, frames) 205 | else: 206 | # The testing is deterministic and no jitter should be performed. 207 | # min_scale, max_scale, and crop_size are expect to be the same. 208 | #assert len({min_scale, max_scale, crop_size}) == 1 209 | frames, _ = transform.random_short_side_scale_jitter( 210 | frames, min_scale, max_scale 211 | ) 212 | frames, _ = transform.uniform_crop_2crops(frames, crop_size, spatial_idx) 213 | return frames 214 | 215 | 216 | def as_binary_vector(labels, num_classes): 217 | """ 218 | Construct binary label vector given a list of label indices. 219 | Args: 220 | labels (list): The input label list. 221 | num_classes (int): Number of classes of the label vector. 222 | Returns: 223 | labels (numpy array): the resulting binary vector. 224 | """ 225 | label_arr = np.zeros((num_classes,)) 226 | 227 | for lbl in set(labels): 228 | label_arr[lbl] = 1.0 229 | return label_arr 230 | 231 | 232 | def aggregate_labels(label_list): 233 | """ 234 | Join a list of label list. 235 | Args: 236 | labels (list): The input label list. 237 | Returns: 238 | labels (list): The joint list of all lists in input. 239 | """ 240 | all_labels = [] 241 | for labels in label_list: 242 | for l in labels: 243 | all_labels.append(l) 244 | return list(set(all_labels)) 245 | 246 | 247 | def convert_to_video_level_labels(labels): 248 | """ 249 | Aggregate annotations from all frames of a video to form video-level labels. 250 | Args: 251 | labels (list): The input label list. 252 | Returns: 253 | labels (list): Same as input, but with each label replaced by 254 | a video-level one. 255 | """ 256 | for video_id in range(len(labels)): 257 | video_level_labels = aggregate_labels(labels[video_id]) 258 | for i in range(len(labels[video_id])): 259 | labels[video_id][i] = video_level_labels 260 | return labels 261 | 262 | 263 | def load_image_lists(frame_list_file, prefix="", return_list=False): 264 | """ 265 | Load image paths and labels from a "frame list". 266 | Each line of the frame list contains: 267 | `original_vido_id video_id frame_id path labels` 268 | Args: 269 | frame_list_file (string): path to the frame list. 270 | prefix (str): the prefix for the path. 271 | return_list (bool): if True, return a list. If False, return a dict. 272 | Returns: 273 | image_paths (list or dict): list of list containing path to each frame. 274 | If return_list is False, then return in a dict form. 275 | labels (list or dict): list of list containing label of each frame. 276 | If return_list is False, then return in a dict form. 277 | """ 278 | image_paths = defaultdict(list) 279 | labels = defaultdict(list) 280 | with PathManager.open(frame_list_file, "r") as f: 281 | assert f.readline().startswith("original_vido_id") 282 | for line in f: 283 | row = line.split() 284 | # original_vido_id video_id frame_id path labels 285 | assert len(row) == 5 286 | video_name = row[0] 287 | if prefix == "": 288 | path = row[3] 289 | else: 290 | path = os.path.join(prefix, row[3]) 291 | image_paths[video_name].append(path) 292 | frame_labels = row[-1].replace('"', "") 293 | if frame_labels != "": 294 | labels[video_name].append( 295 | [int(x) for x in frame_labels.split(",")] 296 | ) 297 | else: 298 | labels[video_name].append([]) 299 | 300 | if return_list: 301 | keys = image_paths.keys() 302 | image_paths = [image_paths[key] for key in keys] 303 | labels = [labels[key] for key in keys] 304 | return image_paths, labels 305 | return dict(image_paths), dict(labels) 306 | 307 | 308 | def tensor_normalize(tensor, mean, std): 309 | """ 310 | Normalize a given tensor by subtracting the mean and dividing the std. 311 | Args: 312 | tensor (tensor): tensor to normalize. 313 | mean (tensor or list): mean value to subtract. 314 | std (tensor or list): std to divide. 315 | """ 316 | if tensor.dtype == torch.uint8: 317 | tensor = tensor.float() 318 | tensor = tensor / 255.0 319 | if type(mean) == list: 320 | mean = torch.tensor(mean) 321 | if type(std) == list: 322 | std = torch.tensor(std) 323 | tensor = tensor - mean 324 | tensor = tensor / std 325 | return tensor 326 | 327 | 328 | def get_random_sampling_rate(long_cycle_sampling_rate, sampling_rate): 329 | """ 330 | When multigrid training uses a fewer number of frames, we randomly 331 | increase the sampling rate so that some clips cover the original span. 332 | """ 333 | if long_cycle_sampling_rate > 0: 334 | assert long_cycle_sampling_rate >= sampling_rate 335 | return random.randint(sampling_rate, long_cycle_sampling_rate) 336 | else: 337 | return sampling_rate 338 | 339 | 340 | def revert_tensor_normalize(tensor, mean, std): 341 | """ 342 | Revert normalization for a given tensor by multiplying by the std and adding the mean. 343 | Args: 344 | tensor (tensor): tensor to revert normalization. 345 | mean (tensor or list): mean value to add. 346 | std (tensor or list): std to multiply. 347 | """ 348 | if type(mean) == list: 349 | mean = torch.tensor(mean) 350 | if type(std) == list: 351 | std = torch.tensor(std) 352 | tensor = tensor * std 353 | tensor = tensor + mean 354 | return tensor 355 | 356 | 357 | def create_sampler(dataset, shuffle, cfg): 358 | """ 359 | Create sampler for the given dataset. 360 | Args: 361 | dataset (torch.utils.data.Dataset): the given dataset. 362 | shuffle (bool): set to ``True`` to have the data reshuffled 363 | at every epoch. 364 | cfg (CfgNode): configs. Details can be found in 365 | slowfast/config/defaults.py 366 | Returns: 367 | sampler (Sampler): the created sampler. 368 | """ 369 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 370 | 371 | return sampler 372 | 373 | 374 | def loader_worker_init_fn(dataset): 375 | """ 376 | Create init function passed to pytorch data loader. 377 | Args: 378 | dataset (torch.utils.data.Dataset): the given dataset. 379 | """ 380 | return None 381 | -------------------------------------------------------------------------------- /datasets/kinetics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import glob 3 | import os 4 | import random 5 | import warnings 6 | from PIL import Image 7 | import torch 8 | import torch.utils.data 9 | import torchvision 10 | import kornia 11 | 12 | from datasets.transform import resize 13 | from datasets.data_utils import get_random_sampling_rate, tensor_normalize, spatial_sampling, pack_pathway_output 14 | from datasets.decoder import decode 15 | from datasets.video_container import get_video_container 16 | from datasets.transform import VideoDataAugmentationDINO 17 | from einops import rearrange 18 | 19 | 20 | class Kinetics(torch.utils.data.Dataset): 21 | """ 22 | Kinetics video loader. Construct the Kinetics video loader, then sample 23 | clips from the videos. For training and validation, a single clip is 24 | randomly sampled from every video with random cropping, scaling, and 25 | flipping. For testing, multiple clips are uniformaly sampled from every 26 | video with uniform cropping. For uniform cropping, we take the left, center, 27 | and right crop if the width is larger than height, or take top, center, and 28 | bottom crop if the height is larger than the width. 29 | """ 30 | 31 | def __init__(self, cfg, mode, num_retries=10, get_flow=False): 32 | """ 33 | Construct the Kinetics video loader with a given csv file. The format of 34 | the csv file is: 35 | ``` 36 | path_to_video_1 label_1 37 | path_to_video_2 label_2 38 | ... 39 | path_to_video_N label_N 40 | ``` 41 | Args: 42 | cfg (CfgNode): configs. 43 | mode (string): Options includes `train`, `val`, or `test` mode. 44 | For the train and val mode, the data loader will take data 45 | from the train or val set, and sample one clip per video. 46 | For the test mode, the data loader will take data from test set, 47 | and sample multiple clips per video. 48 | num_retries (int): number of retries. 49 | """ 50 | # Only support train, val, and test mode. 51 | assert mode in [ 52 | "train", 53 | "val", 54 | "test", 55 | ], "Split '{}' not supported for Kinetics".format(mode) 56 | self.mode = mode 57 | self.cfg = cfg 58 | if get_flow: 59 | assert mode == "train", "invalid: flow only for train mode" 60 | self.get_flow = get_flow 61 | 62 | self._video_meta = {} 63 | self._num_retries = num_retries 64 | # For training or validation mode, one single clip is sampled from every 65 | # video. For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every 66 | # video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from 67 | # the frames. 68 | if self.mode in ["train", "val"]: 69 | self._num_clips = 1 70 | elif self.mode in ["test"]: 71 | self._num_clips = ( 72 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS 73 | ) 74 | 75 | print("Constructing Kinetics {}...".format(mode)) 76 | self._construct_loader() 77 | 78 | def _construct_loader(self): 79 | """ 80 | Construct the video loader. 81 | """ 82 | path_to_file = os.path.join( 83 | self.cfg.DATA.PATH_TO_DATA_DIR, "{}.csv".format(self.mode) 84 | ) 85 | assert os.path.exists(path_to_file), "{} dir not found".format( 86 | path_to_file 87 | ) 88 | 89 | self._path_to_videos = [] 90 | self._labels = [] 91 | self._spatial_temporal_idx = [] 92 | with open(path_to_file, "r") as f: 93 | for clip_idx, path_label in enumerate(f.read().splitlines()): 94 | assert ( 95 | len(path_label.split(self.cfg.DATA.PATH_LABEL_SEPARATOR)) 96 | == 2 97 | ) 98 | path, label = path_label.split( 99 | self.cfg.DATA.PATH_LABEL_SEPARATOR 100 | ) 101 | for idx in range(self._num_clips): 102 | self._path_to_videos.append( 103 | os.path.join(self.cfg.DATA.PATH_PREFIX, path) 104 | ) 105 | self._labels.append(int(label)) 106 | self._spatial_temporal_idx.append(idx) 107 | self._video_meta[clip_idx * self._num_clips + idx] = {} 108 | assert ( 109 | len(self._path_to_videos) > 0 110 | ), "Failed to load Kinetics split {} from {}".format( 111 | self._split_idx, path_to_file 112 | ) 113 | print( 114 | "Constructing kinetics dataloader (size: {}) from {}".format( 115 | len(self._path_to_videos), path_to_file 116 | ) 117 | ) 118 | 119 | def __getitem__(self, index): 120 | """ 121 | Given the video index, return the list of frames, label, and video 122 | index if the video can be fetched and decoded successfully, otherwise 123 | repeatly find a random video that can be decoded as a replacement. 124 | Args: 125 | index (int): the video index provided by the pytorch sampler. 126 | Returns: 127 | frames (tensor): the frames of sampled from the video. The dimension 128 | is `channel` x `num frames` x `height` x `width`. 129 | label (int): the label of the current video. 130 | index (int): if the video provided by pytorch sampler can be 131 | decoded, then return the index of the video. If not, return the 132 | index of the video replacement that can be decoded. 133 | """ 134 | short_cycle_idx = None 135 | # When short cycle is used, input index is a tupple. 136 | if isinstance(index, tuple): 137 | index, short_cycle_idx = index 138 | 139 | if self.mode in ["train", "val"]: 140 | # -1 indicates random sampling. 141 | temporal_sample_index = -1 142 | spatial_sample_index = -1 143 | min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0] 144 | max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1] 145 | crop_size = self.cfg.DATA.TRAIN_CROP_SIZE 146 | if short_cycle_idx in [0, 1]: 147 | crop_size = int( 148 | round( 149 | self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx] 150 | * self.cfg.MULTIGRID.DEFAULT_S 151 | ) 152 | ) 153 | if self.cfg.MULTIGRID.DEFAULT_S > 0: 154 | # Decreasing the scale is equivalent to using a larger "span" 155 | # in a sampling grid. 156 | min_scale = int( 157 | round( 158 | float(min_scale) 159 | * crop_size 160 | / self.cfg.MULTIGRID.DEFAULT_S 161 | ) 162 | ) 163 | elif self.mode in ["test"]: 164 | temporal_sample_index = ( 165 | self._spatial_temporal_idx[index] 166 | // self.cfg.TEST.NUM_SPATIAL_CROPS 167 | ) 168 | # spatial_sample_index is in [0, 1, 2]. Corresponding to left, 169 | # center, or right if width is larger than height, and top, middle, 170 | # or bottom if height is larger than width. 171 | spatial_sample_index = ( 172 | ( 173 | self._spatial_temporal_idx[index] 174 | % self.cfg.TEST.NUM_SPATIAL_CROPS 175 | ) 176 | if self.cfg.TEST.NUM_SPATIAL_CROPS > 1 177 | else 1 178 | ) 179 | min_scale, max_scale, crop_size = ( 180 | [self.cfg.DATA.TEST_CROP_SIZE] * 3 181 | if self.cfg.TEST.NUM_SPATIAL_CROPS > 1 182 | else [self.cfg.DATA.TRAIN_JITTER_SCALES[0]] * 2 183 | + [self.cfg.DATA.TEST_CROP_SIZE] 184 | ) 185 | # The testing is deterministic and no jitter should be performed. 186 | # min_scale, max_scale, and crop_size are expect to be the same. 187 | assert len({min_scale, max_scale}) == 1 188 | else: 189 | raise NotImplementedError( 190 | "Does not support {} mode".format(self.mode) 191 | ) 192 | sampling_rate = get_random_sampling_rate( 193 | self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE, 194 | self.cfg.DATA.SAMPLING_RATE, 195 | ) 196 | # Try to decode and sample a clip from a video. If the video can not be 197 | # decoded, repeatly find a random video replacement that can be decoded. 198 | for i_try in range(self._num_retries): 199 | video_container = None 200 | try: 201 | video_container = get_video_container( 202 | self._path_to_videos[index], 203 | self.cfg.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE, 204 | self.cfg.DATA.DECODING_BACKEND, 205 | ) 206 | except Exception as e: 207 | print( 208 | "Failed to load video from {} with error {}".format( 209 | self._path_to_videos[index], e 210 | ) 211 | ) 212 | # Select a random video if the current video was not able to access. 213 | if video_container is None: 214 | warnings.warn( 215 | "Failed to meta load video idx {} from {}; trial {}".format( 216 | index, self._path_to_videos[index], i_try 217 | ) 218 | ) 219 | if self.mode not in ["test"] and i_try > self._num_retries // 2: 220 | # let's try another one 221 | index = random.randint(0, len(self._path_to_videos) - 1) 222 | continue 223 | 224 | # Decode video. Meta info is used to perform selective decoding. 225 | frames = decode( 226 | container=video_container, 227 | sampling_rate=sampling_rate, 228 | num_frames=self.cfg.DATA.NUM_FRAMES, 229 | clip_idx=temporal_sample_index, 230 | num_clips=self.cfg.TEST.NUM_ENSEMBLE_VIEWS, 231 | video_meta=self._video_meta[index], 232 | target_fps=self.cfg.DATA.TARGET_FPS, 233 | backend=self.cfg.DATA.DECODING_BACKEND, 234 | max_spatial_scale=min_scale, 235 | temporal_aug=self.mode == "train" and not self.cfg.DATA.NO_RGB_AUG, 236 | two_token=self.cfg.MODEL.TWO_TOKEN, 237 | rand_fr=self.cfg.DATA.RAND_FR 238 | ) 239 | 240 | # If decoding failed (wrong format, video is too short, and etc), 241 | # select another video. 242 | if frames is None: 243 | warnings.warn( 244 | "Failed to decode video idx {} from {}; trial {}".format( 245 | index, self._path_to_videos[index], i_try 246 | ) 247 | ) 248 | if self.mode not in ["test"] and i_try > self._num_retries // 2: 249 | # let's try another one 250 | index = random.randint(0, len(self._path_to_videos) - 1) 251 | continue 252 | 253 | label = self._labels[index] 254 | 255 | if self.mode in ["test", "val"] or self.cfg.DATA.NO_RGB_AUG: 256 | # Perform color normalization. 257 | frames = tensor_normalize( 258 | frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD 259 | ) 260 | 261 | # T H W C -> C T H W. 262 | frames = frames.permute(3, 0, 1, 2) 263 | 264 | # Perform data augmentation. 265 | frames = spatial_sampling( 266 | frames, 267 | spatial_idx=spatial_sample_index, 268 | min_scale=min_scale, 269 | max_scale=max_scale, 270 | crop_size=crop_size, 271 | random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP, 272 | inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE, 273 | ) 274 | 275 | if not self.cfg.MODEL.ARCH in ['vit']: 276 | frames = pack_pathway_output(self.cfg, frames) 277 | else: 278 | # Perform temporal sampling from the fast pathway. 279 | frames = torch.index_select( 280 | frames, 281 | 1, 282 | torch.linspace( 283 | 0, frames.shape[1] - 1, self.cfg.DATA.NUM_FRAMES 284 | 285 | ).long(), 286 | ) 287 | 288 | else: 289 | # T H W C -> T C H W. 290 | frames = [rearrange(x, "t h w c -> t c h w") for x in frames] 291 | 292 | # Perform data augmentation. 293 | augmentation = VideoDataAugmentationDINO() 294 | frames = augmentation(frames, from_list=True, no_aug=self.cfg.DATA.NO_SPATIAL, 295 | two_token=self.cfg.MODEL.TWO_TOKEN) 296 | 297 | # T C H W -> C T H W. 298 | frames = [rearrange(x, "t c h w -> c t h w") for x in frames] 299 | 300 | # Perform temporal sampling from the fast pathway. 301 | frames = [torch.index_select( 302 | x, 303 | 1, 304 | torch.linspace( 305 | 0, x.shape[1] - 1, x.shape[1] if self.cfg.DATA.RAND_FR else self.cfg.DATA.NUM_FRAMES 306 | 307 | ).long(), 308 | ) for x in frames] 309 | 310 | meta_data = {} 311 | if self.get_flow: 312 | assert self.mode == "train", "flow only for train" 313 | try: 314 | flow_path = self._path_to_videos[index].replace("train_d256", "train_flow")[:-4] 315 | flow_tensor = self.get_flow_from_folder(flow_path) 316 | flow_tensor = kornia.filters.sobel(flow_tensor) 317 | if self.cfg.DATA.NO_FLOW_AUG: 318 | flow_tensor = resize(flow_tensor, size=self.cfg.DATA.CROP_SIZE, mode="bicubic") 319 | flow_tensor = [x for x in flow_tensor] 320 | else: 321 | flow_tensor = augmentation(flow_tensor) 322 | flow_tensor = [rearrange(x, "t c h w -> c t h w") for x in flow_tensor] 323 | meta_data["flow"] = flow_tensor 324 | except Exception as e: 325 | print(e) 326 | continue 327 | return frames, label, index, meta_data 328 | 329 | else: 330 | raise RuntimeError( 331 | "Failed to fetch video after {} retries.".format( 332 | self._num_retries 333 | ) 334 | ) 335 | 336 | def __len__(self): 337 | """ 338 | Returns: 339 | (int): the number of videos in the dataset. 340 | """ 341 | return len(self._path_to_videos) 342 | 343 | @staticmethod 344 | def get_flow_from_folder(dir_path): 345 | flow_image_list = sorted(glob.glob(f"{dir_path}/*.jpg")) 346 | flow_image_list = [Image.open(im_path) for im_path in flow_image_list] 347 | flow_image_list = [torchvision.transforms.functional.to_tensor(im_path) for im_path in flow_image_list] 348 | return torch.stack(flow_image_list, dim=0) 349 | 350 | 351 | if __name__ == '__main__': 352 | 353 | # import torch 354 | # from timesformer.datasets import Kinetics 355 | from utils.parser import parse_args, load_config 356 | from tqdm import tqdm 357 | 358 | args = parse_args() 359 | args.cfg_file = "/home/kanchanaranasinghe/repo/timesformer/configs/Kinetics/TimeSformer_divST_8x32_224.yaml" 360 | config = load_config(args) 361 | config.DATA.PATH_TO_DATA_DIR = "/home/kanchanaranasinghe/data/kinetics400/new_annotations" 362 | # config.DATA.PATH_TO_DATA_DIR = "/home/kanchanaranasinghe/data/kinetics400/k400-mini" 363 | config.DATA.PATH_PREFIX = "/home/kanchanaranasinghe/data/kinetics400" 364 | # dataset = Kinetics(cfg=config, mode="val", num_retries=10) 365 | dataset = Kinetics(cfg=config, mode="train", num_retries=10, get_flow=True) 366 | print(f"Loaded train dataset of length: {len(dataset)}") 367 | dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=4) 368 | for idx, i in enumerate(dataloader): 369 | print([x.shape for x in i[0]], i[1:3], [x.shape for x in i[3]['flow']]) 370 | break 371 | 372 | do_vis = False 373 | if do_vis: 374 | from PIL import Image 375 | from transform import undo_normalize 376 | 377 | vis_path = "/home/kanchanaranasinghe/data/kinetics400/vis/spatial_aug" 378 | 379 | for aug_idx in range(len(i[0])): 380 | temp = i[0][aug_idx][3].permute(1, 2, 3, 0) 381 | temp = undo_normalize(temp, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 382 | for idx in range(temp.shape[0]): 383 | im = Image.fromarray(temp[idx].numpy()) 384 | im.resize((224, 224)).save(f"{vis_path}/aug_{aug_idx}_fr_{idx:02d}.jpg") 385 | -------------------------------------------------------------------------------- /eval_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import json 16 | import os 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | from pathlib import Path 20 | from torch import nn 21 | from tqdm import tqdm 22 | 23 | from datasets import UCF101, HMDB51, Kinetics 24 | from models import get_vit_base_patch16_224, get_aux_token_vit, SwinTransformer3D 25 | from utils import utils 26 | from utils.meters import TestMeter 27 | from utils.parser import load_config 28 | 29 | 30 | def eval_linear(args): 31 | utils.init_distributed_mode(args) 32 | print("git:\n {}\n".format(utils.get_sha())) 33 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 34 | cudnn.benchmark = True 35 | os.makedirs(args.output_dir, exist_ok=True) 36 | json.dump(vars(args), open(f"{args.output_dir}/config.json", "w"), indent=4) 37 | 38 | # ============ preparing data ... ============ 39 | config = load_config(args) 40 | # config.DATA.PATH_TO_DATA_DIR = f"{os.path.expanduser('~')}/repo/mmaction2/data/{args.dataset}/splits" 41 | # config.DATA.PATH_PREFIX = f"{os.path.expanduser('~')}/repo/mmaction2/data/{args.dataset}/videos" 42 | config.TEST.NUM_SPATIAL_CROPS = 1 43 | if args.dataset == "ucf101": 44 | dataset_train = UCF101(cfg=config, mode="train", num_retries=10) 45 | dataset_val = UCF101(cfg=config, mode="val", num_retries=10) 46 | config.TEST.NUM_SPATIAL_CROPS = 3 47 | multi_crop_val = UCF101(cfg=config, mode="val", num_retries=10) 48 | elif args.dataset == "hmdb51": 49 | dataset_train = HMDB51(cfg=config, mode="train", num_retries=10) 50 | dataset_val = HMDB51(cfg=config, mode="val", num_retries=10) 51 | config.TEST.NUM_SPATIAL_CROPS = 3 52 | multi_crop_val = HMDB51(cfg=config, mode="val", num_retries=10) 53 | elif args.dataset == "kinetics400": 54 | dataset_train = Kinetics(cfg=config, mode="train", num_retries=10) 55 | dataset_val = Kinetics(cfg=config, mode="val", num_retries=10) 56 | config.TEST.NUM_SPATIAL_CROPS = 3 57 | multi_crop_val = Kinetics(cfg=config, mode="val", num_retries=10) 58 | else: 59 | raise NotImplementedError(f"invalid dataset: {args.dataset}") 60 | 61 | sampler = torch.utils.data.distributed.DistributedSampler(dataset_train) 62 | train_loader = torch.utils.data.DataLoader( 63 | dataset_train, 64 | sampler=sampler, 65 | batch_size=args.batch_size_per_gpu, 66 | num_workers=args.num_workers, 67 | pin_memory=True, 68 | ) 69 | val_loader = torch.utils.data.DataLoader( 70 | dataset_val, 71 | batch_size=args.batch_size_per_gpu, 72 | num_workers=args.num_workers, 73 | pin_memory=True, 74 | ) 75 | 76 | multi_crop_val_loader = torch.utils.data.DataLoader( 77 | multi_crop_val, 78 | batch_size=args.batch_size_per_gpu, 79 | num_workers=args.num_workers, 80 | pin_memory=True, 81 | ) 82 | 83 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") 84 | 85 | # ============ building network ... ============ 86 | if config.DATA.USE_FLOW or config.MODEL.TWO_TOKEN: 87 | model = get_aux_token_vit(cfg=config, no_head=True) 88 | model_embed_dim = 2 * model.embed_dim 89 | else: 90 | if args.arch == "vit_base": 91 | model = get_vit_base_patch16_224(cfg=config, no_head=True) 92 | model_embed_dim = model.embed_dim 93 | elif args.arch == "swin": 94 | model = SwinTransformer3D(depths=[2, 2, 18, 2], embed_dim=128, num_heads=[4, 8, 16, 32]) 95 | model_embed_dim = 1024 96 | else: 97 | raise Exception(f"invalid model: {args.arch}") 98 | 99 | ckpt = torch.load(args.pretrained_weights) 100 | # select_ckpt = 'motion_teacher' if args.use_flow else "teacher" 101 | if "teacher" in ckpt: 102 | ckpt = ckpt["teacher"] 103 | renamed_checkpoint = {x[len("backbone."):]: y for x, y in ckpt.items() if x.startswith("backbone.")} 104 | msg = model.load_state_dict(renamed_checkpoint, strict=False) 105 | print(f"Loaded model with msg: {msg}") 106 | model.cuda() 107 | model.eval() 108 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") 109 | # load weights to evaluate 110 | 111 | linear_classifier = LinearClassifier(model_embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)), 112 | num_labels=args.num_labels) 113 | linear_classifier = linear_classifier.cuda() 114 | linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu]) 115 | 116 | if args.lc_pretrained_weights: 117 | lc_ckpt = torch.load(args.lc_pretrained_weights) 118 | msg = linear_classifier.load_state_dict(lc_ckpt['state_dict']) 119 | print(f"Loaded linear classifier weights with msg: {msg}") 120 | test_stats = validate_network_multi_view(multi_crop_val_loader, model, linear_classifier, args.n_last_blocks, 121 | args.avgpool_patchtokens, config) 122 | # test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens) 123 | print(test_stats) 124 | return True 125 | 126 | 127 | # set optimizer 128 | optimizer = torch.optim.SGD( 129 | linear_classifier.parameters(), 130 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule 131 | momentum=0.9, 132 | weight_decay=0, # we do not apply weight decay 133 | ) 134 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0) 135 | 136 | # Optionally resume from a checkpoint 137 | to_restore = {"epoch": 0, "best_acc": 0.} 138 | utils.restart_from_checkpoint( 139 | os.path.join(args.output_dir, "checkpoint.pth.tar"), 140 | run_variables=to_restore, 141 | state_dict=linear_classifier, 142 | optimizer=optimizer, 143 | scheduler=scheduler, 144 | ) 145 | start_epoch = to_restore["epoch"] 146 | best_acc = to_restore["best_acc"] 147 | 148 | for epoch in range(start_epoch, args.epochs): 149 | train_loader.sampler.set_epoch(epoch) 150 | 151 | train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens) 152 | scheduler.step() 153 | 154 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 155 | 'epoch': epoch} 156 | if epoch % args.val_freq == 0 or epoch == args.epochs - 1: 157 | test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens) 158 | print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 159 | best_acc = max(best_acc, test_stats["acc1"]) 160 | print(f'Max accuracy so far: {best_acc:.2f}%') 161 | log_stats = {**{k: v for k, v in log_stats.items()}, 162 | **{f'test_{k}': v for k, v in test_stats.items()}} 163 | if utils.is_main_process(): 164 | with (Path(args.output_dir) / "log.txt").open("a") as f: 165 | f.write(json.dumps(log_stats) + "\n") 166 | save_dict = { 167 | "epoch": epoch + 1, 168 | "state_dict": linear_classifier.state_dict(), 169 | "optimizer": optimizer.state_dict(), 170 | "scheduler": scheduler.state_dict(), 171 | "best_acc": best_acc, 172 | } 173 | torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar")) 174 | 175 | test_stats = validate_network_multi_view(multi_crop_val_loader, model, linear_classifier, args.n_last_blocks, 176 | args.avgpool_patchtokens, config) 177 | print(test_stats) 178 | 179 | print("Training of the supervised linear classifier on frozen features completed.\n" 180 | "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)) 181 | 182 | 183 | def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool): 184 | linear_classifier.train() 185 | metric_logger = utils.MetricLogger(delimiter=" ") 186 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 187 | header = 'Epoch: [{}]'.format(epoch) 188 | for (inp, target, sample_idx, meta) in metric_logger.log_every(loader, 20, header): 189 | # move to gpu 190 | inp = inp.cuda(non_blocking=True) 191 | target = target.cuda(non_blocking=True) 192 | 193 | # forward 194 | with torch.no_grad(): 195 | # intermediate_output = model.get_intermediate_layers(inp, n) 196 | # output = [x[:, 0] for x in intermediate_output] 197 | # if avgpool: 198 | # output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1)) 199 | # output = torch.cat(output, dim=-1) 200 | 201 | output = model(inp) 202 | 203 | output = linear_classifier(output) 204 | 205 | # compute cross entropy loss 206 | loss = nn.CrossEntropyLoss()(output, target) 207 | 208 | # compute the gradients 209 | optimizer.zero_grad() 210 | loss.backward() 211 | 212 | # step 213 | optimizer.step() 214 | 215 | # log 216 | torch.cuda.synchronize() 217 | metric_logger.update(loss=loss.item()) 218 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 219 | # gather the stats from all processes 220 | metric_logger.synchronize_between_processes() 221 | print("Averaged stats:", metric_logger) 222 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 223 | 224 | 225 | @torch.no_grad() 226 | def validate_network(val_loader, model, linear_classifier, n, avgpool): 227 | linear_classifier.eval() 228 | metric_logger = utils.MetricLogger(delimiter=" ") 229 | header = 'Test:' 230 | for (inp, target, sample_idx, meta) in metric_logger.log_every(val_loader, 20, header): 231 | # move to gpu 232 | inp = inp.cuda(non_blocking=True) 233 | target = target.cuda(non_blocking=True) 234 | 235 | # forward 236 | with torch.no_grad(): 237 | # intermediate_output = model.get_intermediate_layers(inp, n) 238 | # output = [x[:, 0] for x in intermediate_output] 239 | # if avgpool: 240 | # output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1)) 241 | # output = torch.cat(output, dim=-1) 242 | output = model(inp) 243 | output = linear_classifier(output) 244 | loss = nn.CrossEntropyLoss()(output, target) 245 | 246 | if linear_classifier.module.num_labels >= 5: 247 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 248 | else: 249 | acc1, = utils.accuracy(output, target, topk=(1,)) 250 | 251 | batch_size = inp.shape[0] 252 | metric_logger.update(loss=loss.item()) 253 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 254 | if linear_classifier.module.num_labels >= 5: 255 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 256 | if linear_classifier.module.num_labels >= 5: 257 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 258 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 259 | else: 260 | print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}' 261 | .format(top1=metric_logger.acc1, losses=metric_logger.loss)) 262 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 263 | 264 | 265 | @torch.no_grad() 266 | def validate_network_multi_view(val_loader, model, linear_classifier, n, avgpool, cfg): 267 | linear_classifier.eval() 268 | test_meter = TestMeter( 269 | len(val_loader.dataset) 270 | // (cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS), 271 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS, 272 | args.num_labels, 273 | len(val_loader), 274 | cfg.DATA.MULTI_LABEL, 275 | cfg.DATA.ENSEMBLE_METHOD, 276 | ) 277 | test_meter.iter_tic() 278 | 279 | for cur_iter, (inp, target, sample_idx, meta) in tqdm(enumerate(val_loader), total=len(val_loader)): 280 | # move to gpu 281 | inp = inp.cuda(non_blocking=True) 282 | # target = target.cuda(non_blocking=True) 283 | test_meter.data_toc() 284 | 285 | # forward 286 | with torch.no_grad(): 287 | output = model(inp) 288 | output = linear_classifier(output) 289 | 290 | output = output.cpu() 291 | target = target.cpu() 292 | sample_idx = sample_idx.cpu() 293 | 294 | test_meter.iter_toc() 295 | # Update and log stats. 296 | test_meter.update_stats( 297 | output.detach(), target.detach(), sample_idx.detach() 298 | ) 299 | test_meter.log_iter_stats(cur_iter) 300 | 301 | test_meter.iter_tic() 302 | 303 | test_meter.finalize_metrics() 304 | return test_meter.stats 305 | 306 | 307 | class LinearClassifier(nn.Module): 308 | """Linear layer to train on top of frozen features""" 309 | def __init__(self, dim, num_labels=1000): 310 | super(LinearClassifier, self).__init__() 311 | self.num_labels = num_labels 312 | self.linear = nn.Linear(dim, num_labels) 313 | self.linear.weight.data.normal_(mean=0.0, std=0.01) 314 | self.linear.bias.data.zero_() 315 | 316 | def forward(self, x): 317 | # flatten 318 | x = x.view(x.size(0), -1) 319 | 320 | # linear layer 321 | return self.linear(x) 322 | 323 | 324 | if __name__ == '__main__': 325 | parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet') 326 | parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens 327 | for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""") 328 | parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag, 329 | help="""Whether ot not to concatenate the global average pooled features to the [CLS] token. 330 | We typically set this to False for ViT-Small and to True with ViT-Base.""") 331 | parser.add_argument('--arch', default='vit_small', type=str, 332 | choices=['vit_tiny', 'vit_small', 'vit_base', 'swin'], 333 | help='Architecture (support only ViT atm).') 334 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 335 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 336 | parser.add_argument('--lc_pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 337 | parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")') 338 | parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.') 339 | parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of 340 | training (highest LR used during training). The learning rate is linearly scaled 341 | with the batch size, and specified here for a reference batch size of 256. 342 | We recommend tweaking the LR depending on the checkpoint evaluated.""") 343 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') 344 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 345 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 346 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 347 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) 348 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 349 | parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") 350 | parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints') 351 | parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier') 352 | parser.add_argument('--dataset', default="ucf101", help='Dataset: ucf101 / hmdb51') 353 | parser.add_argument('--use_flow', default=False, type=utils.bool_flag, help="use flow teacher") 354 | 355 | # config file 356 | parser.add_argument("--cfg", dest="cfg_file", help="Path to the config file", type=str, 357 | default="models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml") 358 | parser.add_argument("--opts", help="See utils/defaults.py for all options", default=None, nargs=argparse.REMAINDER) 359 | 360 | args = parser.parse_args() 361 | eval_linear(args) 362 | --------------------------------------------------------------------------------