├── .github
└── intro.png
├── models
├── __init__.py
├── configs
│ └── Kinetics
│ │ ├── TimeSformer_spaceOnly_8x32_224.yaml
│ │ ├── TimeSformer_divST_16x16_448.yaml
│ │ ├── TimeSformer_divST_8x32_224.yaml
│ │ ├── TimeSformer_divST_96x4_224.yaml
│ │ ├── TimeSformer_jointST_8x32_224.yaml
│ │ ├── TimeSformer_divST_8x32_224_4gpus.yaml
│ │ ├── TimeSformer_divST_8x32_224_TEST.yaml
│ │ ├── SLOWFAST_4x16_R50.yaml
│ │ ├── SLOWFAST_8x8_R50.yaml
│ │ └── SLOWFAST_8x8_R101.yaml
├── vit_utils.py
├── s3d.py
└── helpers.py
├── datasets
├── __init__.py
├── preprocessing
│ ├── kinetics_mini.py
│ ├── verify_file_list.py
│ ├── create_lists.py
│ ├── check_corrupt_videos.py
│ ├── downsample_kinetics.py
│ ├── resize_videos.py
│ └── flow_vis.py
├── video_container.py
├── build.py
├── DATASET.md
├── rand_conv.py
├── multigrid_helper.py
├── loader.py
├── ssv2.py
├── ucf101.py
├── hmdb51.py
├── data_utils.py
└── kinetics.py
├── requirements.txt
├── scripts
├── eval_knn.sh
├── train.sh
└── eval_linear.sh
├── utils
├── logging.py
├── parser.py
├── metrics.py
└── meters.py
├── LICENSE
├── .gitignore
├── README.md
├── eval_knn.py
├── vision_transformer.py
└── eval_linear.py
/.github/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kahnchana/svt/HEAD/.github/intro.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .timesformer import get_vit_base_patch16_224, get_aux_token_vit
2 | from .swin_transformer import SwinTransformer3D
3 | from .s3d import S3D
4 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | from .kinetics import Kinetics # noqa
4 | from .ucf101 import UCF101
5 | from .hmdb51 import HMDB51
6 | # from .ssv2 import Ssv2 # noqa
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.18.4
2 | torch>=1.7.1
3 | torchvision>=0.2.1
4 | pillow>=5.4.1
5 | fvcore>=0.1.5
6 | sklearn>=0.0
7 | scikit-learn>=0.23.2
8 | simplejson>=3.17.5
9 | einops>=0.3.0
10 | timm>=0.4.12
11 | tqdm>=4.29.1
12 | kornia>=0.5.8
13 | opencv-python>=3.4.3.18
14 | pandas>=0.23.4
15 | joblib>=0.16.0
16 | matplotlib>=3.1.1
17 | requests>=2.25.1
18 | scikit-image>=0.15.0
19 | av>=9.2.0
--------------------------------------------------------------------------------
/scripts/eval_knn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | PROJECT_PATH="$HOME/repo/svt"
4 | CHECKPOINT="path/to/checkpoint.pth"
5 | DATASET="ucf101"
6 | DATA_PATH="${HOME}/repo/mmaction2/data/${DATASET}"
7 |
8 | cd "$PROJECT_PATH" || exit
9 |
10 | export CUDA_VISIBLE_DEVICES=0
11 | python -m torch.distributed.launch \
12 | --nproc_per_node=1 \
13 | --master_port="$RANDOM" \
14 | eval_knn.py \
15 | --arch "vit_base" \
16 | --pretrained_weights "$CHECKPOINT" \
17 | --batch_size_per_gpu 128 \
18 | --nb_knn 5 \
19 | --temperature 0.07 \
20 | --num_workers 4 \
21 | --dataset "$DATASET" \
22 | --opts \
23 | DATA.PATH_TO_DATA_DIR "${DATA_PATH}/knn_splits" \
24 | DATA.PATH_PREFIX f"${DATA_PATH}/videos"
25 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | PROJECT_PATH="$HOME/repo/svt"
4 | DATA_PATH="$HOME/data/kinetics/400/annotations"
5 | EXP_NAME="svt_test"
6 |
7 | cd "$PROJECT_PATH" || exit
8 |
9 | if [ ! -d "checkpoints/$EXP_NAME" ]; then
10 | mkdir "checkpoints/$EXP_NAME"
11 | fi
12 |
13 | export CUDA_VISIBLE_DEVICES=0,1,2,3
14 |
15 | python -m torch.distributed.launch \
16 | --nproc_per_node=4 \
17 | --master_port="$RANDOM" \
18 | train_ssl.py \
19 | --arch "timesformer" \
20 | --batch_size_per_gpu 8 \
21 | --data_path "${DATA_PATH}" \
22 | --output_dir "checkpoints/$EXP_NAME" \
23 | --opts \
24 | MODEL.TWO_STREAM False \
25 | MODEL.TWO_TOKEN False \
26 | DATA.NO_FLOW_AUG False \
27 | DATA.USE_FLOW False \
28 | DATA.RAND_CONV False \
29 | DATA.NO_SPATIAL False
30 |
31 |
--------------------------------------------------------------------------------
/scripts/eval_linear.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | PROJECT_PATH="$HOME/repo/svt"
4 | EXP_NAME="le_001"
5 | DATASET="ucf101"
6 | DATA_PATH="${HOME}/repo/mmaction2/data/${DATASET}"
7 | CHECKPOINT="path/to/checkpoint.pth"
8 |
9 | cd "$PROJECT_PATH" || exit
10 |
11 | if [ ! -d "checkpoints/$EXP_NAME" ]; then
12 | mkdir "checkpoints/$EXP_NAME"
13 | fi
14 |
15 | export CUDA_VISIBLE_DEVICES=0
16 | python -m torch.distributed.launch \
17 | --nproc_per_node=1 \
18 | --master_port="$RANDOM" \
19 | eval_linear.py \
20 | --n_last_blocks 1 \
21 | --arch "vit_base" \
22 | --pretrained_weights "$CHECKPOINT" \
23 | --epochs 20 \
24 | --lr 0.001 \
25 | --batch_size_per_gpu 16 \
26 | --num_workers 4 \
27 | --num_labels 101 \
28 | --dataset "$DATASET" \
29 | --output_dir "checkpoints/eval/$EXP_NAME" \
30 | --opts \
31 | DATA.PATH_TO_DATA_DIR "${DATA_PATH}/splits" \
32 | DATA.PATH_PREFIX f"${DATA_PATH}/videos" \
33 | DATA.USE_FLOW False
34 |
--------------------------------------------------------------------------------
/utils/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | """Logging."""
4 |
5 | import decimal
6 |
7 | import simplejson
8 |
9 | import logging
10 |
11 |
12 | def get_logger(name):
13 | """
14 | Retrieve the logger with the specified name or, if name is None, return a
15 | logger which is the root logger of the hierarchy.
16 | Args:
17 | name (string): name of the logger.
18 | """
19 | return logging.getLogger(name)
20 |
21 |
22 | def log_json_stats(stats):
23 | """
24 | Logs json stats.
25 | Args:
26 | stats (dict): a dictionary of statistical information to log.
27 | """
28 | stats = {
29 | k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v
30 | for k, v in stats.items()
31 | }
32 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
33 | logger = get_logger(__name__)
34 | logger.info("json_stats: {:s}".format(json_stats))
35 |
--------------------------------------------------------------------------------
/datasets/preprocessing/kinetics_mini.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | # root_path = "/home/kanchanaranasinghe/data/kinetics400"
4 | # anno_folder = "k400-mini"
5 | #
6 | # for split in ["train", "val",]:
7 | #
8 | # save_path = f"{root_path}/{anno_folder}/{split}.csv"
9 | # file_list = pd.read_csv(save_path, sep=" ")
10 | #
11 | # if split == "train":
12 | # file_list = file_list.sample(n=60000, replace=False)
13 | # new_save_path = f"{root_path}/{anno_folder}/{split}_60k.csv"
14 | # file_list.to_csv(new_save_path, index=False, header=False, sep=" ")
15 | # elif split == "val":
16 | # file_list = file_list.sample(n=10000, replace=False)
17 | # new_save_path = f"{root_path}/{anno_folder}/{split}_10k.csv"
18 | # file_list.to_csv(new_save_path, index=False, header=False, sep=" ")
19 |
20 |
21 | file_path = "/home/kanchanaranasinghe/data/kinetics400/k400-mini/train_60k.csv"
22 | df = pd.read_csv(file_path, header=None, sep=" ")
23 | print(f"{len(df[1].unique())} unique classes")
24 | df[1].hist(bins=len(df[1].unique()))
25 |
--------------------------------------------------------------------------------
/models/configs/Kinetics/TimeSformer_spaceOnly_8x32_224.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: kinetics
4 | BATCH_SIZE: 8
5 | EVAL_PERIOD: 5
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | DATA:
9 | PATH_TO_DATA_DIR: /path/to/kinetics/
10 | NUM_FRAMES: 8
11 | SAMPLING_RATE: 32
12 | TRAIN_JITTER_SCALES: [256, 320]
13 | TRAIN_CROP_SIZE: 224
14 | TEST_CROP_SIZE: 224
15 | INPUT_CHANNEL_NUM: [3]
16 | TIMESFORMER:
17 | ATTENTION_TYPE: 'space_only'
18 | SOLVER:
19 | BASE_LR: 0.005
20 | LR_POLICY: steps_with_relative_lrs
21 | STEPS: [0, 11, 14]
22 | LRS: [1, 0.1, 0.01]
23 | MAX_EPOCH: 15
24 | MOMENTUM: 0.9
25 | WEIGHT_DECAY: 1e-4
26 | OPTIMIZING_METHOD: sgd
27 | MODEL:
28 | MODEL_NAME: vit_base_patch16_224
29 | NUM_CLASSES: 400
30 | ARCH: vit
31 | LOSS_FUNC: cross_entropy
32 | DROPOUT_RATE: 0.5
33 | TEST:
34 | ENABLE: True
35 | DATASET: kinetics
36 | BATCH_SIZE: 8
37 | NUM_ENSEMBLE_VIEWS: 1
38 | NUM_SPATIAL_CROPS: 3
39 | DATA_LOADER:
40 | NUM_WORKERS: 8
41 | PIN_MEMORY: True
42 | NUM_GPUS: 8
43 | NUM_SHARDS: 1
44 | RNG_SEED: 0
45 | OUTPUT_DIR: .
46 |
--------------------------------------------------------------------------------
/models/configs/Kinetics/TimeSformer_divST_16x16_448.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: kinetics
4 | BATCH_SIZE: 8
5 | EVAL_PERIOD: 5
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | DATA:
9 | PATH_TO_DATA_DIR: /path/to/kinetics/
10 | NUM_FRAMES: 16
11 | SAMPLING_RATE: 16
12 | TRAIN_JITTER_SCALES: [448, 512]
13 | TRAIN_CROP_SIZE: 448
14 | TEST_CROP_SIZE: 448
15 | INPUT_CHANNEL_NUM: [3]
16 | TIMESFORMER:
17 | ATTENTION_TYPE: 'divided_space_time'
18 | SOLVER:
19 | BASE_LR: 0.005
20 | LR_POLICY: steps_with_relative_lrs
21 | STEPS: [0, 11, 14]
22 | LRS: [1, 0.1, 0.01]
23 | MAX_EPOCH: 15
24 | MOMENTUM: 0.9
25 | WEIGHT_DECAY: 1e-4
26 | OPTIMIZING_METHOD: sgd
27 | MODEL:
28 | MODEL_NAME: vit_base_patch16_224
29 | NUM_CLASSES: 400
30 | ARCH: vit
31 | LOSS_FUNC: cross_entropy
32 | DROPOUT_RATE: 0.5
33 | TEST:
34 | ENABLE: True
35 | DATASET: kinetics
36 | BATCH_SIZE: 8
37 | NUM_ENSEMBLE_VIEWS: 1
38 | NUM_SPATIAL_CROPS: 3
39 | DATA_LOADER:
40 | NUM_WORKERS: 8
41 | PIN_MEMORY: True
42 | NUM_GPUS: 8
43 | NUM_SHARDS: 1
44 | RNG_SEED: 0
45 | OUTPUT_DIR: .
46 |
--------------------------------------------------------------------------------
/models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: kinetics
4 | BATCH_SIZE: 8
5 | EVAL_PERIOD: 5
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | DATA:
9 | PATH_TO_DATA_DIR: /path/to/kinetics/
10 | NUM_FRAMES: 8
11 | SAMPLING_RATE: 32
12 | TRAIN_JITTER_SCALES: [256, 320]
13 | TRAIN_CROP_SIZE: 224
14 | TEST_CROP_SIZE: 224
15 | INPUT_CHANNEL_NUM: [3]
16 | TIMESFORMER:
17 | ATTENTION_TYPE: 'divided_space_time'
18 | SOLVER:
19 | BASE_LR: 0.005
20 | LR_POLICY: steps_with_relative_lrs
21 | STEPS: [0, 11, 14]
22 | LRS: [1, 0.1, 0.01]
23 | MAX_EPOCH: 15
24 | MOMENTUM: 0.9
25 | WEIGHT_DECAY: 1e-4
26 | OPTIMIZING_METHOD: sgd
27 | MODEL:
28 | MODEL_NAME: vit_base_patch16_224
29 | NUM_CLASSES: 400
30 | ARCH: vit
31 | LOSS_FUNC: cross_entropy
32 | DROPOUT_RATE: 0.5
33 | TEST:
34 | ENABLE: True
35 | DATASET: kinetics
36 | BATCH_SIZE: 8
37 | NUM_ENSEMBLE_VIEWS: 1
38 | NUM_SPATIAL_CROPS: 3
39 | DATA_LOADER:
40 | NUM_WORKERS: 8
41 | PIN_MEMORY: True
42 | NUM_GPUS: 8
43 | NUM_SHARDS: 1
44 | RNG_SEED: 0
45 | OUTPUT_DIR: .
46 |
--------------------------------------------------------------------------------
/models/configs/Kinetics/TimeSformer_divST_96x4_224.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: kinetics
4 | BATCH_SIZE: 8
5 | EVAL_PERIOD: 5
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | DATA:
9 | PATH_TO_DATA_DIR: /path/to/kinetics/
10 | NUM_FRAMES: 96
11 | SAMPLING_RATE: 4
12 | TRAIN_JITTER_SCALES: [256, 320]
13 | TRAIN_CROP_SIZE: 224
14 | TEST_CROP_SIZE: 224
15 | INPUT_CHANNEL_NUM: [3]
16 | TIMESFORMER:
17 | ATTENTION_TYPE: 'divided_space_time'
18 | SOLVER:
19 | BASE_LR: 0.005
20 | LR_POLICY: steps_with_relative_lrs
21 | STEPS: [0, 11, 14]
22 | LRS: [1, 0.1, 0.01]
23 | MAX_EPOCH: 15
24 | MOMENTUM: 0.9
25 | WEIGHT_DECAY: 1e-4
26 | OPTIMIZING_METHOD: sgd
27 | MODEL:
28 | MODEL_NAME: vit_base_patch16_224
29 | NUM_CLASSES: 400
30 | ARCH: vit
31 | LOSS_FUNC: cross_entropy
32 | DROPOUT_RATE: 0.5
33 | TEST:
34 | ENABLE: True
35 | DATASET: kinetics
36 | BATCH_SIZE: 8
37 | NUM_ENSEMBLE_VIEWS: 1
38 | NUM_SPATIAL_CROPS: 3
39 | DATA_LOADER:
40 | NUM_WORKERS: 8
41 | PIN_MEMORY: True
42 | NUM_GPUS: 8
43 | NUM_SHARDS: 1
44 | RNG_SEED: 0
45 | OUTPUT_DIR: .
46 |
--------------------------------------------------------------------------------
/models/configs/Kinetics/TimeSformer_jointST_8x32_224.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: kinetics
4 | BATCH_SIZE: 8
5 | EVAL_PERIOD: 5
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | DATA:
9 | PATH_TO_DATA_DIR: /path/to/kinetics/
10 | NUM_FRAMES: 8
11 | SAMPLING_RATE: 32
12 | TRAIN_JITTER_SCALES: [256, 320]
13 | TRAIN_CROP_SIZE: 224
14 | TEST_CROP_SIZE: 224
15 | INPUT_CHANNEL_NUM: [3]
16 | TIMESFORMER:
17 | ATTENTION_TYPE: 'joint_space_time'
18 | SOLVER:
19 | BASE_LR: 0.005
20 | LR_POLICY: steps_with_relative_lrs
21 | STEPS: [0, 11, 14]
22 | LRS: [1, 0.1, 0.01]
23 | MAX_EPOCH: 15
24 | MOMENTUM: 0.9
25 | WEIGHT_DECAY: 1e-4
26 | OPTIMIZING_METHOD: sgd
27 | MODEL:
28 | MODEL_NAME: vit_base_patch16_224
29 | NUM_CLASSES: 400
30 | ARCH: vit
31 | LOSS_FUNC: cross_entropy
32 | DROPOUT_RATE: 0.5
33 | TEST:
34 | ENABLE: True
35 | DATASET: kinetics
36 | BATCH_SIZE: 8
37 | NUM_ENSEMBLE_VIEWS: 1
38 | NUM_SPATIAL_CROPS: 3
39 | DATA_LOADER:
40 | NUM_WORKERS: 8
41 | PIN_MEMORY: True
42 | NUM_GPUS: 8
43 | NUM_SHARDS: 1
44 | RNG_SEED: 0
45 | OUTPUT_DIR: .
46 |
--------------------------------------------------------------------------------
/models/configs/Kinetics/TimeSformer_divST_8x32_224_4gpus.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: kinetics
4 | BATCH_SIZE: 4
5 | EVAL_PERIOD: 5
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | DATA:
9 | PATH_TO_DATA_DIR: /path/to/kinetics/
10 | NUM_FRAMES: 8
11 | SAMPLING_RATE: 32
12 | TRAIN_JITTER_SCALES: [256, 320]
13 | TRAIN_CROP_SIZE: 224
14 | TEST_CROP_SIZE: 224
15 | INPUT_CHANNEL_NUM: [3]
16 | TIMESFORMER:
17 | ATTENTION_TYPE: 'divided_space_time'
18 | SOLVER:
19 | BASE_LR: 0.005
20 | LR_POLICY: steps_with_relative_lrs
21 | STEPS: [0, 11, 14]
22 | LRS: [1, 0.1, 0.01]
23 | MAX_EPOCH: 15
24 | MOMENTUM: 0.9
25 | WEIGHT_DECAY: 1e-4
26 | OPTIMIZING_METHOD: sgd
27 | MODEL:
28 | MODEL_NAME: vit_base_patch16_224
29 | NUM_CLASSES: 400
30 | ARCH: vit
31 | LOSS_FUNC: cross_entropy
32 | DROPOUT_RATE: 0.5
33 | TEST:
34 | ENABLE: True
35 | DATASET: kinetics
36 | BATCH_SIZE: 4
37 | NUM_ENSEMBLE_VIEWS: 1
38 | NUM_SPATIAL_CROPS: 3
39 | DATA_LOADER:
40 | NUM_WORKERS: 4
41 | PIN_MEMORY: True
42 | NUM_GPUS: 4
43 | NUM_SHARDS: 1
44 | RNG_SEED: 0
45 | OUTPUT_DIR: .
46 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Kanchana Ranasinghe
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/datasets/video_container.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | import av
4 |
5 |
6 | def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"):
7 | """
8 | Given the path to the video, return the pyav video container.
9 | Args:
10 | path_to_vid (str): path to the video.
11 | multi_thread_decode (bool): if True, perform multi-thread decoding.
12 | backend (str): decoder backend, options include `pyav` and
13 | `torchvision`, default is `pyav`.
14 | Returns:
15 | container (container): video container.
16 | """
17 | if backend == "torchvision":
18 | with open(path_to_vid, "rb") as fp:
19 | container = fp.read()
20 | return container
21 | elif backend == "pyav":
22 | #try:
23 | container = av.open(path_to_vid)
24 | if multi_thread_decode:
25 | # Enable multiple threads for decoding.
26 | container.streams.video[0].thread_type = "AUTO"
27 | #except:
28 | # container = None
29 | return container
30 | else:
31 | raise NotImplementedError("Unknown backend {}".format(backend))
32 |
--------------------------------------------------------------------------------
/datasets/build.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | from fvcore.common.registry import Registry
4 |
5 | DATASET_REGISTRY = Registry("DATASET")
6 | DATASET_REGISTRY.__doc__ = """
7 | Registry for dataset.
8 |
9 | The registered object will be called with `obj(cfg, split)`.
10 | The call should return a `torch.utils.data.Dataset` object.
11 | """
12 |
13 |
14 | def build_dataset(dataset_name, cfg, split):
15 | """
16 | Build a dataset, defined by `dataset_name`.
17 | Args:
18 | dataset_name (str): the name of the dataset to be constructed.
19 | cfg (CfgNode): configs. Details can be found in
20 | slowfast/config/defaults.py
21 | split (str): the split of the data loader. Options include `train`,
22 | `val`, and `test`.
23 | Returns:
24 | Dataset: a constructed dataset specified by dataset_name.
25 | """
26 | # Capitalize the the first letter of the dataset_name since the dataset_name
27 | # in configs may be in lowercase but the name of dataset class should always
28 | # start with an uppercase letter.
29 | name = dataset_name.capitalize()
30 | return DATASET_REGISTRY.get(name)(cfg, split)
31 |
--------------------------------------------------------------------------------
/models/configs/Kinetics/TimeSformer_divST_8x32_224_TEST.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: False
3 | DATASET: kinetics
4 | BATCH_SIZE: 8
5 | EVAL_PERIOD: 5
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | DATA:
9 | PATH_TO_DATA_DIR: /path/to/kinetics/
10 | NUM_FRAMES: 8
11 | SAMPLING_RATE: 32
12 | TRAIN_JITTER_SCALES: [256, 320]
13 | TRAIN_CROP_SIZE: 224
14 | TEST_CROP_SIZE: 224
15 | INPUT_CHANNEL_NUM: [3]
16 | TIMESFORMER:
17 | ATTENTION_TYPE: 'divided_space_time'
18 | SOLVER:
19 | BASE_LR: 0.005
20 | LR_POLICY: steps_with_relative_lrs
21 | STEPS: [0, 11, 14]
22 | LRS: [1, 0.1, 0.01]
23 | MAX_EPOCH: 15
24 | MOMENTUM: 0.9
25 | WEIGHT_DECAY: 1e-4
26 | OPTIMIZING_METHOD: sgd
27 | MODEL:
28 | MODEL_NAME: vit_base_patch16_224
29 | NUM_CLASSES: 400
30 | ARCH: vit
31 | LOSS_FUNC: cross_entropy
32 | DROPOUT_RATE: 0.5
33 | TEST:
34 | ENABLE: True
35 | DATASET: kinetics
36 | BATCH_SIZE: 8
37 | NUM_ENSEMBLE_VIEWS: 1
38 | NUM_SPATIAL_CROPS: 3
39 | CHECKPOINT_FILE_PATH: '/checkpoint/gedas/jobs/timesformer/kinetics_400/TimeSformer_divST_8x32_224/checkpoints/checkpoint_epoch_00025.pyth'
40 | DATA_LOADER:
41 | NUM_WORKERS: 8
42 | PIN_MEMORY: True
43 | NUM_GPUS: 8
44 | NUM_SHARDS: 1
45 | RNG_SEED: 0
46 | OUTPUT_DIR: .
47 |
--------------------------------------------------------------------------------
/datasets/preprocessing/verify_file_list.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from tqdm import tqdm
3 |
4 | from datasets.decoder import decode
5 | from datasets.video_container import get_video_container
6 |
7 | root_path = "/home/kanchanaranasinghe/data/kinetics400"
8 |
9 | for split in ["train", "val", "test"]:
10 | anno_folder = "new_annotations"
11 |
12 | if split == "train":
13 | save_path = f"{root_path}/{anno_folder}/{split}_60k.csv"
14 | file_list = pd.read_csv(save_path, sep=" ")
15 | elif split in ["val", "test"]:
16 | save_path = f"{root_path}/{anno_folder}/{split}_10k.csv"
17 | file_list = pd.read_csv(save_path, sep=" ")
18 | else:
19 | raise NotImplementedError("invalid split")
20 |
21 | good_count = 0
22 | frames = "0"
23 | bad_list = []
24 | print(f"Processing files from: {save_path}")
25 | for file in tqdm(file_list.values[:, 0]):
26 | try:
27 | container = get_video_container(file, True)
28 | frames = decode(container, 32, 8)
29 | assert frames is not None, "frames is None"
30 | except Exception as e:
31 | print(e, file)
32 | bad_list.append(file)
33 | else:
34 | good_count += 1
35 |
36 | print(f"{len(file_list) - good_count} files bad. {good_count} / {len(file_list)} files are good.")
37 |
--------------------------------------------------------------------------------
/datasets/DATASET.md:
--------------------------------------------------------------------------------
1 | # Dataset Preparation
2 |
3 | ## Kinetics
4 |
5 | The Kinetics Dataset could be downloaded from the following [link](https://github.com/cvdfoundation/kinetics-dataset):
6 |
7 | After all the videos were downloaded, resize the video to the short edge size of 256, then prepare the csv files for training, validation, and testing set as `train.csv`, `val.csv`, `test.csv`. The format of the csv file is:
8 |
9 | ```
10 | path_to_video_1 label_1
11 | path_to_video_2 label_2
12 | path_to_video_3 label_3
13 | ...
14 | path_to_video_N label_N
15 | ```
16 |
17 | ## Something-Something V2
18 | 1. Please download the dataset and annotations from [dataset provider](https://20bn.com/datasets/something-something).
19 |
20 | 2. Download the *frame list* from the following links: ([train](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/val.csv)).
21 |
22 | 3. Extract the frames at 30 FPS. (We used ffmpeg-4.1.3 with command
23 | `ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}"`
24 | in experiments.) Please put the frames in a structure consistent with the frame lists.
25 |
26 | Please put all annotation json files and the frame lists in the same folder, and set `DATA.PATH_TO_DATA_DIR` to the path. Set `DATA.PATH_PREFIX` to be the path to the folder containing extracted frames.
27 |
--------------------------------------------------------------------------------
/datasets/preprocessing/create_lists.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from tqdm import tqdm
4 | import json
5 |
6 | root_path = "/home/kanchanaranasinghe/data/kinetics400"
7 | split = "val"
8 | anno_folder = "new_annotations"
9 |
10 | video_list = set(os.listdir(f"{root_path}/{split}_d256"))
11 | save_path = f"{root_path}/{anno_folder}/{split}.csv"
12 | os.makedirs(f"{root_path}/{anno_folder}", exist_ok=True)
13 |
14 | labels = pd.read_csv(f"{root_path}/annotations/{split}.csv")
15 | label_dict = {y: x for x, y in enumerate(sorted(labels.label.unique().tolist()))}
16 | json.dump(label_dict, open(f"{root_path}/new_annotations/{split}_label_dict.json", "w"))
17 |
18 | with open(f"{root_path}/bad_files_{split}.txt", "r") as fo:
19 | bad_videos = fo.readlines()
20 | bad_videos = [x.strip() for x in bad_videos]
21 |
22 | video_label = []
23 | for idx, row in tqdm(labels.iterrows()):
24 | video_name = f"{row.youtube_id}_{row.time_start:06d}_{row.time_end:06d}.mp4"
25 | if video_name in bad_videos:
26 | continue
27 | if video_name not in video_list:
28 | continue
29 | assert os.path.exists(f"{root_path}/{split}_d256/{video_name}"), "video not found"
30 | video_label.append((video_name, label_dict[row.label]))
31 |
32 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
33 | pd.DataFrame(video_label).to_csv(save_path, index=False, header=False, sep=" ")
34 |
--------------------------------------------------------------------------------
/models/configs/Kinetics/SLOWFAST_4x16_R50.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: kinetics
4 | BATCH_SIZE: 64
5 | EVAL_PERIOD: 10
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | DATA:
9 | PATH_TO_DATA_DIR: /path/to/kinetics/
10 | NUM_FRAMES: 32
11 | SAMPLING_RATE: 2
12 | TRAIN_JITTER_SCALES: [256, 320]
13 | TRAIN_CROP_SIZE: 224
14 | TEST_CROP_SIZE: 256
15 | INPUT_CHANNEL_NUM: [3, 3]
16 | SLOWFAST:
17 | ALPHA: 8
18 | BETA_INV: 8
19 | FUSION_CONV_CHANNEL_RATIO: 2
20 | FUSION_KERNEL_SZ: 5
21 | RESNET:
22 | ZERO_INIT_FINAL_BN: True
23 | WIDTH_PER_GROUP: 64
24 | NUM_GROUPS: 1
25 | DEPTH: 50
26 | TRANS_FUNC: bottleneck_transform
27 | STRIDE_1X1: False
28 | NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
29 | SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
30 | SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
31 | NONLOCAL:
32 | LOCATION: [[[], []], [[], []], [[], []], [[], []]]
33 | GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
34 | INSTANTIATION: dot_product
35 | BN:
36 | USE_PRECISE_STATS: True
37 | NUM_BATCHES_PRECISE: 200
38 | SOLVER:
39 | BASE_LR: 0.8
40 | LR_POLICY: cosine
41 | MAX_EPOCH: 196
42 | MOMENTUM: 0.9
43 | WEIGHT_DECAY: 1e-4
44 | WARMUP_EPOCHS: 34.0
45 | WARMUP_START_LR: 0.01
46 | OPTIMIZING_METHOD: sgd
47 | MODEL:
48 | NUM_CLASSES: 400
49 | ARCH: slowfast
50 | MODEL_NAME: SlowFast
51 | LOSS_FUNC: cross_entropy
52 | DROPOUT_RATE: 0.5
53 | TEST:
54 | ENABLE: True
55 | DATASET: kinetics
56 | BATCH_SIZE: 64
57 | DATA_LOADER:
58 | NUM_WORKERS: 8
59 | PIN_MEMORY: True
60 | NUM_GPUS: 8
61 | NUM_SHARDS: 1
62 | RNG_SEED: 0
63 | OUTPUT_DIR: .
64 |
--------------------------------------------------------------------------------
/models/configs/Kinetics/SLOWFAST_8x8_R50.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: kinetics
4 | BATCH_SIZE: 64
5 | EVAL_PERIOD: 10
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | DATA:
9 | PATH_TO_DATA_DIR: /path/to/kinetics/
10 | NUM_FRAMES: 32
11 | SAMPLING_RATE: 2
12 | TRAIN_JITTER_SCALES: [256, 320]
13 | TRAIN_CROP_SIZE: 224
14 | TEST_CROP_SIZE: 256
15 | INPUT_CHANNEL_NUM: [3, 3]
16 | SLOWFAST:
17 | ALPHA: 4
18 | BETA_INV: 8
19 | FUSION_CONV_CHANNEL_RATIO: 2
20 | FUSION_KERNEL_SZ: 7
21 | RESNET:
22 | ZERO_INIT_FINAL_BN: True
23 | WIDTH_PER_GROUP: 64
24 | NUM_GROUPS: 1
25 | DEPTH: 50
26 | TRANS_FUNC: bottleneck_transform
27 | STRIDE_1X1: False
28 | NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
29 | SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
30 | SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
31 | NONLOCAL:
32 | LOCATION: [[[], []], [[], []], [[], []], [[], []]]
33 | GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
34 | INSTANTIATION: dot_product
35 | BN:
36 | USE_PRECISE_STATS: True
37 | NUM_BATCHES_PRECISE: 200
38 | SOLVER:
39 | BASE_LR: 0.8
40 | LR_POLICY: cosine
41 | MAX_EPOCH: 196
42 | MOMENTUM: 0.9
43 | WEIGHT_DECAY: 1e-4
44 | WARMUP_EPOCHS: 34.0
45 | WARMUP_START_LR: 0.01
46 | OPTIMIZING_METHOD: sgd
47 | MODEL:
48 | NUM_CLASSES: 400
49 | ARCH: slowfast
50 | MODEL_NAME: SlowFast
51 | LOSS_FUNC: cross_entropy
52 | DROPOUT_RATE: 0.5
53 | TEST:
54 | ENABLE: True
55 | DATASET: kinetics
56 | BATCH_SIZE: 64
57 | DATA_LOADER:
58 | NUM_WORKERS: 8
59 | PIN_MEMORY: True
60 | NUM_GPUS: 8
61 | NUM_SHARDS: 1
62 | RNG_SEED: 0
63 | OUTPUT_DIR: .
64 |
--------------------------------------------------------------------------------
/models/configs/Kinetics/SLOWFAST_8x8_R101.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ENABLE: True
3 | DATASET: kinetics
4 | BATCH_SIZE: 64
5 | EVAL_PERIOD: 10
6 | CHECKPOINT_PERIOD: 5
7 | AUTO_RESUME: True
8 | DATA:
9 | PATH_TO_DATA_DIR: /path/to/kinetics/
10 | NUM_FRAMES: 32
11 | SAMPLING_RATE: 2
12 | TRAIN_JITTER_SCALES: [256, 340]
13 | TRAIN_CROP_SIZE: 224
14 | TEST_CROP_SIZE: 256
15 | INPUT_CHANNEL_NUM: [3, 3]
16 | SLOWFAST:
17 | ALPHA: 4
18 | BETA_INV: 8
19 | FUSION_CONV_CHANNEL_RATIO: 2
20 | FUSION_KERNEL_SZ: 5
21 | RESNET:
22 | ZERO_INIT_FINAL_BN: True
23 | WIDTH_PER_GROUP: 64
24 | NUM_GROUPS: 1
25 | DEPTH: 101
26 | TRANS_FUNC: bottleneck_transform
27 | STRIDE_1X1: False
28 | NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
29 | SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
30 | SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
31 | NONLOCAL:
32 | LOCATION: [[[], []], [[], []], [[], []], [[], []]]
33 | GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
34 | INSTANTIATION: dot_product
35 | BN:
36 | USE_PRECISE_STATS: True
37 | NUM_BATCHES_PRECISE: 200
38 | SOLVER:
39 | BASE_LR: 0.8 ## 8 nodes
40 | LR_POLICY: cosine
41 | MAX_EPOCH: 196
42 | MOMENTUM: 0.9
43 | WEIGHT_DECAY: 1e-4
44 | WARMUP_EPOCHS: 34.0
45 | WARMUP_START_LR: 0.01
46 | OPTIMIZING_METHOD: sgd
47 | MODEL:
48 | NUM_CLASSES: 400
49 | ARCH: slowfast
50 | MODEL_NAME: SlowFast
51 | LOSS_FUNC: cross_entropy
52 | DROPOUT_RATE: 0.5
53 | TEST:
54 | ENABLE: True
55 | DATASET: kinetics
56 | BATCH_SIZE: 64
57 | DATA_LOADER:
58 | NUM_WORKERS: 8
59 | PIN_MEMORY: True
60 | NUM_GPUS: 8
61 | NUM_SHARDS: 1
62 | RNG_SEED: 0
63 | OUTPUT_DIR: .
64 |
--------------------------------------------------------------------------------
/datasets/preprocessing/check_corrupt_videos.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from datasets.video_container import get_video_container
3 | from datasets.decoder import decode
4 | from tqdm import tqdm
5 | import json
6 |
7 | root_path = "/home/kanchanaranasinghe/data/kinetics400"
8 | root_path = "/home/kanchanaranasinghe/repo/mmaction2/data/hmdb51/videos"
9 | split = "hmdb"
10 |
11 | file_list = glob.glob(f"{root_path}/{split}_256/*.mp4")
12 | file_list = glob.glob(f"{root_path}/*/*.avi")
13 | good_count = 0
14 | frames = "0"
15 | bad_list = []
16 | for file in tqdm(file_list):
17 | try:
18 | container = get_video_container(file, True)
19 | frames = decode(container, 32, 8)
20 | assert frames is not None, "frames is None"
21 | except Exception as e:
22 | print(e, file)
23 | bad_list.append(file)
24 | else:
25 | good_count += 1
26 |
27 | print(f"{len(file_list) - good_count} files bad. {good_count} / {len(file_list)} files are good.")
28 | json.dump({"bad": bad_list}, open(f"{root_path}/{split}_256_bad.json", "w"))
29 |
30 | # import os
31 | # import json
32 | # import shutil
33 | # import pandas as pd
34 | #
35 | # root_path = "/home/kanchanaranasinghe/repo/mmaction2/data/hmdb51/splits"
36 | # txt_file_name = "/home/kanchanaranasinghe/repo/mmaction2/data/hmdb51/splits/hmdb51_val_split_1_videos.txt"
37 | # files = json.load(open(f"{root_path}/{split}_256_bad.json", "r"))
38 | # bad_names = [x[59:] for x in files['bad']]
39 | # df = pd.read_csv(f"{txt_file_name}", sep=" ",
40 | # header=None)
41 | # df = df[df[0].isin(bad_names) == False]
42 | # df.to_csv(f"{txt_file_name}", sep=" ",
43 | # header=None, index=None)
44 |
45 | # for file in files["bad"]:
46 | # shutil.move(file, file.replace(f"/{split}_256/", f"/{split}_256_bad/"))
47 |
--------------------------------------------------------------------------------
/datasets/preprocessing/downsample_kinetics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | from tqdm import tqdm
4 | from joblib import Parallel
5 | from joblib import delayed
6 |
7 |
8 | def downscale_clip(inname, outname):
9 | inname = '"%s"' % inname
10 | outname = '"%s"' % outname
11 | # command = "ffmpeg -loglevel panic -i {} -filter:v scale=\"trunc(oh*a/2)*2:256\" -q:v 1 -c:a copy {}".format(
12 | # inname, outname)
13 | command = f"ffmpeg -i {inname} -filter:v scale=\"trunc(oh*a/2)*2:256\" -q:v 1 -c:a copy {outname}"
14 | try:
15 | output = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
16 | except subprocess.CalledProcessError as err:
17 | print(err)
18 | return err.output
19 |
20 | return output
21 |
22 |
23 | def downscale_clip_wrapper(file_name):
24 | in_name = f"{folder_path}/{file_name}"
25 | out_name = f"{output_path}/{file_name}"
26 |
27 | log = downscale_clip(in_name, out_name)
28 | return file_name, log
29 |
30 |
31 | if __name__ == '__main__':
32 | root_path = "/home/salman/data/kinetics/400"
33 | split = "val"
34 |
35 | folder_path = f'{root_path}/{split}'
36 | output_path = f'{root_path}/{split}_256'
37 | os.makedirs(output_path, exist_ok=True)
38 |
39 | file_list = os.listdir(folder_path)
40 | completed_file_list = set(os.listdir(output_path))
41 | file_list = [x for x in file_list if x not in completed_file_list]
42 |
43 | # file_list = file_list[:100]
44 | print(f"Starting to downsample {len(file_list)} video files.")
45 |
46 | # split = len(file_list) // 10
47 | # list_of_lists = [file_list[x * split:(x + 1) * split] for x in range(10)]
48 | # list_of_lists[-1].extend(file_list[10 * split:])
49 |
50 | for file in tqdm(file_list):
51 | _, log = downscale_clip_wrapper(file)
52 |
53 | # status_lst = Parallel(n_jobs=16)(delayed(downscale_clip_wrapper)(row) for row in file_list)
54 | # status_lst = Parallel(n_jobs=16)(downscale_clip_wrapper(row) for row in file_list)
55 | # with open(f"{root_path}/downsample_{split}_logs.txt", "w") as fo:
56 | # fo.writelines([f"{x[0], x[1]}\n" for x in status_lst])
57 |
--------------------------------------------------------------------------------
/datasets/rand_conv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from PIL import Image
5 | from einops import rearrange, repeat
6 |
7 |
8 | class RandConv(nn.Module):
9 |
10 | def __init__(self, kernel_size=3, alpha=0.7, temporal_input=False):
11 | super(RandConv, self).__init__()
12 | self.m = nn.Conv2d(3, 3, kernel_size, stride=1, padding=kernel_size // 2, bias=False)
13 | self.std_normal = 1 / (np.sqrt(3) * kernel_size)
14 | self.alpha = alpha
15 | self.temporal_input = temporal_input
16 |
17 | def forward(self, image):
18 | with torch.no_grad():
19 | self.m.weight = torch.nn.Parameter(torch.normal(mean=torch.zeros_like(self.m.weight),
20 | std=torch.ones_like(self.m.weight) * self.std_normal))
21 | if self.temporal_input:
22 | batch_dim = image.shape[0]
23 | filtered_im = rearrange(image, "b c t h w -> (b t) c h w")
24 | filtered_im = self.m(filtered_im)
25 | filtered_im = rearrange(filtered_im, "(b t) c h w -> b c t h w", b=batch_dim)
26 | else:
27 | filtered_im = self.m(image)
28 | return self.alpha * image + (1 - self.alpha) * filtered_im
29 |
30 |
31 | if __name__ == '__main__':
32 |
33 | filter_im = RandConv(temporal_input=True)
34 | mean = torch.Tensor([0.485, 0.456, 0.406])
35 | std = torch.Tensor([0.229, 0.224, 0.225])
36 |
37 |
38 | def whiten(input):
39 | return (input - mean.reshape(1, -1, 1, 1, 1)) / std.reshape(1, -1, 1, 1, 1)
40 |
41 |
42 | def dewhiten(input):
43 | return torch.clip(input * std.reshape(1, -1, 1, 1, 1) + mean.reshape(1, -1, 1, 1, 1), 0, 1)
44 |
45 |
46 | img = Image.open("/Users/kanchana/Documents/current/video_research/repo/dino/data/rand_conv/raw.jpg")
47 | # img_arr = torch.Tensor(np.array(img).transpose(2, 0, 1)).unsqueeze(0)
48 | img_arr = repeat(torch.Tensor(np.array(img)).unsqueeze(0), "b h w c -> b c t h w", t=8)
49 | img_arr = img_arr.float() / 255.
50 |
51 | for idx in range(10):
52 | with torch.no_grad():
53 | out = dewhiten(filter_im(whiten(img_arr)))
54 | # out = alpha * img_arr + (1 - alpha) * out
55 | im_vis = out[0, :, 0].permute(1, 2, 0).detach().numpy() * 255.
56 | im_vis = Image.fromarray(im_vis.astype(np.uint8))
57 | im_vis.save(f"/Users/kanchana/Documents/current/video_research/repo/dino/data/rand_conv/{idx + 1:05d}.jpg")
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project specific
2 | .idea
3 | .DS_Store
4 | checkpoints
5 | /data
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
--------------------------------------------------------------------------------
/datasets/multigrid_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | """Helper functions for multigrid training."""
4 |
5 | import numpy as np
6 | from torch._six import int_classes as _int_classes
7 | from torch.utils.data.sampler import Sampler
8 |
9 |
10 | class ShortCycleBatchSampler(Sampler):
11 | """
12 | Extend Sampler to support "short cycle" sampling.
13 | See paper "A Multigrid Method for Efficiently Training Video Models",
14 | Wu et al., 2019 (https://arxiv.org/abs/1912.00998) for details.
15 | """
16 |
17 | def __init__(self, sampler, batch_size, drop_last, cfg):
18 | if not isinstance(sampler, Sampler):
19 | raise ValueError(
20 | "sampler should be an instance of "
21 | "torch.utils.data.Sampler, but got sampler={}".format(sampler)
22 | )
23 | if (
24 | not isinstance(batch_size, _int_classes)
25 | or isinstance(batch_size, bool)
26 | or batch_size <= 0
27 | ):
28 | raise ValueError(
29 | "batch_size should be a positive integer value, "
30 | "but got batch_size={}".format(batch_size)
31 | )
32 | if not isinstance(drop_last, bool):
33 | raise ValueError(
34 | "drop_last should be a boolean value, but got "
35 | "drop_last={}".format(drop_last)
36 | )
37 | self.sampler = sampler
38 | self.drop_last = drop_last
39 |
40 | bs_factor = [
41 | int(
42 | round(
43 | (
44 | float(cfg.DATA.TRAIN_CROP_SIZE)
45 | / (s * cfg.MULTIGRID.DEFAULT_S)
46 | )
47 | ** 2
48 | )
49 | )
50 | for s in cfg.MULTIGRID.SHORT_CYCLE_FACTORS
51 | ]
52 |
53 | self.batch_sizes = [
54 | batch_size * bs_factor[0],
55 | batch_size * bs_factor[1],
56 | batch_size,
57 | ]
58 |
59 | def __iter__(self):
60 | counter = 0
61 | batch_size = self.batch_sizes[0]
62 | batch = []
63 | for idx in self.sampler:
64 | batch.append((idx, counter % 3))
65 | if len(batch) == batch_size:
66 | yield batch
67 | counter += 1
68 | batch_size = self.batch_sizes[counter % 3]
69 | batch = []
70 | if len(batch) > 0 and not self.drop_last:
71 | yield batch
72 |
73 | def __len__(self):
74 | avg_batch_size = sum(self.batch_sizes) / 3.0
75 | if self.drop_last:
76 | return int(np.floor(len(self.sampler) / avg_batch_size))
77 | else:
78 | return int(np.ceil(len(self.sampler) / avg_batch_size))
79 |
--------------------------------------------------------------------------------
/utils/parser.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | """Argument parser functions."""
4 |
5 | import argparse
6 | import sys
7 |
8 | from utils.defaults import get_cfg
9 |
10 |
11 | def parse_args():
12 | """
13 | Parse the following arguments for a default parser for PySlowFast users.
14 | Args:
15 | shard_id (int): shard id for the current machine. Starts from 0 to
16 | num_shards - 1. If single machine is used, then set shard id to 0.
17 | num_shards (int): number of shards using by the job.
18 | init_method (str): initialization method to launch the job with multiple
19 | devices. Options includes TCP or shared file-system for
20 | initialization. details can be find in
21 | https://pytorch.org/docs/stable/distributed.html#tcp-initialization
22 | cfg (str): path to the config file.
23 | opts (argument): provide addtional options from the command line, it
24 | overwrites the config loaded from file.
25 | """
26 | parser = argparse.ArgumentParser(
27 | description="Provide SlowFast video training and testing pipeline."
28 | )
29 | parser.add_argument(
30 | "--shard_id",
31 | help="The shard id of current node, Starts from 0 to num_shards - 1",
32 | default=0,
33 | type=int,
34 | )
35 | parser.add_argument(
36 | "--num_shards",
37 | help="Number of shards using by the job",
38 | default=1,
39 | type=int,
40 | )
41 | parser.add_argument(
42 | "--init_method",
43 | help="Initialization method, includes TCP or shared file-system",
44 | default="tcp://localhost:9999",
45 | type=str,
46 | )
47 | parser.add_argument(
48 | "--cfg",
49 | dest="cfg_file",
50 | help="Path to the config file",
51 | default="configs/Kinetics/SLOWFAST_4x16_R50.yaml",
52 | type=str,
53 | )
54 | parser.add_argument(
55 | "opts",
56 | help="See slowfast/config/defaults.py for all options",
57 | default=None,
58 | nargs=argparse.REMAINDER,
59 | )
60 | if len(sys.argv) == 1:
61 | parser.print_help()
62 | return parser.parse_args()
63 |
64 |
65 | def load_config(args):
66 | """
67 | Given the arguemnts, load and initialize the configs.
68 | Args:
69 | args (argument): arguments includes `shard_id`, `num_shards`,
70 | `init_method`, `cfg_file`, and `opts`.
71 | """
72 | # Setup cfg.
73 | cfg = get_cfg()
74 | # Load config from cfg.
75 | if args.cfg_file is not None:
76 | cfg.merge_from_file(args.cfg_file)
77 | # Load config from command line, overwrite config from opts.
78 | if args.opts is not None:
79 | cfg.merge_from_list(args.opts)
80 |
81 | # Inherit parameters from args.
82 | if hasattr(args, "num_shards") and hasattr(args, "shard_id"):
83 | cfg.NUM_SHARDS = args.num_shards
84 | cfg.SHARD_ID = args.shard_id
85 | if hasattr(args, "rng_seed"):
86 | cfg.RNG_SEED = args.rng_seed
87 | if hasattr(args, "output_dir"):
88 | cfg.OUTPUT_DIR = args.output_dir
89 |
90 | return cfg
91 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | """Functions for computing metrics."""
4 |
5 | import torch
6 | import numpy as np
7 |
8 | def topks_correct(preds, labels, ks):
9 | """
10 | Given the predictions, labels, and a list of top-k values, compute the
11 | number of correct predictions for each top-k value.
12 |
13 | Args:
14 | preds (array): array of predictions. Dimension is batchsize
15 | N x ClassNum.
16 | labels (array): array of labels. Dimension is batchsize N.
17 | ks (list): list of top-k values. For example, ks = [1, 5] correspods
18 | to top-1 and top-5.
19 |
20 | Returns:
21 | topks_correct (list): list of numbers, where the `i`-th entry
22 | corresponds to the number of top-`ks[i]` correct predictions.
23 | """
24 | assert preds.size(0) == labels.size(
25 | 0
26 | ), "Batch dim of predictions and labels must match"
27 | # Find the top max_k predictions for each sample
28 | _top_max_k_vals, top_max_k_inds = torch.topk(
29 | preds, max(ks), dim=1, largest=True, sorted=True
30 | )
31 | # (batch_size, max_k) -> (max_k, batch_size).
32 | top_max_k_inds = top_max_k_inds.t()
33 | # (batch_size, ) -> (max_k, batch_size).
34 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
35 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct.
36 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
37 | # Compute the number of topk correct predictions for each k.
38 | topks_correct = [top_max_k_correct[:k, :].float().sum() for k in ks]
39 | return topks_correct
40 |
41 |
42 | def topk_errors(preds, labels, ks):
43 | """
44 | Computes the top-k error for each k.
45 | Args:
46 | preds (array): array of predictions. Dimension is N.
47 | labels (array): array of labels. Dimension is N.
48 | ks (list): list of ks to calculate the top accuracies.
49 | """
50 | num_topks_correct = topks_correct(preds, labels, ks)
51 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct]
52 |
53 |
54 | def topk_accuracies(preds, labels, ks):
55 | """
56 | Computes the top-k accuracy for each k.
57 | Args:
58 | preds (array): array of predictions. Dimension is N.
59 | labels (array): array of labels. Dimension is N.
60 | ks (list): list of ks to calculate the top accuracies.
61 | """
62 | num_topks_correct = topks_correct(preds, labels, ks)
63 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct]
64 |
65 | def multitask_topks_correct(preds, labels, ks=(1,)):
66 | """
67 | Args:
68 | preds: tuple(torch.FloatTensor), each tensor should be of shape
69 | [batch_size, class_count], class_count can vary on a per task basis, i.e.
70 | outputs[i].shape[1] can be different to outputs[j].shape[j].
71 | labels: tuple(torch.LongTensor), each tensor should be of shape [batch_size]
72 | ks: tuple(int), compute accuracy at top-k for the values of k specified
73 | in this parameter.
74 | Returns:
75 | tuple(float), same length at topk with the corresponding accuracy@k in.
76 | """
77 | max_k = int(np.max(ks))
78 | task_count = len(preds)
79 | batch_size = labels[0].size(0)
80 | all_correct = torch.zeros(max_k, batch_size).type(torch.ByteTensor)
81 | if torch.cuda.is_available():
82 | all_correct = all_correct.cuda()
83 | for output, label in zip(preds, labels):
84 | _, max_k_idx = output.topk(max_k, dim=1, largest=True, sorted=True)
85 | # Flip batch_size, class_count as .view doesn't work on non-contiguous
86 | max_k_idx = max_k_idx.t()
87 | correct_for_task = max_k_idx.eq(label.view(1, -1).expand_as(max_k_idx))
88 | all_correct.add_(correct_for_task)
89 |
90 | multitask_topks_correct = [
91 | torch.ge(all_correct[:k].float().sum(0), task_count).float().sum(0) for k in ks
92 | ]
93 |
94 | return multitask_topks_correct
95 |
--------------------------------------------------------------------------------
/datasets/preprocessing/resize_videos.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import os.path as osp
5 | import sys
6 | from multiprocessing import Pool
7 |
8 |
9 | def resize_videos(vid_item):
10 | """Generate resized video cache.
11 |
12 | Args:
13 | vid_item (list): Video item containing video full path,
14 | video relative path.
15 |
16 | Returns:
17 | bool: Whether generate video cache successfully.
18 | """
19 | full_path, vid_path = vid_item
20 | out_full_path = osp.join(args.out_dir, vid_path)
21 | dir_name = osp.dirname(vid_path)
22 | out_dir = osp.join(args.out_dir, dir_name)
23 | if not osp.exists(out_dir):
24 | os.makedirs(out_dir)
25 | result = os.popen(
26 | f'ffprobe -hide_banner -loglevel error -select_streams v:0 -show_entries stream=width,height -of csv=p=0 {full_path}' # noqa:E501
27 | )
28 | try:
29 | w, h = [int(d) for d in result.readline().rstrip().split(',')]
30 | if w > h:
31 | cmd = (f'ffmpeg -hide_banner -loglevel error -i {full_path} '
32 | f'-vf {"mpdecimate," if args.remove_dup else ""}'
33 | f'scale=-2:{args.scale} '
34 | f'{"-vsync vfr" if args.remove_dup else ""} '
35 | f'-c:v libx264 {"-g 16" if args.dense else ""} '
36 | f'-an {out_full_path} -y')
37 | else:
38 | cmd = (f'ffmpeg -hide_banner -loglevel error -i {full_path} '
39 | f'-vf {"mpdecimate," if args.remove_dup else ""}'
40 | f'scale={args.scale}:-2 '
41 | f'{"-vsync vfr" if args.remove_dup else ""} '
42 | f'-c:v libx264 {"-g 16" if args.dense else ""} '
43 | f'-an {out_full_path} -y')
44 | os.popen(cmd)
45 | print(f'{vid_path} done')
46 | sys.stdout.flush()
47 | except Exception as e:
48 | print(e)
49 | return False
50 |
51 | return True
52 |
53 |
54 | def parse_args():
55 | parser = argparse.ArgumentParser(
56 | description='Generate the resized cache of original videos')
57 | parser.add_argument('src_dir', type=str, help='source video directory')
58 | parser.add_argument('out_dir', type=str, help='output video directory')
59 | parser.add_argument('--save', type=str, help='path to save output')
60 | parser.add_argument(
61 | '--dense',
62 | action='store_true',
63 | help='whether to generate a faster cache')
64 | parser.add_argument(
65 | '--level',
66 | type=int,
67 | choices=[1, 2],
68 | default=2,
69 | help='directory level of data')
70 | parser.add_argument(
71 | '--remove-dup',
72 | action='store_true',
73 | help='whether to remove duplicated frames')
74 | parser.add_argument(
75 | '--ext',
76 | type=str,
77 | default='mp4',
78 | choices=['avi', 'mp4', 'webm', 'mkv'],
79 | help='video file extensions')
80 | parser.add_argument(
81 | '--scale',
82 | type=int,
83 | default=256,
84 | help='resize image short side length keeping ratio')
85 | parser.add_argument(
86 | '--num-worker', type=int, default=8, help='number of workers')
87 | args = parser.parse_args()
88 |
89 | return args
90 |
91 |
92 | if __name__ == '__main__':
93 | args = parse_args()
94 |
95 | if not osp.isdir(args.out_dir):
96 | print(f'Creating folder: {args.out_dir}')
97 | os.makedirs(args.out_dir)
98 |
99 | print('Reading videos from folder: ', args.src_dir)
100 | print('Extension of videos: ', args.ext)
101 | fullpath_list = glob.glob(args.src_dir + '/*' * args.level + '.' +
102 | args.ext)
103 | done_fullpath_list = glob.glob(args.out_dir + '/*' * args.level + args.ext)
104 | print('Total number of videos found: ', len(fullpath_list))
105 | print('Total number of videos transfer finished: ',
106 | len(done_fullpath_list))
107 | if args.level == 2:
108 | vid_list = list(
109 | map(
110 | lambda p: osp.join(
111 | osp.basename(osp.dirname(p)), osp.basename(p)),
112 | fullpath_list))
113 | elif args.level == 1:
114 | vid_list = list(map(osp.basename, fullpath_list))
115 | pool = Pool(args.num_worker)
116 | out = pool.map(resize_videos, zip(fullpath_list, vid_list))
117 | bad_videos = [osp.basename(x) for x, y in zip(fullpath_list, out) if not y]
118 | with open(f"{args.save}", "w") as fo:
119 | fo.writelines([f"{x}\n" for x in bad_videos])
120 |
--------------------------------------------------------------------------------
/datasets/preprocessing/flow_vis.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 |
21 | def make_colorwheel():
22 | """
23 | Generates a color wheel for optical flow visualization as presented in:
24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
26 |
27 | Code follows the original C++ source code of Daniel Scharstein.
28 | Code follows the the Matlab source code of Deqing Sun.
29 |
30 | Returns:
31 | np.ndarray: Color wheel
32 | """
33 |
34 | RY = 15
35 | YG = 6
36 | GC = 4
37 | CB = 11
38 | BM = 13
39 | MR = 6
40 |
41 | ncols = RY + YG + GC + CB + BM + MR
42 | colorwheel = np.zeros((ncols, 3))
43 | col = 0
44 |
45 | # RY
46 | colorwheel[0:RY, 0] = 255
47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
48 | col = col + RY
49 | # YG
50 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
51 | colorwheel[col:col + YG, 1] = 255
52 | col = col + YG
53 | # GC
54 | colorwheel[col:col + GC, 1] = 255
55 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
56 | col = col + GC
57 | # CB
58 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
59 | colorwheel[col:col + CB, 2] = 255
60 | col = col + CB
61 | # BM
62 | colorwheel[col:col + BM, 2] = 255
63 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
64 | col = col + BM
65 | # MR
66 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
67 | colorwheel[col:col + MR, 0] = 255
68 | return colorwheel
69 |
70 |
71 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
72 | """
73 | Applies the flow color wheel to (possibly clipped) flow components u and v.
74 |
75 | According to the C++ source code of Daniel Scharstein
76 | According to the Matlab source code of Deqing Sun
77 |
78 | Args:
79 | u (np.ndarray): Input horizontal flow of shape [H,W]
80 | v (np.ndarray): Input vertical flow of shape [H,W]
81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
82 |
83 | Returns:
84 | np.ndarray: Flow visualization image of shape [H,W,3]
85 | """
86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
87 | colorwheel = make_colorwheel() # shape [55x3]
88 | ncols = colorwheel.shape[0]
89 | rad = np.sqrt(np.square(u) + np.square(v))
90 | a = np.arctan2(-v, -u) / np.pi
91 | fk = (a + 1) / 2 * (ncols - 1)
92 | k0 = np.floor(fk).astype(np.int32)
93 | k1 = k0 + 1
94 | k1[k1 == ncols] = 0
95 | f = fk - k0
96 | for i in range(colorwheel.shape[1]):
97 | tmp = colorwheel[:, i]
98 | col0 = tmp[k0] / 255.0
99 | col1 = tmp[k1] / 255.0
100 | col = (1 - f) * col0 + f * col1
101 | idx = (rad <= 1)
102 | col[idx] = 1 - rad[idx] * (1 - col[idx])
103 | col[~idx] = col[~idx] * 0.75 # out of range
104 | # Note the 2-i => BGR instead of RGB
105 | ch_idx = 2 - i if convert_to_bgr else i
106 | flow_image[:, :, ch_idx] = np.floor(255 * col)
107 | return flow_image
108 |
109 |
110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
111 | """
112 | Expects a two dimensional flow image of shape.
113 |
114 | Args:
115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
118 |
119 | Returns:
120 | np.ndarray: Flow visualization image of shape [H,W,3]
121 | """
122 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
123 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
124 | if clip_flow is not None:
125 | flow_uv = np.clip(flow_uv, 0, clip_flow)
126 | u = flow_uv[:, :, 0]
127 | v = flow_uv[:, :, 1]
128 | rad = np.sqrt(np.square(u) + np.square(v))
129 | rad_max = np.max(rad)
130 | epsilon = 1e-5
131 | u = u / (rad_max + epsilon)
132 | v = v / (rad_max + epsilon)
133 | return flow_uv_to_colors(u, v, convert_to_bgr)
134 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Supervised Video Transformer (CVPR'22-Oral)
2 |
3 | [](https://paperswithcode.com/sota/self-supervised-action-recognition-linear-on-3?p=self-supervised-video-transformer)
4 | [](https://paperswithcode.com/sota/self-supervised-action-recognition-linear-on?p=self-supervised-video-transformer)
5 | [](https://paperswithcode.com/sota/self-supervised-action-recognition-linear-on-1?p=self-supervised-video-transformer)
6 |
7 | [](https://paperswithcode.com/sota/action-recognition-in-videos-on-ucf101?p=self-supervised-video-transformer)
8 | [](https://paperswithcode.com/sota/action-recognition-in-videos-on-hmdb-51?p=self-supervised-video-transformer)
9 | [](https://paperswithcode.com/sota/action-recognition-in-videos-on-something?p=self-supervised-video-transformer)
10 | [](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=self-supervised-video-transformer)
11 |
12 | [Kanchana Ranasinghe](https://kahnchana.github.io),
13 | [Muzammal Naseer](https://muzammal-naseer.netlify.app/),
14 | [Salman Khan](https://salman-h-khan.github.io),
15 | [Fahad Shahbaz Khan](https://sites.google.com/view/fahadkhans/home),
16 | [Michael Ryoo](http://michaelryoo.com)
17 |
18 | **[Paper Link](https://arxiv.org/abs/2112.01514)** | **[Project Page](https://kahnchana.github.io/svt)**
19 |
20 |
21 | > **Abstract:**
22 | >*In this paper, we propose self-supervised training for video transformers using unlabelled video data. From a given video, we create local and global spatiotemporal views with varying spatial sizes and frame rates. Our self-supervised objective seeks to match the features of these different views representing the same video, to be invariant to spatiotemporal variations in actions. To the best of our knowledge, the proposed approach is the first to alleviate the dependency on negative samples or dedicated memory banks in Self-supervised Video Transformer (SVT). Further, owing to the flexibility of Transformer models, SVT supports slow-fast video processing within a single architecture using dynamically adjusted positional encodings and supports long-term relationship modeling along spatiotemporal dimensions. Our approach performs well on four action recognition benchmarks (Kinetics-400, UCF-101, HMDB-51, and SSv2) and converges faster with small batch sizes.*
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 | ## Usage & Data
32 | Refer to `requirements.txt` for installing all python dependencies. We use python 3.7 with pytorch 1.7.1.
33 |
34 | We download the official version of Kinetics-400 from [here](https://github.com/cvdfoundation/kinetics-dataset) and videos are resized using code [here](https://github.com/open-mmlab/mmaction2/tree/master/tools/data/kinetics).
35 |
36 |
37 | ## Self-supervised Training
38 | For self-supervised pre-training on models on the Kinetics-400 dataset, use the scripts in the `scripts` directory as follows. Change the paths to dataset as required.
39 |
40 | ```
41 | ./scripts/train.sh
42 | ```
43 |
44 |
45 | ## Downstream Evaluation
46 | Scripts to perform evaluation (linear or knn) on selected downstream tasks are as below. Paths to datasets and pre-trained models must be set appropriately. Note that in the case of linear evaluation, a linear layer will be fine-tuned on the new dataset and this training can be time-consuming on a single GPU.
47 |
48 | ```
49 | ./scripts/eval_linear.sh
50 | ./scripts/eval_knn.sh
51 | ```
52 |
53 |
54 | ## Pretrained Models
55 | Our pre-trained models can be found under [releases](https://github.com/kahnchana/svt/releases/tag/v1.0).
56 |
57 |
58 | ## Citation
59 | If you find our work, this repository, or pretrained models useful, please consider giving a star :star: and citation.
60 | ```bibtex
61 | @inproceedings{ranasinghe2022selfsupervised,
62 | title={Self-supervised Video Transformer},
63 | author={Kanchana Ranasinghe and Muzammal Naseer and Salman Khan and Fahad Shahbaz Khan and Michael Ryoo},
64 | booktitle={IEEE/CVF International Conference on Computer Vision and Pattern Recognition},
65 | month = {June},
66 | year={2022}
67 | }
68 | ```
69 |
70 |
71 | ## Acknowledgements
72 | Our code is based on [DINO](https://github.com/facebookresearch/dino) and [TimeSformer](https://github.com/facebookresearch/TimeSformer) repositories. We thank the authors for releasing their code. If you use our model, please consider citing these works as well.
73 |
--------------------------------------------------------------------------------
/datasets/loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | """Data loader."""
4 |
5 | import itertools
6 | import numpy as np
7 | import torch
8 | from torch.utils.data._utils.collate import default_collate
9 | from torch.utils.data.distributed import DistributedSampler
10 | from torch.utils.data.sampler import RandomSampler
11 |
12 | from timesformer.datasets.multigrid_helper import ShortCycleBatchSampler
13 |
14 | from . import data_utils as utils
15 | from .build import build_dataset
16 |
17 |
18 | def detection_collate(batch):
19 | """
20 | Collate function for detection task. Concatanate bboxes, labels and
21 | metadata from different samples in the first dimension instead of
22 | stacking them to have a batch-size dimension.
23 | Args:
24 | batch (tuple or list): data batch to collate.
25 | Returns:
26 | (tuple): collated detection data batch.
27 | """
28 | inputs, labels, video_idx, extra_data = zip(*batch)
29 | inputs, video_idx = default_collate(inputs), default_collate(video_idx)
30 | labels = torch.tensor(np.concatenate(labels, axis=0)).float()
31 |
32 | collated_extra_data = {}
33 | for key in extra_data[0].keys():
34 | data = [d[key] for d in extra_data]
35 | if key == "boxes" or key == "ori_boxes":
36 | # Append idx info to the bboxes before concatenating them.
37 | bboxes = [
38 | np.concatenate(
39 | [np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1
40 | )
41 | for i in range(len(data))
42 | ]
43 | bboxes = np.concatenate(bboxes, axis=0)
44 | collated_extra_data[key] = torch.tensor(bboxes).float()
45 | elif key == "metadata":
46 | collated_extra_data[key] = torch.tensor(
47 | list(itertools.chain(*data))
48 | ).view(-1, 2)
49 | else:
50 | collated_extra_data[key] = default_collate(data)
51 |
52 | return inputs, labels, video_idx, collated_extra_data
53 |
54 |
55 | def construct_loader(cfg, split, is_precise_bn=False):
56 | """
57 | Constructs the data loader for the given dataset.
58 | Args:
59 | cfg (CfgNode): configs. Details can be found in
60 | slowfast/config/defaults.py
61 | split (str): the split of the data loader. Options include `train`,
62 | `val`, and `test`.
63 | """
64 | assert split in ["train", "val", "test"]
65 | if split in ["train"]:
66 | dataset_name = cfg.TRAIN.DATASET
67 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
68 | shuffle = True
69 | drop_last = True
70 | elif split in ["val"]:
71 | dataset_name = cfg.TRAIN.DATASET
72 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
73 | shuffle = False
74 | drop_last = False
75 | elif split in ["test"]:
76 | dataset_name = cfg.TEST.DATASET
77 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
78 | shuffle = False
79 | drop_last = False
80 |
81 | # Construct the dataset
82 | dataset = build_dataset(dataset_name, cfg, split)
83 |
84 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn:
85 | # Create a sampler for multi-process training
86 | sampler = utils.create_sampler(dataset, shuffle, cfg)
87 | batch_sampler = ShortCycleBatchSampler(
88 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg
89 | )
90 | # Create a loader
91 | loader = torch.utils.data.DataLoader(
92 | dataset,
93 | batch_sampler=batch_sampler,
94 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
95 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
96 | worker_init_fn=utils.loader_worker_init_fn(dataset),
97 | )
98 | else:
99 | # Create a sampler for multi-process training
100 | sampler = utils.create_sampler(dataset, shuffle, cfg)
101 | # Create a loader
102 | loader = torch.utils.data.DataLoader(
103 | dataset,
104 | batch_size=batch_size,
105 | shuffle=(False if sampler else shuffle),
106 | sampler=sampler,
107 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
108 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
109 | drop_last=drop_last,
110 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
111 | worker_init_fn=utils.loader_worker_init_fn(dataset),
112 | )
113 | return loader
114 |
115 |
116 | def shuffle_dataset(loader, cur_epoch):
117 | """ "
118 | Shuffles the data.
119 | Args:
120 | loader (loader): data loader to perform shuffle.
121 | cur_epoch (int): number of the current epoch.
122 | """
123 | sampler = (
124 | loader.batch_sampler.sampler
125 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
126 | else loader.sampler
127 | )
128 | assert isinstance(
129 | sampler, (RandomSampler, DistributedSampler)
130 | ), "Sampler type '{}' not supported".format(type(sampler))
131 | # RandomSampler handles shuffling automatically
132 | if isinstance(sampler, DistributedSampler):
133 | # DistributedSampler shuffles data based on epoch
134 | sampler.set_epoch(cur_epoch)
135 |
--------------------------------------------------------------------------------
/models/vit_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Ross Wightman
2 | # Various utility functions
3 |
4 | import math
5 | import warnings
6 | from itertools import repeat
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | TORCH_MAJOR = int(torch.__version__.split('.')[0])
13 | TORCH_MINOR = int(torch.__version__.split('.')[1])
14 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
15 | from torch._six import container_abcs, int_classes
16 | else:
17 | import collections.abc as container_abcs
18 |
19 | DEFAULT_CROP_PCT = 0.875
20 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
21 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
22 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
23 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
24 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
25 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
26 |
27 |
28 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
29 | def norm_cdf(x):
30 | # Computes standard normal cumulative distribution function
31 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
32 |
33 | if (mean < a - 2 * std) or (mean > b + 2 * std):
34 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
35 | "The distribution of values may be incorrect.",
36 | stacklevel=2)
37 |
38 | with torch.no_grad():
39 | # Values are generated by using a truncated uniform distribution and
40 | # then using the inverse CDF for the normal distribution.
41 | # Get upper and lower cdf values
42 | l = norm_cdf((a - mean) / std)
43 | u = norm_cdf((b - mean) / std)
44 |
45 | # Uniformly fill tensor with values from [l, u], then translate to
46 | # [2l-1, 2u-1].
47 | tensor.uniform_(2 * l - 1, 2 * u - 1)
48 |
49 | # Use inverse cdf transform for normal distribution to get truncated
50 | # standard normal
51 | tensor.erfinv_()
52 |
53 | # Transform to proper mean, std
54 | tensor.mul_(std * math.sqrt(2.))
55 | tensor.add_(mean)
56 |
57 | # Clamp to ensure it's in the proper range
58 | tensor.clamp_(min=a, max=b)
59 | return tensor
60 |
61 |
62 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
63 | # type: (Tensor, float, float, float, float) -> Tensor
64 | r"""Fills the input Tensor with values drawn from a truncated
65 | normal distribution. The values are effectively drawn from the
66 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
67 | with values outside :math:`[a, b]` redrawn until they are within
68 | the bounds. The method used for generating the random values works
69 | best when :math:`a \leq \text{mean} \leq b`.
70 | Args:
71 | tensor: an n-dimensional `torch.Tensor`
72 | mean: the mean of the normal distribution
73 | std: the standard deviation of the normal distribution
74 | a: the minimum cutoff value
75 | b: the maximum cutoff value
76 | Examples:
77 | >>> w = torch.empty(3, 5)
78 | >>> nn.init.trunc_normal_(w)
79 | """
80 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
81 |
82 |
83 | # From PyTorch internals
84 | def _ntuple(n):
85 | def parse(x):
86 | if isinstance(x, container_abcs.Iterable):
87 | return x
88 | return tuple(repeat(x, n))
89 |
90 | return parse
91 |
92 |
93 | to_2tuple = _ntuple(2)
94 |
95 |
96 | # Calculate symmetric padding for a convolution
97 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
98 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
99 | return padding
100 |
101 |
102 | def get_padding_value(padding, kernel_size, **kwargs):
103 | dynamic = False
104 | if isinstance(padding, str):
105 | # for any string padding, the padding will be calculated for you, one of three ways
106 | padding = padding.lower()
107 | if padding == 'same':
108 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
109 | if is_static_pad(kernel_size, **kwargs):
110 | # static case, no extra overhead
111 | padding = get_padding(kernel_size, **kwargs)
112 | else:
113 | # dynamic 'SAME' padding, has runtime/GPU memory overhead
114 | padding = 0
115 | dynamic = True
116 | elif padding == 'valid':
117 | # 'VALID' padding, same as padding=0
118 | padding = 0
119 | else:
120 | # Default to PyTorch style 'same'-ish symmetric padding
121 | padding = get_padding(kernel_size, **kwargs)
122 | return padding, dynamic
123 |
124 |
125 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
126 | def get_same_padding(x: int, k: int, s: int, d: int):
127 | return max((int(math.ceil(x // s)) - 1) * s + (k - 1) * d + 1 - x, 0)
128 |
129 |
130 | # Can SAME padding for given args be done statically?
131 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
132 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
133 |
134 |
135 | # Dynamically pad input x with 'SAME' padding for conv with specified args
136 | # def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
137 | def pad_same(x, k, s, d=(1, 1), value=0):
138 | ih, iw = x.size()[-2:]
139 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
140 | if pad_h > 0 or pad_w > 0:
141 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
142 | return x
143 |
144 |
145 | def adaptive_pool_feat_mult(pool_type='avg'):
146 | if pool_type == 'catavgmax':
147 | return 2
148 | else:
149 | return 1
150 |
151 |
152 | def drop_path(x, drop_prob: float = 0., training: bool = False):
153 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
154 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
155 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
156 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
157 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
158 | 'survival rate' as the argument.
159 | """
160 | if drop_prob == 0. or not training:
161 | return x
162 | keep_prob = 1 - drop_prob
163 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
164 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
165 | random_tensor.floor_() # binarize
166 | output = x.div(keep_prob) * random_tensor
167 | return output
168 |
169 |
170 | class DropPath(nn.Module):
171 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
172 | """
173 |
174 | def __init__(self, drop_prob=None):
175 | super(DropPath, self).__init__()
176 | self.drop_prob = drop_prob
177 |
178 | def forward(self, x):
179 | return drop_path(x, self.drop_prob, self.training)
180 |
--------------------------------------------------------------------------------
/utils/meters.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | """Meters."""
4 |
5 | import datetime
6 |
7 | import numpy as np
8 | import torch
9 | from fvcore.common.timer import Timer
10 | from sklearn.metrics import average_precision_score
11 |
12 | import utils.logging as logging
13 | import utils.metrics as metrics
14 |
15 | logger = logging.get_logger(__name__)
16 |
17 |
18 | class TestMeter(object):
19 | """
20 | Perform the multi-view ensemble for testing: each video with an unique index
21 | will be sampled with multiple clips, and the predictions of the clips will
22 | be aggregated to produce the final prediction for the video.
23 | The accuracy is calculated with the given ground truth labels.
24 | """
25 |
26 | def __init__(
27 | self,
28 | num_videos,
29 | num_clips,
30 | num_cls,
31 | overall_iters,
32 | multi_label=False,
33 | ensemble_method="sum",
34 | ):
35 | """
36 | Construct tensors to store the predictions and labels. Expect to get
37 | num_clips predictions from each video, and calculate the metrics on
38 | num_videos videos.
39 | Args:
40 | num_videos (int): number of videos to test.
41 | num_clips (int): number of clips sampled from each video for
42 | aggregating the final prediction for the video.
43 | num_cls (int): number of classes for each prediction.
44 | overall_iters (int): overall iterations for testing.
45 | multi_label (bool): if True, use map as the metric.
46 | ensemble_method (str): method to perform the ensemble, options
47 | include "sum", and "max".
48 | """
49 |
50 | self.iter_timer = Timer()
51 | self.data_timer = Timer()
52 | self.net_timer = Timer()
53 | self.num_clips = num_clips
54 | self.overall_iters = overall_iters
55 | self.multi_label = multi_label
56 | self.ensemble_method = ensemble_method
57 | # Initialize tensors.
58 | self.video_preds = torch.zeros((num_videos, num_cls))
59 | if multi_label:
60 | self.video_preds -= 1e10
61 |
62 | self.video_labels = (
63 | torch.zeros((num_videos, num_cls))
64 | if multi_label
65 | else torch.zeros((num_videos)).long()
66 | )
67 | self.clip_count = torch.zeros((num_videos)).long()
68 | self.topk_accs = []
69 | self.stats = {}
70 |
71 | # Reset metric.
72 | self.reset()
73 |
74 | def reset(self):
75 | """
76 | Reset the metric.
77 | """
78 | self.clip_count.zero_()
79 | self.video_preds.zero_()
80 | if self.multi_label:
81 | self.video_preds -= 1e10
82 | self.video_labels.zero_()
83 |
84 | def update_stats(self, preds, labels, clip_ids):
85 | """
86 | Collect the predictions from the current batch and perform on-the-flight
87 | summation as ensemble.
88 | Args:
89 | preds (tensor): predictions from the current batch. Dimension is
90 | N x C where N is the batch size and C is the channel size
91 | (num_cls).
92 | labels (tensor): the corresponding labels of the current batch.
93 | Dimension is N.
94 | clip_ids (tensor): clip indexes of the current batch, dimension is
95 | N.
96 | """
97 | for ind in range(preds.shape[0]):
98 | vid_id = int(clip_ids[ind]) // self.num_clips
99 | if self.video_labels[vid_id].sum() > 0:
100 | assert torch.equal(
101 | self.video_labels[vid_id].type(torch.FloatTensor),
102 | labels[ind].type(torch.FloatTensor),
103 | )
104 | self.video_labels[vid_id] = labels[ind]
105 | if self.ensemble_method == "sum":
106 | self.video_preds[vid_id] += preds[ind]
107 | elif self.ensemble_method == "max":
108 | self.video_preds[vid_id] = torch.max(
109 | self.video_preds[vid_id], preds[ind]
110 | )
111 | else:
112 | raise NotImplementedError(
113 | "Ensemble Method {} is not supported".format(
114 | self.ensemble_method
115 | )
116 | )
117 | self.clip_count[vid_id] += 1
118 |
119 | def log_iter_stats(self, cur_iter):
120 | """
121 | Log the stats.
122 | Args:
123 | cur_iter (int): the current iteration of testing.
124 | """
125 | eta_sec = self.iter_timer.seconds() * (self.overall_iters - cur_iter)
126 | eta = str(datetime.timedelta(seconds=int(eta_sec)))
127 | stats = {
128 | "split": "test_iter",
129 | "cur_iter": "{}".format(cur_iter + 1),
130 | "eta": eta,
131 | "time_diff": self.iter_timer.seconds(),
132 | }
133 | logging.log_json_stats(stats)
134 |
135 | def iter_tic(self):
136 | """
137 | Start to record time.
138 | """
139 | self.iter_timer.reset()
140 | self.data_timer.reset()
141 |
142 | def iter_toc(self):
143 | """
144 | Stop to record time.
145 | """
146 | self.iter_timer.pause()
147 | self.net_timer.pause()
148 |
149 | def data_toc(self):
150 | self.data_timer.pause()
151 | self.net_timer.reset()
152 |
153 | def finalize_metrics(self, ks=(1, 5)):
154 | """
155 | Calculate and log the final ensembled metrics.
156 | ks (tuple): list of top-k values for topk_accuracies. For example,
157 | ks = (1, 5) correspods to top-1 and top-5 accuracy.
158 | """
159 | if not all(self.clip_count == self.num_clips):
160 | logger.warning(
161 | "clip count {} ~= num clips {}".format(
162 | ", ".join(
163 | [
164 | "{}: {}".format(i, k)
165 | for i, k in enumerate(self.clip_count.tolist())
166 | ]
167 | ),
168 | self.num_clips,
169 | )
170 | )
171 |
172 | self.stats = {"split": "test_final"}
173 | if self.multi_label:
174 | map = get_map(
175 | self.video_preds.cpu().numpy(), self.video_labels.cpu().numpy()
176 | )
177 | self.stats["map"] = map
178 | else:
179 | num_topks_correct = metrics.topks_correct(
180 | self.video_preds, self.video_labels, ks
181 | )
182 | topks = [
183 | (x / self.video_preds.size(0)) * 100.0
184 | for x in num_topks_correct
185 | ]
186 |
187 | assert len({len(ks), len(topks)}) == 1
188 | for k, topk in zip(ks, topks):
189 | self.stats["top{}_acc".format(k)] = "{:.{prec}f}".format(
190 | topk, prec=2
191 | )
192 | logging.log_json_stats(self.stats)
193 |
194 |
195 | def get_map(preds, labels):
196 | """
197 | Compute mAP for multi-label case.
198 | Args:
199 | preds (numpy tensor): num_examples x num_classes.
200 | labels (numpy tensor): num_examples x num_classes.
201 | Returns:
202 | mean_ap (int): final mAP score.
203 | """
204 |
205 | logger.info("Getting mAP for {} examples".format(preds.shape[0]))
206 |
207 | preds = preds[:, ~(np.all(labels == 0, axis=0))]
208 | labels = labels[:, ~(np.all(labels == 0, axis=0))]
209 | aps = [0]
210 | try:
211 | aps = average_precision_score(labels, preds, average=None)
212 | except ValueError:
213 | print(
214 | "Average precision requires a sufficient number of samples \
215 | in a batch which are missing in this sample."
216 | )
217 |
218 | mean_ap = np.mean(aps)
219 | return mean_ap
220 |
--------------------------------------------------------------------------------
/models/s3d.py:
--------------------------------------------------------------------------------
1 | # modified from https://raw.githubusercontent.com/qijiezhao/s3d.pytorch/master/S3DG_Pytorch.py
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | ## pytorch default: torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
7 | ## tensorflow s3d code: torch.nn.BatchNorm3d(num_features, eps=1e-3, momentum=0.001, affine=True, track_running_stats=True)
8 |
9 | class BasicConv3d(nn.Module):
10 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
11 | super(BasicConv3d, self).__init__()
12 | self.conv = nn.Conv3d(in_planes, out_planes,
13 | kernel_size=kernel_size, stride=stride,
14 | padding=padding, bias=False)
15 |
16 | # self.bn = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
17 | self.bn = nn.BatchNorm3d(out_planes)
18 | self.relu = nn.ReLU(inplace=True)
19 |
20 | # init
21 | self.conv.weight.data.normal_(mean=0, std=0.01) # original s3d is truncated normal within 2 std
22 | self.bn.weight.data.fill_(1)
23 | self.bn.bias.data.zero_()
24 |
25 | def forward(self, x):
26 | x = self.conv(x)
27 | x = self.bn(x)
28 | x = self.relu(x)
29 | return x
30 |
31 |
32 | class STConv3d(nn.Module):
33 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
34 | super(STConv3d, self).__init__()
35 | if isinstance(stride, tuple):
36 | t_stride = stride[0]
37 | stride = stride[-1]
38 | else: # int
39 | t_stride = stride
40 |
41 | self.conv1 = nn.Conv3d(in_planes, out_planes, kernel_size=(1, kernel_size, kernel_size),
42 | stride=(1, stride, stride), padding=(0, padding, padding), bias=False)
43 | self.conv2 = nn.Conv3d(out_planes, out_planes, kernel_size=(kernel_size, 1, 1),
44 | stride=(t_stride, 1, 1), padding=(padding, 0, 0), bias=False)
45 |
46 | # self.bn1=nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
47 | # self.bn2=nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
48 | self.bn1 = nn.BatchNorm3d(out_planes)
49 | self.bn2 = nn.BatchNorm3d(out_planes)
50 | self.relu = nn.ReLU(inplace=True)
51 |
52 | # init
53 | self.conv1.weight.data.normal_(mean=0, std=0.01) # original s3d is truncated normal within 2 std
54 | self.conv2.weight.data.normal_(mean=0, std=0.01) # original s3d is truncated normal within 2 std
55 | self.bn1.weight.data.fill_(1)
56 | self.bn1.bias.data.zero_()
57 | self.bn2.weight.data.fill_(1)
58 | self.bn2.bias.data.zero_()
59 |
60 | def forward(self, x):
61 | x = self.conv1(x)
62 | x = self.bn1(x)
63 | x = self.relu(x)
64 | x = self.conv2(x)
65 | x = self.bn2(x)
66 | x = self.relu(x)
67 | return x
68 |
69 |
70 | class SelfGating(nn.Module):
71 | def __init__(self, input_dim):
72 | super(SelfGating, self).__init__()
73 | self.fc = nn.Linear(input_dim, input_dim)
74 |
75 | def forward(self, input_tensor):
76 | """Feature gating as used in S3D-G"""
77 | spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4])
78 | weights = self.fc(spatiotemporal_average)
79 | weights = torch.sigmoid(weights)
80 | return weights[:, :, None, None, None] * input_tensor
81 |
82 |
83 | class SepInception(nn.Module):
84 | def __init__(self, in_planes, out_planes, gating=False):
85 | super(SepInception, self).__init__()
86 |
87 | assert len(out_planes) == 6
88 | assert isinstance(out_planes, list)
89 |
90 | [num_out_0_0a,
91 | num_out_1_0a, num_out_1_0b,
92 | num_out_2_0a, num_out_2_0b,
93 | num_out_3_0b] = out_planes
94 |
95 | self.branch0 = nn.Sequential(
96 | BasicConv3d(in_planes, num_out_0_0a, kernel_size=1, stride=1),
97 | )
98 | self.branch1 = nn.Sequential(
99 | BasicConv3d(in_planes, num_out_1_0a, kernel_size=1, stride=1),
100 | STConv3d(num_out_1_0a, num_out_1_0b, kernel_size=3, stride=1, padding=1),
101 | )
102 | self.branch2 = nn.Sequential(
103 | BasicConv3d(in_planes, num_out_2_0a, kernel_size=1, stride=1),
104 | STConv3d(num_out_2_0a, num_out_2_0b, kernel_size=3, stride=1, padding=1),
105 | )
106 | self.branch3 = nn.Sequential(
107 | nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
108 | BasicConv3d(in_planes, num_out_3_0b, kernel_size=1, stride=1),
109 | )
110 |
111 | self.out_channels = sum([num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b])
112 |
113 | self.gating = gating
114 | if gating:
115 | self.gating_b0 = SelfGating(num_out_0_0a)
116 | self.gating_b1 = SelfGating(num_out_1_0b)
117 | self.gating_b2 = SelfGating(num_out_2_0b)
118 | self.gating_b3 = SelfGating(num_out_3_0b)
119 |
120 | def forward(self, x):
121 | x0 = self.branch0(x)
122 | x1 = self.branch1(x)
123 | x2 = self.branch2(x)
124 | x3 = self.branch3(x)
125 | if self.gating:
126 | x0 = self.gating_b0(x0)
127 | x1 = self.gating_b1(x1)
128 | x2 = self.gating_b2(x2)
129 | x3 = self.gating_b3(x3)
130 |
131 | out = torch.cat((x0, x1, x2, x3), 1)
132 |
133 | return out
134 |
135 |
136 | class S3D(nn.Module):
137 |
138 | def __init__(self, input_channel=3, gating=False, slow=False):
139 | super(S3D, self).__init__()
140 | self.gating = gating
141 | self.slow = slow
142 |
143 | if slow:
144 | self.Conv_1a = STConv3d(input_channel, 64, kernel_size=7, stride=(1, 2, 2), padding=3)
145 | else: # normal
146 | self.Conv_1a = STConv3d(input_channel, 64, kernel_size=7, stride=2, padding=3)
147 |
148 | self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112)
149 |
150 | ###################################
151 |
152 | self.MaxPool_2a = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
153 | self.Conv_2b = BasicConv3d(64, 64, kernel_size=1, stride=1)
154 | self.Conv_2c = STConv3d(64, 192, kernel_size=3, stride=1, padding=1)
155 |
156 | self.block2 = nn.Sequential(
157 | self.MaxPool_2a, # (64, 32, 56, 56)
158 | self.Conv_2b, # (64, 32, 56, 56)
159 | self.Conv_2c) # (192, 32, 56, 56)
160 |
161 | ###################################
162 |
163 | self.MaxPool_3a = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
164 | self.Mixed_3b = SepInception(in_planes=192, out_planes=[64, 96, 128, 16, 32, 32], gating=gating)
165 | self.Mixed_3c = SepInception(in_planes=256, out_planes=[128, 128, 192, 32, 96, 64], gating=gating)
166 |
167 | self.block3 = nn.Sequential(
168 | self.MaxPool_3a, # (192, 32, 28, 28)
169 | self.Mixed_3b, # (256, 32, 28, 28)
170 | self.Mixed_3c) # (480, 32, 28, 28)
171 |
172 | ###################################
173 |
174 | self.MaxPool_4a = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
175 | self.Mixed_4b = SepInception(in_planes=480, out_planes=[192, 96, 208, 16, 48, 64], gating=gating)
176 | self.Mixed_4c = SepInception(in_planes=512, out_planes=[160, 112, 224, 24, 64, 64], gating=gating)
177 | self.Mixed_4d = SepInception(in_planes=512, out_planes=[128, 128, 256, 24, 64, 64], gating=gating)
178 | self.Mixed_4e = SepInception(in_planes=512, out_planes=[112, 144, 288, 32, 64, 64], gating=gating)
179 | self.Mixed_4f = SepInception(in_planes=528, out_planes=[256, 160, 320, 32, 128, 128], gating=gating)
180 |
181 | self.block4 = nn.Sequential(
182 | self.MaxPool_4a, # (480, 16, 14, 14)
183 | self.Mixed_4b, # (512, 16, 14, 14)
184 | self.Mixed_4c, # (512, 16, 14, 14)
185 | self.Mixed_4d, # (512, 16, 14, 14)
186 | self.Mixed_4e, # (528, 16, 14, 14)
187 | self.Mixed_4f) # (832, 16, 14, 14)
188 |
189 | ###################################
190 |
191 | self.MaxPool_5a = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0))
192 | self.Mixed_5b = SepInception(in_planes=832, out_planes=[256, 160, 320, 32, 128, 128], gating=gating)
193 | self.Mixed_5c = SepInception(in_planes=832, out_planes=[384, 192, 384, 48, 128, 128], gating=gating)
194 |
195 | self.block5 = nn.Sequential(
196 | self.MaxPool_5a, # (832, 8, 7, 7)
197 | self.Mixed_5b, # (832, 8, 7, 7)
198 | self.Mixed_5c) # (1024, 8, 7, 7)
199 |
200 | ###################################
201 |
202 | # self.AvgPool_0a = nn.AvgPool3d(kernel_size=(2, 7, 7), stride=1)
203 | # self.Dropout_0b = nn.Dropout3d(dropout_keep_prob)
204 | # self.Conv_0c = nn.Conv3d(1024, num_classes, kernel_size=1, stride=1, bias=True)
205 |
206 | # self.classifier = nn.Sequential(
207 | # self.AvgPool_0a,
208 | # self.Dropout_0b,
209 | # self.Conv_0c)
210 |
211 | def forward(self, x):
212 | x = self.block1(x)
213 | x = self.block2(x)
214 | x = self.block3(x)
215 | x = self.block4(x)
216 | x = self.block5(x)
217 | return x
218 |
219 |
220 | if __name__ == '__main__':
221 | model = S3D(num_classes=400)
222 |
--------------------------------------------------------------------------------
/models/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | # Copyright 2020 Ross Wightman
3 | # Modified model creation / weight loading / state_dict helpers
4 |
5 | import logging
6 | import math
7 | import os
8 | from collections import OrderedDict
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | import torch.utils.model_zoo as model_zoo
13 |
14 | _logger = logging.getLogger(__name__)
15 |
16 |
17 | def load_state_dict(checkpoint_path, use_ema=False):
18 | if checkpoint_path and os.path.isfile(checkpoint_path):
19 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
20 | elif checkpoint_path.startswith("https://"):
21 | checkpoint = torch.hub.load_state_dict_from_url(checkpoint_path, map_location='cpu')
22 | else:
23 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
24 | raise FileNotFoundError()
25 |
26 | state_dict_key = 'state_dict'
27 | if isinstance(checkpoint, dict):
28 | if use_ema and 'state_dict_ema' in checkpoint:
29 | state_dict_key = 'state_dict_ema'
30 | if state_dict_key and state_dict_key in checkpoint:
31 | new_state_dict = OrderedDict()
32 | for k, v in checkpoint[state_dict_key].items():
33 | # strip `module.` prefix
34 | name = k[7:] if k.startswith('module') else k
35 | new_state_dict[name] = v
36 | state_dict = new_state_dict
37 | elif 'model_state' in checkpoint:
38 | state_dict_key = 'model_state'
39 | new_state_dict = OrderedDict()
40 | for k, v in checkpoint[state_dict_key].items():
41 | # strip `model.` prefix
42 | name = k[6:] if k.startswith('model') else k
43 | new_state_dict[name] = v
44 | state_dict = new_state_dict
45 | else:
46 | state_dict = checkpoint
47 | _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
48 | return state_dict
49 |
50 |
51 | def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
52 | state_dict = load_state_dict(checkpoint_path, use_ema)
53 | model.load_state_dict(state_dict, strict=strict)
54 |
55 |
56 | def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
57 | resume_epoch = None
58 | if os.path.isfile(checkpoint_path):
59 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
60 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
61 | if log_info:
62 | _logger.info('Restoring model state from checkpoint...')
63 | new_state_dict = OrderedDict()
64 | for k, v in checkpoint['state_dict'].items():
65 | name = k[7:] if k.startswith('module') else k
66 | new_state_dict[name] = v
67 | model.load_state_dict(new_state_dict)
68 |
69 | if optimizer is not None and 'optimizer' in checkpoint:
70 | if log_info:
71 | _logger.info('Restoring optimizer state from checkpoint...')
72 | optimizer.load_state_dict(checkpoint['optimizer'])
73 |
74 | if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
75 | if log_info:
76 | _logger.info('Restoring AMP loss scaler state from checkpoint...')
77 | loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
78 |
79 | if 'epoch' in checkpoint:
80 | resume_epoch = checkpoint['epoch']
81 | if 'version' in checkpoint and checkpoint['version'] > 1:
82 | resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
83 |
84 | if log_info:
85 | _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
86 | else:
87 | model.load_state_dict(checkpoint)
88 | if log_info:
89 | _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
90 | return resume_epoch
91 | else:
92 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
93 | raise FileNotFoundError()
94 |
95 |
96 | def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_frames=8,
97 | num_patches=196, attention_type='divided_space_time', pretrained_model="", strict=True):
98 | if cfg is None:
99 | cfg = getattr(model, 'default_cfg')
100 | if cfg is None or 'url' not in cfg or not cfg['url']:
101 | _logger.warning("Pretrained model URL is invalid, using random initialization.")
102 | return
103 |
104 | if len(pretrained_model) == 0:
105 | state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
106 | else:
107 | try:
108 | state_dict = load_state_dict(pretrained_model)['model']
109 | except:
110 | state_dict = load_state_dict(pretrained_model)
111 |
112 | if filter_fn is not None:
113 | state_dict = filter_fn(state_dict)
114 |
115 | if in_chans == 1:
116 | conv1_name = cfg['first_conv']
117 | _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
118 | conv1_weight = state_dict[conv1_name + '.weight']
119 | conv1_type = conv1_weight.dtype
120 | conv1_weight = conv1_weight.float()
121 | O, I, J, K = conv1_weight.shape
122 | if I > 3:
123 | assert conv1_weight.shape[1] % 3 == 0
124 | # For models with space2depth stems
125 | conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
126 | conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
127 | else:
128 | conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
129 | conv1_weight = conv1_weight.to(conv1_type)
130 | state_dict[conv1_name + '.weight'] = conv1_weight
131 | elif in_chans != 3:
132 | conv1_name = cfg['first_conv']
133 | conv1_weight = state_dict[conv1_name + '.weight']
134 | conv1_type = conv1_weight.dtype
135 | conv1_weight = conv1_weight.float()
136 | O, I, J, K = conv1_weight.shape
137 | if I != 3:
138 | _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
139 | del state_dict[conv1_name + '.weight']
140 | strict = False
141 | else:
142 | _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
143 | repeat = int(math.ceil(in_chans / 3))
144 | conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
145 | conv1_weight *= (3 / float(in_chans))
146 | conv1_weight = conv1_weight.to(conv1_type)
147 | state_dict[conv1_name + '.weight'] = conv1_weight
148 |
149 | classifier_name = cfg['classifier']
150 | if num_classes == 1000 and cfg['num_classes'] == 1001:
151 | # special case for imagenet trained models with extra background class in pretrained weights
152 | classifier_weight = state_dict[classifier_name + '.weight']
153 | state_dict[classifier_name + '.weight'] = classifier_weight[1:]
154 | classifier_bias = state_dict[classifier_name + '.bias']
155 | state_dict[classifier_name + '.bias'] = classifier_bias[1:]
156 | elif not (classifier_name + '.weight') in state_dict:
157 | pass
158 | elif num_classes != state_dict[classifier_name + '.weight'].size(0):
159 | # print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True)
160 | # completely discard fully connected for all other differences between pretrained and created model
161 | del state_dict[classifier_name + '.weight']
162 | del state_dict[classifier_name + '.bias']
163 | strict = False
164 |
165 | ## Resizing the positional embeddings in case they don't match
166 | if num_patches + 1 != state_dict['pos_embed'].size(1):
167 | pos_embed = state_dict['pos_embed']
168 | cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)
169 | other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)
170 | new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')
171 | new_pos_embed = new_pos_embed.transpose(1, 2)
172 | new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
173 | state_dict['pos_embed'] = new_pos_embed
174 |
175 | ## Resizing time embeddings in case they don't match
176 | if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1):
177 | time_embed = state_dict['time_embed'].transpose(1, 2)
178 | new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest')
179 | state_dict['time_embed'] = new_time_embed.transpose(1, 2)
180 |
181 | ## Initializing temporal attention
182 | if attention_type == 'divided_space_time':
183 | new_state_dict = state_dict.copy()
184 | for key in state_dict:
185 | if 'blocks' in key and 'attn' in key:
186 | new_key = key.replace('attn', 'temporal_attn')
187 | if not new_key in state_dict:
188 | new_state_dict[new_key] = state_dict[key]
189 | else:
190 | new_state_dict[new_key] = state_dict[new_key]
191 | if 'blocks' in key and 'norm1' in key:
192 | new_key = key.replace('norm1', 'temporal_norm1')
193 | if not new_key in state_dict:
194 | new_state_dict[new_key] = state_dict[key]
195 | else:
196 | new_state_dict[new_key] = state_dict[new_key]
197 | state_dict = new_state_dict
198 |
199 | ## Loading the weights
200 | msg = model.load_state_dict(state_dict, strict=False)
201 | print(f"Loaded model inside with msg: {msg}")
202 |
--------------------------------------------------------------------------------
/datasets/ssv2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 | import json
4 | import numpy as np
5 | import os
6 | import random
7 | from itertools import chain as chain
8 | import torch
9 | import torch.utils.data
10 | from fvcore.common.file_io import PathManager
11 |
12 | import timesformer.utils.logging as logging
13 |
14 | from . import data_utils as utils
15 | from .build import DATASET_REGISTRY
16 |
17 | logger = logging.get_logger(__name__)
18 |
19 |
20 | @DATASET_REGISTRY.register()
21 | class Ssv2(torch.utils.data.Dataset):
22 | """
23 | Something-Something v2 (SSV2) video loader. Construct the SSV2 video loader,
24 | then sample clips from the videos. For training and validation, a single
25 | clip is randomly sampled from every video with random cropping, scaling, and
26 | flipping. For testing, multiple clips are uniformaly sampled from every
27 | video with uniform cropping. For uniform cropping, we take the left, center,
28 | and right crop if the width is larger than height, or take top, center, and
29 | bottom crop if the height is larger than the width.
30 | """
31 |
32 | def __init__(self, cfg, mode, num_retries=10):
33 | """
34 | Load Something-Something V2 data (frame paths, labels, etc. ) to a given
35 | Dataset object. The dataset could be downloaded from Something-Something
36 | official website (https://20bn.com/datasets/something-something).
37 | Please see datasets/DATASET.md for more information about the data format.
38 | Args:
39 | cfg (CfgNode): configs.
40 | mode (string): Options includes `train`, `val`, or `test` mode.
41 | For the train and val mode, the data loader will take data
42 | from the train or val set, and sample one clip per video.
43 | For the test mode, the data loader will take data from test set,
44 | and sample multiple clips per video.
45 | num_retries (int): number of retries for reading frames from disk.
46 | """
47 | # Only support train, val, and test mode.
48 | assert mode in [
49 | "train",
50 | "val",
51 | "test",
52 | ], "Split '{}' not supported for Something-Something V2".format(mode)
53 | self.mode = mode
54 | self.cfg = cfg
55 |
56 | self._video_meta = {}
57 | self._num_retries = num_retries
58 | # For training or validation mode, one single clip is sampled from every
59 | # video. For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every
60 | # video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from
61 | # the frames.
62 | if self.mode in ["train", "val"]:
63 | self._num_clips = 1
64 | elif self.mode in ["test"]:
65 | self._num_clips = (
66 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS
67 | )
68 |
69 | logger.info("Constructing Something-Something V2 {}...".format(mode))
70 | self._construct_loader()
71 |
72 | def _construct_loader(self):
73 | """
74 | Construct the video loader.
75 | """
76 | # Loading label names.
77 | with PathManager.open(
78 | os.path.join(
79 | self.cfg.DATA.PATH_TO_DATA_DIR,
80 | "something-something-v2-labels.json",
81 | ),
82 | "r",
83 | ) as f:
84 | label_dict = json.load(f)
85 |
86 | # Loading labels.
87 | label_file = os.path.join(
88 | self.cfg.DATA.PATH_TO_DATA_DIR,
89 | "something-something-v2-{}.json".format(
90 | "train" if self.mode == "train" else "validation"
91 | ),
92 | )
93 | with PathManager.open(label_file, "r") as f:
94 | label_json = json.load(f)
95 |
96 | self._video_names = []
97 | self._labels = []
98 | for video in label_json:
99 | video_name = video["id"]
100 | template = video["template"]
101 | template = template.replace("[", "")
102 | template = template.replace("]", "")
103 | label = int(label_dict[template])
104 | self._video_names.append(video_name)
105 | self._labels.append(label)
106 |
107 | path_to_file = os.path.join(
108 | self.cfg.DATA.PATH_TO_DATA_DIR,
109 | "{}.csv".format("train" if self.mode == "train" else "val"),
110 | )
111 | assert PathManager.exists(path_to_file), "{} dir not found".format(
112 | path_to_file
113 | )
114 |
115 | self._path_to_videos, _ = utils.load_image_lists(
116 | path_to_file, self.cfg.DATA.PATH_PREFIX
117 | )
118 |
119 | assert len(self._path_to_videos) == len(self._video_names), (
120 | len(self._path_to_videos),
121 | len(self._video_names),
122 | )
123 |
124 |
125 | # From dict to list.
126 | new_paths, new_labels = [], []
127 | for index in range(len(self._video_names)):
128 | if self._video_names[index] in self._path_to_videos:
129 | new_paths.append(self._path_to_videos[self._video_names[index]])
130 | new_labels.append(self._labels[index])
131 |
132 | self._labels = new_labels
133 | self._path_to_videos = new_paths
134 |
135 | # Extend self when self._num_clips > 1 (during testing).
136 | self._path_to_videos = list(
137 | chain.from_iterable(
138 | [[x] * self._num_clips for x in self._path_to_videos]
139 | )
140 | )
141 | self._labels = list(
142 | chain.from_iterable([[x] * self._num_clips for x in self._labels])
143 | )
144 | self._spatial_temporal_idx = list(
145 | chain.from_iterable(
146 | [
147 | range(self._num_clips)
148 | for _ in range(len(self._path_to_videos))
149 | ]
150 | )
151 | )
152 | logger.info(
153 | "Something-Something V2 dataloader constructed "
154 | " (size: {}) from {}".format(
155 | len(self._path_to_videos), path_to_file
156 | )
157 | )
158 |
159 | def __getitem__(self, index):
160 | """
161 | Given the video index, return the list of frames, label, and video
162 | index if the video frames can be fetched.
163 | Args:
164 | index (int): the video index provided by the pytorch sampler.
165 | Returns:
166 | frames (tensor): the frames of sampled from the video. The dimension
167 | is `channel` x `num frames` x `height` x `width`.
168 | label (int): the label of the current video.
169 | index (int): the index of the video.
170 | """
171 | short_cycle_idx = None
172 | # When short cycle is used, input index is a tupple.
173 | if isinstance(index, tuple):
174 | index, short_cycle_idx = index
175 |
176 | if self.mode in ["train", "val"]: #or self.cfg.MODEL.ARCH in ['resformer', 'vit']:
177 | # -1 indicates random sampling.
178 | spatial_sample_index = -1
179 | min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
180 | max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1]
181 | crop_size = self.cfg.DATA.TRAIN_CROP_SIZE
182 | if short_cycle_idx in [0, 1]:
183 | crop_size = int(
184 | round(
185 | self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx]
186 | * self.cfg.MULTIGRID.DEFAULT_S
187 | )
188 | )
189 | if self.cfg.MULTIGRID.DEFAULT_S > 0:
190 | # Decreasing the scale is equivalent to using a larger "span"
191 | # in a sampling grid.
192 | min_scale = int(
193 | round(
194 | float(min_scale)
195 | * crop_size
196 | / self.cfg.MULTIGRID.DEFAULT_S
197 | )
198 | )
199 | elif self.mode in ["test"]:
200 | # spatial_sample_index is in [0, 1, 2]. Corresponding to left,
201 | # center, or right if width is larger than height, and top, middle,
202 | # or bottom if height is larger than width.
203 | spatial_sample_index = (
204 | self._spatial_temporal_idx[index]
205 | % self.cfg.TEST.NUM_SPATIAL_CROPS
206 | )
207 | if self.cfg.TEST.NUM_SPATIAL_CROPS == 1:
208 | spatial_sample_index = 1
209 |
210 | min_scale, max_scale, crop_size = [self.cfg.DATA.TEST_CROP_SIZE] * 3
211 | # The testing is deterministic and no jitter should be performed.
212 | # min_scale, max_scale, and crop_size are expect to be the same.
213 | assert len({min_scale, max_scale, crop_size}) == 1
214 | else:
215 | raise NotImplementedError(
216 | "Does not support {} mode".format(self.mode)
217 | )
218 |
219 | label = self._labels[index]
220 |
221 | num_frames = self.cfg.DATA.NUM_FRAMES
222 | video_length = len(self._path_to_videos[index])
223 |
224 |
225 | seg_size = float(video_length - 1) / num_frames
226 | seq = []
227 | for i in range(num_frames):
228 | start = int(np.round(seg_size * i))
229 | end = int(np.round(seg_size * (i + 1)))
230 | if self.mode == "train":
231 | seq.append(random.randint(start, end))
232 | else:
233 | seq.append((start + end) // 2)
234 |
235 | frames = torch.as_tensor(
236 | utils.retry_load_images(
237 | [self._path_to_videos[index][frame] for frame in seq],
238 | self._num_retries,
239 | )
240 | )
241 |
242 | # Perform color normalization.
243 | frames = utils.tensor_normalize(
244 | frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD
245 | )
246 |
247 | # T H W C -> C T H W.
248 | frames = frames.permute(3, 0, 1, 2)
249 | frames = utils.spatial_sampling(
250 | frames,
251 | spatial_idx=spatial_sample_index,
252 | min_scale=min_scale,
253 | max_scale=max_scale,
254 | crop_size=crop_size,
255 | random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP,
256 | inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
257 | )
258 | #if not self.cfg.RESFORMER.ACTIVE:
259 | if not self.cfg.MODEL.ARCH in ['vit']:
260 | frames = utils.pack_pathway_output(self.cfg, frames)
261 | else:
262 | # Perform temporal sampling from the fast pathway.
263 | frames = torch.index_select(
264 | frames,
265 | 1,
266 | torch.linspace(
267 | 0, frames.shape[1] - 1, self.cfg.DATA.NUM_FRAMES
268 |
269 | ).long(),
270 | )
271 | return frames, label, index, {}
272 |
273 | def __len__(self):
274 | """
275 | Returns:
276 | (int): the number of videos in the dataset.
277 | """
278 | return len(self._path_to_videos)
279 |
--------------------------------------------------------------------------------
/eval_knn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 | import torch
18 | import torch.backends.cudnn as cudnn
19 | import torch.distributed as dist
20 | import torch.utils.data
21 | from torch import nn
22 |
23 | from datasets.hmdb51 import HMDB51
24 | from datasets.ucf101 import UCF101
25 | from models import get_vit_base_patch16_224
26 | from utils import utils
27 | from utils.parser import load_config
28 |
29 |
30 | def extract_feature_pipeline(args):
31 | # ============ preparing data ... ============
32 | config = load_config(args)
33 | # config.DATA.PATH_TO_DATA_DIR = f"{os.path.expanduser('~')}/repo/mmaction2/data/{args.dataset}/knn_splits"
34 | # config.DATA.PATH_PREFIX = f"{os.path.expanduser('~')}/repo/mmaction2/data/{args.dataset}/videos"
35 | config.TEST.NUM_SPATIAL_CROPS = 1
36 | if args.dataset == "ucf101":
37 | dataset_train = UCFReturnIndexDataset(cfg=config, mode="train", num_retries=10)
38 | dataset_val = UCFReturnIndexDataset(cfg=config, mode="val", num_retries=10)
39 | elif args.dataset == "hmdb51":
40 | dataset_train = HMDBReturnIndexDataset(cfg=config, mode="train", num_retries=10)
41 | dataset_val = HMDBReturnIndexDataset(cfg=config, mode="val", num_retries=10)
42 | else:
43 | raise NotImplementedError(f"invalid dataset: {args.dataset}")
44 |
45 | sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
46 | data_loader_train = torch.utils.data.DataLoader(
47 | dataset_train,
48 | sampler=sampler,
49 | batch_size=args.batch_size_per_gpu,
50 | num_workers=args.num_workers,
51 | pin_memory=True,
52 | drop_last=False,
53 | )
54 | data_loader_val = torch.utils.data.DataLoader(
55 | dataset_val,
56 | batch_size=args.batch_size_per_gpu,
57 | num_workers=args.num_workers,
58 | pin_memory=True,
59 | drop_last=False,
60 | )
61 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
62 |
63 | # ============ building network ... ============
64 | model = get_vit_base_patch16_224(cfg=config, no_head=True)
65 | ckpt = torch.load(args.pretrained_weights)
66 | # select_ckpt = "teacher"
67 | renamed_checkpoint = {x[len("backbone."):]: y for x, y in ckpt.items() if x.startswith("backbone.")}
68 | msg = model.load_state_dict(renamed_checkpoint, strict=False)
69 | print(f"Loaded model with msg: {msg}")
70 | model.cuda()
71 | model.eval()
72 |
73 | # ============ extract features ... ============
74 | print("Extracting features for train set...")
75 | train_features = extract_features(model, data_loader_train)
76 | print("Extracting features for val set...")
77 | test_features = extract_features(model, data_loader_val)
78 |
79 | if utils.get_rank() == 0:
80 | train_features = nn.functional.normalize(train_features, dim=1, p=2)
81 | test_features = nn.functional.normalize(test_features, dim=1, p=2)
82 |
83 | train_labels = torch.tensor([s for s in dataset_train._labels]).long()
84 | test_labels = torch.tensor([s for s in dataset_val._labels]).long()
85 | # save features and labels
86 | if args.dump_features and dist.get_rank() == 0:
87 | torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth"))
88 | torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth"))
89 | torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth"))
90 | torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth"))
91 | return train_features, test_features, train_labels, test_labels
92 |
93 |
94 | @torch.no_grad()
95 | def extract_features(model, data_loader):
96 | metric_logger = utils.MetricLogger(delimiter=" ")
97 | features = None
98 | for samples, index in metric_logger.log_every(data_loader, 10):
99 | samples = samples.cuda(non_blocking=True)
100 | index = index.cuda(non_blocking=True)
101 | feats = model(samples).clone()
102 |
103 | # init storage feature matrix
104 | if dist.get_rank() == 0 and features is None:
105 | features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
106 | # if args.use_cuda:
107 | features = features.cuda(non_blocking=True)
108 | print(f"Storing features into tensor of shape {features.shape}")
109 |
110 | # get indexes from all processes
111 | y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
112 | y_l = list(y_all.unbind(0))
113 | y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
114 | y_all_reduce.wait()
115 | index_all = torch.cat(y_l)
116 |
117 | # share features between processes
118 | feats_all = torch.empty(
119 | dist.get_world_size(),
120 | feats.size(0),
121 | feats.size(1),
122 | dtype=feats.dtype,
123 | device=feats.device,
124 | )
125 | output_l = list(feats_all.unbind(0))
126 | output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
127 | output_all_reduce.wait()
128 |
129 | # update storage feature matrix
130 | if dist.get_rank() == 0:
131 | # if args.use_cuda:
132 | features.index_copy_(0, index_all, torch.cat(output_l))
133 | # else:
134 | # features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
135 | return features
136 |
137 |
138 | @torch.no_grad()
139 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000):
140 | top1, top5, total = 0.0, 0.0, 0
141 | train_features = train_features.t()
142 | num_test_images, num_chunks = test_labels.shape[0], 100
143 | imgs_per_chunk = num_test_images // num_chunks
144 | retrieval_one_hot = torch.zeros(k, num_classes).cuda()
145 | for idx in range(0, num_test_images, imgs_per_chunk):
146 | # get the features for test images
147 | features = test_features[
148 | idx : min((idx + imgs_per_chunk), num_test_images), :
149 | ]
150 | targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)]
151 | batch_size = targets.shape[0]
152 |
153 | # calculate the dot product and compute top-k neighbors
154 | similarity = torch.mm(features, train_features)
155 | distances, indices = similarity.topk(k, largest=True, sorted=True)
156 | candidates = train_labels.view(1, -1).expand(batch_size, -1)
157 | retrieved_neighbors = torch.gather(candidates, 1, indices)
158 |
159 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
160 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
161 | distances_transform = distances.clone().div_(T).exp_()
162 | probs = torch.sum(
163 | torch.mul(
164 | retrieval_one_hot.view(batch_size, -1, num_classes),
165 | distances_transform.view(batch_size, -1, 1),
166 | ),
167 | 1,
168 | )
169 | _, predictions = probs.sort(1, True)
170 |
171 | # find the predictions that match the target
172 | correct = predictions.eq(targets.data.view(-1, 1))
173 | top1 = top1 + correct.narrow(1, 0, 1).sum().item()
174 | top5 = top5 + correct.narrow(1, 0, 5).sum().item()
175 | total += targets.size(0)
176 | top1 = top1 * 100.0 / total
177 | top5 = top5 * 100.0 / total
178 | return top1, top5
179 |
180 |
181 | class UCFReturnIndexDataset(UCF101):
182 | def __getitem__(self, idx):
183 | img, _, _, _ = super(UCFReturnIndexDataset, self).__getitem__(idx)
184 | return img, idx
185 |
186 |
187 | class HMDBReturnIndexDataset(HMDB51):
188 | def __getitem__(self, idx):
189 | img, _, _, _ = super(HMDBReturnIndexDataset, self).__getitem__(idx)
190 | return img, idx
191 |
192 | if __name__ == '__main__':
193 | parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet')
194 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
195 | parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int,
196 | help='Number of NN to use. 20 is usually working the best.')
197 | parser.add_argument('--temperature', default=0.07, type=float,
198 | help='Temperature used in the voting coefficient')
199 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
200 | parser.add_argument('--use_cuda', default=True, type=utils.bool_flag,
201 | help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM")
202 | parser.add_argument('--arch', default='vit_small', type=str,
203 | choices=['vit_tiny', 'vit_small', 'vit_base', 'timesformer'], help='Architecture (support only ViT atm).')
204 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
205 | parser.add_argument("--checkpoint_key", default="teacher", type=str,
206 | help='Key to use in the checkpoint (example: "teacher")')
207 | parser.add_argument('--dump_features', default=None,
208 | help='Path where to save computed features, empty for no saving')
209 | parser.add_argument('--load_features', default=None, help="""If the features have
210 | already been computed, where to find them.""")
211 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
212 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
213 | distributed training; see https://pytorch.org/docs/stable/distributed.html""")
214 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
215 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
216 |
217 | parser.add_argument('--dataset', default="ucf101", help='Dataset: ucf101 / hmdb51')
218 | parser.add_argument("--cfg", dest="cfg_file", help="Path to the config file", type=str,
219 | default="models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml")
220 | parser.add_argument("--opts", help="See utils/defaults.py for all options", default=None, nargs=argparse.REMAINDER)
221 |
222 | args = parser.parse_args()
223 |
224 | utils.init_distributed_mode(args)
225 | print("git:\n {}\n".format(utils.get_sha()))
226 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
227 | cudnn.benchmark = True
228 |
229 | if args.load_features:
230 | train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth"))
231 | test_features = torch.load(os.path.join(args.load_features, "testfeat.pth"))
232 | train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth"))
233 | test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth"))
234 | else:
235 | # need to extract features !
236 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args)
237 |
238 | if utils.get_rank() == 0:
239 | if args.use_cuda:
240 | train_features = train_features.cuda()
241 | test_features = test_features.cuda()
242 | train_labels = train_labels.cuda()
243 | test_labels = test_labels.cuda()
244 |
245 | print("Features are ready!\nStart the k-NN classification.")
246 | for k in args.nb_knn:
247 | top1, top5 = knn_classifier(train_features, train_labels,
248 | test_features, test_labels, k, args.temperature)
249 | print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
250 | dist.barrier()
251 |
--------------------------------------------------------------------------------
/datasets/ucf101.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import warnings
4 |
5 | import torch
6 | import torch.utils.data
7 |
8 | from datasets.data_utils import get_random_sampling_rate, tensor_normalize, spatial_sampling, pack_pathway_output
9 | from datasets.decoder import decode
10 | from datasets.video_container import get_video_container
11 | from datasets.transform import VideoDataAugmentationDINO
12 | from einops import rearrange
13 |
14 |
15 | class UCF101(torch.utils.data.Dataset):
16 | """
17 | UCF101 video loader. Construct the UCF101 video loader, then sample
18 | clips from the videos. For training and validation, a single clip is
19 | randomly sampled from every video with random cropping, scaling, and
20 | flipping. For testing, multiple clips are uniformaly sampled from every
21 | video with uniform cropping. For uniform cropping, we take the left, center,
22 | and right crop if the width is larger than height, or take top, center, and
23 | bottom crop if the height is larger than the width.
24 | """
25 |
26 | def __init__(self, cfg, mode, num_retries=10):
27 | """
28 | Construct the UCF101 video loader with a given csv file. The format of
29 | the csv file is:
30 | ```
31 | path_to_video_1 label_1
32 | path_to_video_2 label_2
33 | ...
34 | path_to_video_N label_N
35 | ```
36 | Args:
37 | cfg (CfgNode): configs.
38 | mode (string): Options includes `train`, `val`, or `test` mode.
39 | For the train mode, the data loader will take data from the
40 | train set, and sample one clip per video. For the val and
41 | test mode, the data loader will take data from relevent set,
42 | and sample multiple clips per video.
43 | num_retries (int): number of retries.
44 | """
45 | # Only support train, val, and test mode.
46 | assert mode in ["train", "val", "test"], "Split '{}' not supported for UCF101".format(mode)
47 | self.mode = mode
48 | self.cfg = cfg
49 |
50 | self._video_meta = {}
51 | self._num_retries = num_retries
52 | self._split_idx = mode
53 | # For training mode, one single clip is sampled from every video. For validation or testing, NUM_ENSEMBLE_VIEWS
54 | # clips are sampled from every video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from the frames.
55 | if self.mode in ["train"]:
56 | self._num_clips = 1
57 | elif self.mode in ["val", "test"]:
58 | self._num_clips = (
59 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS
60 | )
61 |
62 | print("Constructing UCF101 {}...".format(mode))
63 | self._construct_loader()
64 |
65 | def _construct_loader(self):
66 | """
67 | Construct the video loader.
68 | """
69 | path_to_file = os.path.join(
70 | self.cfg.DATA.PATH_TO_DATA_DIR, "ucf101_{}_split_1_videos.txt".format(self.mode)
71 | )
72 | assert os.path.exists(path_to_file), "{} dir not found".format(
73 | path_to_file
74 | )
75 |
76 | self._path_to_videos = []
77 | self._labels = []
78 | self._spatial_temporal_idx = []
79 | with open(path_to_file, "r") as f:
80 | for clip_idx, path_label in enumerate(f.read().splitlines()):
81 | assert (
82 | len(path_label.split(self.cfg.DATA.PATH_LABEL_SEPARATOR))
83 | == 2
84 | )
85 | path, label = path_label.split(
86 | self.cfg.DATA.PATH_LABEL_SEPARATOR
87 | )
88 | for idx in range(self._num_clips):
89 | self._path_to_videos.append(
90 | os.path.join(self.cfg.DATA.PATH_PREFIX, path)
91 | )
92 | self._labels.append(int(label))
93 | self._spatial_temporal_idx.append(idx)
94 | self._video_meta[clip_idx * self._num_clips + idx] = {}
95 | assert (len(self._path_to_videos) > 0), f"Failed to load UCF101 split {self._split_idx} from {path_to_file}"
96 | print(f"Constructing UCF101 dataloader (size: {len(self._path_to_videos)}) from {path_to_file}")
97 |
98 | def __getitem__(self, index):
99 | """
100 | Given the video index, return the list of frames, label, and video
101 | index if the video can be fetched and decoded successfully, otherwise
102 | repeatly find a random video that can be decoded as a replacement.
103 | Args:
104 | index (int): the video index provided by the pytorch sampler.
105 | Returns:
106 | frames (tensor): the frames of sampled from the video. The dimension
107 | is `channel` x `num frames` x `height` x `width`.
108 | label (int): the label of the current video.
109 | index (int): if the video provided by pytorch sampler can be
110 | decoded, then return the index of the video. If not, return the
111 | index of the video replacement that can be decoded.
112 | """
113 | short_cycle_idx = None
114 | # When short cycle is used, input index is a tupple.
115 | if isinstance(index, tuple):
116 | index, short_cycle_idx = index
117 |
118 | if self.mode in ["train"]:
119 | # -1 indicates random sampling.
120 | temporal_sample_index = -1
121 | spatial_sample_index = -1
122 | min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
123 | max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1]
124 | crop_size = self.cfg.DATA.TRAIN_CROP_SIZE
125 | if short_cycle_idx in [0, 1]:
126 | crop_size = int(
127 | round(
128 | self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx]
129 | * self.cfg.MULTIGRID.DEFAULT_S
130 | )
131 | )
132 | if self.cfg.MULTIGRID.DEFAULT_S > 0:
133 | # Decreasing the scale is equivalent to using a larger "span"
134 | # in a sampling grid.
135 | min_scale = int(
136 | round(
137 | float(min_scale)
138 | * crop_size
139 | / self.cfg.MULTIGRID.DEFAULT_S
140 | )
141 | )
142 | elif self.mode in ["val", "test"]:
143 | temporal_sample_index = (self._spatial_temporal_idx[index] // self.cfg.TEST.NUM_SPATIAL_CROPS)
144 | # spatial_sample_index is in [0, 1, 2]. Corresponding to left,
145 | # center, or right if width is larger than height, and top, middle,
146 | # or bottom if height is larger than width.
147 | spatial_sample_index = (
148 | (self._spatial_temporal_idx[index] % self.cfg.TEST.NUM_SPATIAL_CROPS)
149 | if self.cfg.TEST.NUM_SPATIAL_CROPS > 1 else 1
150 | )
151 | min_scale, max_scale, crop_size = (
152 | [self.cfg.DATA.TEST_CROP_SIZE] * 3 if self.cfg.TEST.NUM_SPATIAL_CROPS > 1
153 | else [self.cfg.DATA.TRAIN_JITTER_SCALES[0]] * 2 + [self.cfg.DATA.TEST_CROP_SIZE]
154 | )
155 | # The testing is deterministic and no jitter should be performed.
156 | # min_scale, max_scale, and crop_size are expect to be the same.
157 | assert len({min_scale, max_scale}) == 1
158 | else:
159 | raise NotImplementedError(
160 | "Does not support {} mode".format(self.mode)
161 | )
162 | sampling_rate = get_random_sampling_rate(
163 | self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE,
164 | self.cfg.DATA.SAMPLING_RATE,
165 | )
166 | # Try to decode and sample a clip from a video. If the video can not be
167 | # decoded, repeatedly find a random video replacement that can be decoded.
168 | for i_try in range(self._num_retries):
169 | video_container = None
170 | try:
171 | video_container = get_video_container(
172 | self._path_to_videos[index],
173 | self.cfg.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE,
174 | self.cfg.DATA.DECODING_BACKEND,
175 | )
176 | except Exception as e:
177 | print(
178 | "Failed to load video from {} with error {}".format(
179 | self._path_to_videos[index], e
180 | )
181 | )
182 | # Select a random video if the current video was not able to access.
183 | if video_container is None:
184 | warnings.warn(
185 | "Failed to meta load video idx {} from {}; trial {}".format(
186 | index, self._path_to_videos[index], i_try
187 | )
188 | )
189 | if self.mode not in ["val", "test"] and i_try > self._num_retries // 2:
190 | # let's try another one
191 | index = random.randint(0, len(self._path_to_videos) - 1)
192 | continue
193 |
194 | # Decode video. Meta info is used to perform selective decoding.
195 | frames = decode(
196 | container=video_container,
197 | sampling_rate=sampling_rate,
198 | num_frames=self.cfg.DATA.NUM_FRAMES,
199 | clip_idx=temporal_sample_index,
200 | num_clips=self.cfg.TEST.NUM_ENSEMBLE_VIEWS,
201 | video_meta=self._video_meta[index],
202 | target_fps=self.cfg.DATA.TARGET_FPS,
203 | backend=self.cfg.DATA.DECODING_BACKEND,
204 | max_spatial_scale=min_scale,
205 | )
206 |
207 | # If decoding failed (wrong format, video is too short, and etc),
208 | # select another video.
209 | if frames is None:
210 | warnings.warn(
211 | "Failed to decode video idx {} from {}; trial {}".format(
212 | index, self._path_to_videos[index], i_try
213 | )
214 | )
215 | if self.mode not in ["test"] and i_try > self._num_retries // 2:
216 | # let's try another one
217 | index = random.randint(0, len(self._path_to_videos) - 1)
218 | continue
219 |
220 | label = self._labels[index]
221 |
222 | # Perform color normalization.
223 | frames = tensor_normalize(
224 | frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD
225 | )
226 | frames = frames.permute(3, 0, 1, 2)
227 |
228 | # Perform data augmentation.
229 | frames = spatial_sampling(
230 | frames,
231 | spatial_idx=spatial_sample_index,
232 | min_scale=min_scale,
233 | max_scale=max_scale,
234 | crop_size=crop_size,
235 | random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP,
236 | inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
237 | )
238 |
239 | # if not self.cfg.MODEL.ARCH in ['vit']:
240 | # frames = pack_pathway_output(self.cfg, frames)
241 | # else:
242 | # Perform temporal sampling from the fast pathway.
243 | # frames = [torch.index_select(
244 | # x,
245 | # 1,
246 | # torch.linspace(
247 | # 0, x.shape[1] - 1, self.cfg.DATA.NUM_FRAMES
248 | # ).long(),
249 | # ) for x in frames]
250 |
251 | return frames, label, index, {}
252 | else:
253 | raise RuntimeError(
254 | "Failed to fetch video after {} retries.".format(
255 | self._num_retries
256 | )
257 | )
258 |
259 | def __len__(self):
260 | """
261 | Returns:
262 | (int): the number of videos in the dataset.
263 | """
264 | return len(self._path_to_videos)
265 |
266 |
267 | if __name__ == '__main__':
268 |
269 | from utils.parser import parse_args, load_config
270 | from tqdm import tqdm
271 |
272 | args = parse_args()
273 | args.cfg_file = "models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml"
274 | config = load_config(args)
275 | config.DATA.PATH_TO_DATA_DIR = "/home/kanchanaranasinghe/repo/mmaction2/data/ucf101/splits"
276 | config.DATA.PATH_PREFIX = "/home/kanchanaranasinghe/repo/mmaction2/data/ucf101/videos"
277 | dataset = UCF101(cfg=config, mode="train", num_retries=10)
278 | dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=4)
279 | print(f"Loaded train dataset of length: {len(dataset)}")
280 | for idx, i in enumerate(dataloader):
281 | print(idx, i[0].shape, i[1:])
282 | if idx > 2:
283 | break
284 |
285 | test_dataset = UCF101(cfg=config, mode="val", num_retries=10)
286 | test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=4)
287 | print(f"Loaded test dataset of length: {len(test_dataset)}")
288 | for idx, i in enumerate(test_dataloader):
289 | print(idx, i[0].shape, i[1:])
290 | if idx > 2:
291 | break
292 |
--------------------------------------------------------------------------------
/datasets/hmdb51.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import warnings
4 |
5 | import torch
6 | import torch.utils.data
7 |
8 | from datasets.data_utils import get_random_sampling_rate, tensor_normalize, spatial_sampling, pack_pathway_output
9 | from datasets.decoder import decode
10 | from datasets.video_container import get_video_container
11 | from datasets.transform import VideoDataAugmentationDINO
12 | from einops import rearrange
13 |
14 |
15 | class HMDB51(torch.utils.data.Dataset):
16 | """
17 | UCF101 video loader. Construct the UCF101 video loader, then sample
18 | clips from the videos. For training and validation, a single clip is
19 | randomly sampled from every video with random cropping, scaling, and
20 | flipping. For testing, multiple clips are uniformaly sampled from every
21 | video with uniform cropping. For uniform cropping, we take the left, center,
22 | and right crop if the width is larger than height, or take top, center, and
23 | bottom crop if the height is larger than the width.
24 | """
25 |
26 | def __init__(self, cfg, mode, num_retries=10):
27 | """
28 | Construct the UCF101 video loader with a given csv file. The format of
29 | the csv file is:
30 | ```
31 | path_to_video_1 label_1
32 | path_to_video_2 label_2
33 | ...
34 | path_to_video_N label_N
35 | ```
36 | Args:
37 | cfg (CfgNode): configs.
38 | mode (string): Options includes `train`, `val`, or `test` mode.
39 | For the train mode, the data loader will take data from the
40 | train set, and sample one clip per video. For the val and
41 | test mode, the data loader will take data from relevent set,
42 | and sample multiple clips per video.
43 | num_retries (int): number of retries.
44 | """
45 | # Only support train, val, and test mode.
46 | assert mode in ["train", "val", "test"], "Split '{}' not supported for UCF101".format(mode)
47 | self.mode = mode
48 | self.cfg = cfg
49 |
50 | self._video_meta = {}
51 | self._num_retries = num_retries
52 | self._split_idx = mode
53 | # For training mode, one single clip is sampled from every video. For validation or testing, NUM_ENSEMBLE_VIEWS
54 | # clips are sampled from every video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from the frames.
55 | if self.mode in ["train"]:
56 | self._num_clips = 1
57 | elif self.mode in ["val", "test"]:
58 | self._num_clips = (
59 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS
60 | )
61 |
62 | print("Constructing HMDB51 {}...".format(mode))
63 | self._construct_loader()
64 |
65 | def _construct_loader(self):
66 | """
67 | Construct the video loader.
68 | """
69 | path_to_file = os.path.join(
70 | self.cfg.DATA.PATH_TO_DATA_DIR, "hmdb51_{}_split_1_videos.txt".format(self.mode)
71 | )
72 | assert os.path.exists(path_to_file), "{} dir not found".format(
73 | path_to_file
74 | )
75 |
76 | self._path_to_videos = []
77 | self._labels = []
78 | self._spatial_temporal_idx = []
79 | with open(path_to_file, "r") as f:
80 | for clip_idx, path_label in enumerate(f.read().splitlines()):
81 | assert (
82 | len(path_label.split(self.cfg.DATA.PATH_LABEL_SEPARATOR))
83 | == 2
84 | )
85 | path, label = path_label.split(
86 | self.cfg.DATA.PATH_LABEL_SEPARATOR
87 | )
88 | for idx in range(self._num_clips):
89 | self._path_to_videos.append(
90 | os.path.join(self.cfg.DATA.PATH_PREFIX, path)
91 | )
92 | self._labels.append(int(label))
93 | self._spatial_temporal_idx.append(idx)
94 | self._video_meta[clip_idx * self._num_clips + idx] = {}
95 | assert (len(self._path_to_videos) > 0), f"Failed to load UCF101 split {self._split_idx} from {path_to_file}"
96 | print(f"Constructing HMDB51 dataloader (size: {len(self._path_to_videos)}) from {path_to_file}")
97 |
98 | def __getitem__(self, index):
99 | """
100 | Given the video index, return the list of frames, label, and video
101 | index if the video can be fetched and decoded successfully, otherwise
102 | repeatly find a random video that can be decoded as a replacement.
103 | Args:
104 | index (int): the video index provided by the pytorch sampler.
105 | Returns:
106 | frames (tensor): the frames of sampled from the video. The dimension
107 | is `channel` x `num frames` x `height` x `width`.
108 | label (int): the label of the current video.
109 | index (int): if the video provided by pytorch sampler can be
110 | decoded, then return the index of the video. If not, return the
111 | index of the video replacement that can be decoded.
112 | """
113 | short_cycle_idx = None
114 | # When short cycle is used, input index is a tupple.
115 | if isinstance(index, tuple):
116 | index, short_cycle_idx = index
117 |
118 | if self.mode in ["train"]:
119 | # -1 indicates random sampling.
120 | temporal_sample_index = -1
121 | spatial_sample_index = -1
122 | min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
123 | max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1]
124 | crop_size = self.cfg.DATA.TRAIN_CROP_SIZE
125 | if short_cycle_idx in [0, 1]:
126 | crop_size = int(
127 | round(
128 | self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx]
129 | * self.cfg.MULTIGRID.DEFAULT_S
130 | )
131 | )
132 | if self.cfg.MULTIGRID.DEFAULT_S > 0:
133 | # Decreasing the scale is equivalent to using a larger "span"
134 | # in a sampling grid.
135 | min_scale = int(
136 | round(
137 | float(min_scale)
138 | * crop_size
139 | / self.cfg.MULTIGRID.DEFAULT_S
140 | )
141 | )
142 | elif self.mode in ["val", "test"]:
143 | temporal_sample_index = (self._spatial_temporal_idx[index] // self.cfg.TEST.NUM_SPATIAL_CROPS)
144 | # spatial_sample_index is in [0, 1, 2]. Corresponding to left,
145 | # center, or right if width is larger than height, and top, middle,
146 | # or bottom if height is larger than width.
147 | spatial_sample_index = (
148 | (self._spatial_temporal_idx[index] % self.cfg.TEST.NUM_SPATIAL_CROPS)
149 | if self.cfg.TEST.NUM_SPATIAL_CROPS > 1 else 1
150 | )
151 | min_scale, max_scale, crop_size = (
152 | [self.cfg.DATA.TEST_CROP_SIZE] * 3 if self.cfg.TEST.NUM_SPATIAL_CROPS > 1
153 | else [self.cfg.DATA.TRAIN_JITTER_SCALES[0]] * 2 + [self.cfg.DATA.TEST_CROP_SIZE]
154 | )
155 | # The testing is deterministic and no jitter should be performed.
156 | # min_scale, max_scale, and crop_size are expect to be the same.
157 | assert len({min_scale, max_scale}) == 1
158 | else:
159 | raise NotImplementedError(
160 | "Does not support {} mode".format(self.mode)
161 | )
162 | sampling_rate = get_random_sampling_rate(
163 | self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE,
164 | self.cfg.DATA.SAMPLING_RATE,
165 | )
166 | # Try to decode and sample a clip from a video. If the video can not be
167 | # decoded, repeatedly find a random video replacement that can be decoded.
168 | for i_try in range(self._num_retries):
169 | video_container = None
170 | try:
171 | video_container = get_video_container(
172 | self._path_to_videos[index],
173 | self.cfg.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE,
174 | self.cfg.DATA.DECODING_BACKEND,
175 | )
176 | except Exception as e:
177 | print(
178 | "Failed to load video from {} with error {}".format(
179 | self._path_to_videos[index], e
180 | )
181 | )
182 | # Select a random video if the current video was not able to access.
183 | if video_container is None:
184 | warnings.warn(
185 | "Failed to meta load video idx {} from {}; trial {}".format(
186 | index, self._path_to_videos[index], i_try
187 | )
188 | )
189 | if self.mode not in ["val", "test"] and i_try > self._num_retries // 2:
190 | # let's try another one
191 | index = random.randint(0, len(self._path_to_videos) - 1)
192 | continue
193 |
194 | # Decode video. Meta info is used to perform selective decoding.
195 | frames = decode(
196 | container=video_container,
197 | sampling_rate=sampling_rate,
198 | num_frames=self.cfg.DATA.NUM_FRAMES,
199 | clip_idx=temporal_sample_index,
200 | num_clips=self.cfg.TEST.NUM_ENSEMBLE_VIEWS,
201 | video_meta=self._video_meta[index],
202 | target_fps=self.cfg.DATA.TARGET_FPS,
203 | backend=self.cfg.DATA.DECODING_BACKEND,
204 | max_spatial_scale=min_scale,
205 | )
206 |
207 | # If decoding failed (wrong format, video is too short, and etc),
208 | # select another video.
209 | if frames is None:
210 | warnings.warn(
211 | "Failed to decode video idx {} from {}; trial {}".format(
212 | index, self._path_to_videos[index], i_try
213 | )
214 | )
215 | if self.mode not in ["test"] and i_try > self._num_retries // 2:
216 | # let's try another one
217 | index = random.randint(0, len(self._path_to_videos) - 1)
218 | continue
219 |
220 | label = self._labels[index]
221 |
222 | # Perform color normalization.
223 | frames = tensor_normalize(
224 | frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD
225 | )
226 | frames = frames.permute(3, 0, 1, 2)
227 |
228 | # Perform data augmentation.
229 | frames = spatial_sampling(
230 | frames,
231 | spatial_idx=spatial_sample_index,
232 | min_scale=min_scale,
233 | max_scale=max_scale,
234 | crop_size=crop_size,
235 | random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP,
236 | inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
237 | )
238 |
239 | # if not self.cfg.MODEL.ARCH in ['vit']:
240 | # frames = pack_pathway_output(self.cfg, frames)
241 | # else:
242 | # Perform temporal sampling from the fast pathway.
243 | # frames = [torch.index_select(
244 | # x,
245 | # 1,
246 | # torch.linspace(
247 | # 0, x.shape[1] - 1, self.cfg.DATA.NUM_FRAMES
248 | # ).long(),
249 | # ) for x in frames]
250 |
251 | return frames, label, index, {}
252 | else:
253 | raise RuntimeError(
254 | "Failed to fetch video after {} retries.".format(
255 | self._num_retries
256 | )
257 | )
258 |
259 | def __len__(self):
260 | """
261 | Returns:
262 | (int): the number of videos in the dataset.
263 | """
264 | return len(self._path_to_videos)
265 |
266 |
267 | if __name__ == '__main__':
268 |
269 | from utils.parser import parse_args, load_config
270 | from tqdm import tqdm
271 |
272 | args = parse_args()
273 | args.cfg_file = "models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml"
274 | config = load_config(args)
275 | config.DATA.PATH_TO_DATA_DIR = "/home/kanchanaranasinghe/repo/mmaction2/data/hmdb51/splits"
276 | config.DATA.PATH_PREFIX = "/home/kanchanaranasinghe/repo/mmaction2/data/hmdb51/videos"
277 | dataset = HMDB51(cfg=config, mode="train", num_retries=10)
278 | dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=4)
279 | print(f"Loaded train dataset of length: {len(dataset)}")
280 | for idx, i in tqdm(enumerate(dataloader)):
281 | # continue
282 | print(idx, i[0].shape, i[1:])
283 | if idx > 2:
284 | break
285 |
286 | test_dataset = HMDB51(cfg=config, mode="val", num_retries=10)
287 | test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=4)
288 | print(f"Loaded test dataset of length: {len(test_dataset)}")
289 | for idx, i in tqdm(enumerate(test_dataloader)):
290 | # continue
291 | print(idx, i[0].shape, i[1:])
292 | if idx > 2:
293 | break
294 |
--------------------------------------------------------------------------------
/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Mostly copy-paste from timm library.
16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17 | """
18 | import math
19 | from functools import partial
20 |
21 | import torch
22 | import torch.nn as nn
23 |
24 | from utils.utils import trunc_normal_
25 |
26 |
27 | def drop_path(x, drop_prob: float = 0., training: bool = False):
28 | if drop_prob == 0. or not training:
29 | return x
30 | keep_prob = 1 - drop_prob
31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
33 | random_tensor.floor_() # binarize
34 | output = x.div(keep_prob) * random_tensor
35 | return output
36 |
37 |
38 | class DropPath(nn.Module):
39 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
40 | """
41 | def __init__(self, drop_prob=None):
42 | super(DropPath, self).__init__()
43 | self.drop_prob = drop_prob
44 |
45 | def forward(self, x):
46 | return drop_path(x, self.drop_prob, self.training)
47 |
48 |
49 | class Mlp(nn.Module):
50 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51 | super().__init__()
52 | out_features = out_features or in_features
53 | hidden_features = hidden_features or in_features
54 | self.fc1 = nn.Linear(in_features, hidden_features)
55 | self.act = act_layer()
56 | self.fc2 = nn.Linear(hidden_features, out_features)
57 | self.drop = nn.Dropout(drop)
58 |
59 | def forward(self, x):
60 | x = self.fc1(x)
61 | x = self.act(x)
62 | x = self.drop(x)
63 | x = self.fc2(x)
64 | x = self.drop(x)
65 | return x
66 |
67 |
68 | class Attention(nn.Module):
69 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
70 | super().__init__()
71 | self.num_heads = num_heads
72 | head_dim = dim // num_heads
73 | self.scale = qk_scale or head_dim ** -0.5
74 |
75 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
76 | self.attn_drop = nn.Dropout(attn_drop)
77 | self.proj = nn.Linear(dim, dim)
78 | self.proj_drop = nn.Dropout(proj_drop)
79 |
80 | def forward(self, x):
81 | B, N, C = x.shape
82 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
83 | q, k, v = qkv[0], qkv[1], qkv[2]
84 |
85 | attn = (q @ k.transpose(-2, -1)) * self.scale
86 | attn = attn.softmax(dim=-1)
87 | attn = self.attn_drop(attn)
88 |
89 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
90 | x = self.proj(x)
91 | x = self.proj_drop(x)
92 | return x, attn
93 |
94 |
95 | class Block(nn.Module):
96 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
97 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
98 | super().__init__()
99 | self.norm1 = norm_layer(dim)
100 | self.attn = Attention(
101 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
102 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
103 | self.norm2 = norm_layer(dim)
104 | mlp_hidden_dim = int(dim * mlp_ratio)
105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
106 |
107 | def forward(self, x, return_attention=False):
108 | y, attn = self.attn(self.norm1(x))
109 | if return_attention:
110 | return attn
111 | x = x + self.drop_path(y)
112 | x = x + self.drop_path(self.mlp(self.norm2(x)))
113 | return x
114 |
115 |
116 | class PatchEmbed(nn.Module):
117 | """ Image to Patch Embedding
118 | """
119 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
120 | super().__init__()
121 | num_patches = (img_size // patch_size) * (img_size // patch_size)
122 | self.img_size = img_size
123 | self.patch_size = patch_size
124 | self.num_patches = num_patches
125 |
126 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
127 |
128 | def forward(self, x):
129 | B, C, H, W = x.shape
130 | x = self.proj(x).flatten(2).transpose(1, 2)
131 | return x
132 |
133 |
134 | class VisionTransformer(nn.Module):
135 | """ Vision Transformer """
136 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
137 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
138 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
139 | super().__init__()
140 | self.num_features = self.embed_dim = embed_dim
141 |
142 | self.patch_embed = PatchEmbed(
143 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
144 | num_patches = self.patch_embed.num_patches
145 |
146 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
147 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
148 | self.pos_drop = nn.Dropout(p=drop_rate)
149 |
150 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
151 | self.blocks = nn.ModuleList([
152 | Block(
153 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
154 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
155 | for i in range(depth)])
156 | self.norm = norm_layer(embed_dim)
157 |
158 | # Classifier head
159 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
160 |
161 | trunc_normal_(self.pos_embed, std=.02)
162 | trunc_normal_(self.cls_token, std=.02)
163 | self.apply(self._init_weights)
164 |
165 | def _init_weights(self, m):
166 | if isinstance(m, nn.Linear):
167 | trunc_normal_(m.weight, std=.02)
168 | if isinstance(m, nn.Linear) and m.bias is not None:
169 | nn.init.constant_(m.bias, 0)
170 | elif isinstance(m, nn.LayerNorm):
171 | nn.init.constant_(m.bias, 0)
172 | nn.init.constant_(m.weight, 1.0)
173 |
174 | def interpolate_pos_encoding(self, x, w, h):
175 | npatch = x.shape[1] - 1
176 | N = self.pos_embed.shape[1] - 1
177 | if npatch == N and w == h:
178 | return self.pos_embed
179 | class_pos_embed = self.pos_embed[:, 0]
180 | patch_pos_embed = self.pos_embed[:, 1:]
181 | dim = x.shape[-1]
182 | w0 = w // self.patch_embed.patch_size
183 | h0 = h // self.patch_embed.patch_size
184 | # we add a small number to avoid floating point error in the interpolation
185 | # see discussion at https://github.com/facebookresearch/dino/issues/8
186 | w0, h0 = w0 + 0.1, h0 + 0.1
187 | patch_pos_embed = nn.functional.interpolate(
188 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
189 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
190 | mode='bicubic',
191 | )
192 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
193 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
194 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
195 |
196 | def prepare_tokens(self, x):
197 | B, nc, w, h = x.shape
198 | x = self.patch_embed(x) # patch linear embedding
199 |
200 | # add the [CLS] token to the embed patch tokens
201 | cls_tokens = self.cls_token.expand(B, -1, -1)
202 | x = torch.cat((cls_tokens, x), dim=1)
203 |
204 | # add positional encoding to each token
205 | x = x + self.interpolate_pos_encoding(x, w, h)
206 |
207 | return self.pos_drop(x)
208 |
209 | def forward(self, x):
210 | x = self.prepare_tokens(x)
211 | for blk in self.blocks:
212 | x = blk(x)
213 | x = self.norm(x)
214 | return x[:, 0]
215 |
216 | def get_intermediate_layers(self, x, n=1):
217 | x = self.prepare_tokens(x)
218 | # we return the output tokens from the `n` last blocks
219 | output = []
220 | for i, blk in enumerate(self.blocks):
221 | x = blk(x)
222 | if len(self.blocks) - i <= n:
223 | output.append(self.norm(x))
224 | return output
225 |
226 |
227 | def vit_tiny(patch_size=16, **kwargs):
228 | model = VisionTransformer(
229 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
230 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
231 | return model
232 |
233 |
234 | def vit_small(patch_size=16, **kwargs):
235 | model = VisionTransformer(
236 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
237 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
238 | return model
239 |
240 |
241 | def vit_base(patch_size=16, **kwargs):
242 | model = VisionTransformer(
243 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
244 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
245 | return model
246 |
247 |
248 | class DINOHead(nn.Module):
249 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
250 | super().__init__()
251 | nlayers = max(nlayers, 1)
252 | if nlayers == 1:
253 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
254 | else:
255 | layers = [nn.Linear(in_dim, hidden_dim)]
256 | if use_bn:
257 | layers.append(nn.BatchNorm1d(hidden_dim))
258 | layers.append(nn.GELU())
259 | for _ in range(nlayers - 2):
260 | layers.append(nn.Linear(hidden_dim, hidden_dim))
261 | if use_bn:
262 | layers.append(nn.BatchNorm1d(hidden_dim))
263 | layers.append(nn.GELU())
264 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
265 | self.mlp = nn.Sequential(*layers)
266 | self.apply(self._init_weights)
267 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
268 | self.last_layer.weight_g.data.fill_(1)
269 | if norm_last_layer:
270 | self.last_layer.weight_g.requires_grad = False
271 |
272 | def _init_weights(self, m):
273 | if isinstance(m, nn.Linear):
274 | trunc_normal_(m.weight, std=.02)
275 | if isinstance(m, nn.Linear) and m.bias is not None:
276 | nn.init.constant_(m.bias, 0)
277 |
278 | def forward(self, x):
279 | x = self.mlp(x)
280 | x = nn.functional.normalize(x, dim=-1, p=2)
281 | x = self.last_layer(x)
282 | return x
283 |
284 |
285 | class MultiDINOHead(nn.Module):
286 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
287 | bottleneck_dim=256):
288 | super().__init__()
289 | nlayers = max(nlayers, 1)
290 | if nlayers == 1:
291 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
292 | self.aux_mlp = nn.Linear(in_dim, bottleneck_dim)
293 | else:
294 | layers = [nn.Linear(in_dim, hidden_dim)]
295 | if use_bn:
296 | layers.append(nn.BatchNorm1d(hidden_dim))
297 | layers.append(nn.GELU())
298 | for _ in range(nlayers - 2):
299 | layers.append(nn.Linear(hidden_dim, hidden_dim))
300 | if use_bn:
301 | layers.append(nn.BatchNorm1d(hidden_dim))
302 | layers.append(nn.GELU())
303 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
304 | self.mlp = nn.Sequential(*layers)
305 |
306 | aux_layers = [nn.Linear(in_dim, hidden_dim)]
307 | if use_bn:
308 | aux_layers.append(nn.BatchNorm1d(hidden_dim))
309 | aux_layers.append(nn.GELU())
310 | for _ in range(nlayers - 2):
311 | aux_layers.append(nn.Linear(hidden_dim, hidden_dim))
312 | if use_bn:
313 | aux_layers.append(nn.BatchNorm1d(hidden_dim))
314 | aux_layers.append(nn.GELU())
315 | aux_layers.append(nn.Linear(hidden_dim, bottleneck_dim))
316 | self.aux_mlp = nn.Sequential(*aux_layers)
317 |
318 | self.apply(self._init_weights)
319 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
320 | self.last_layer.weight_g.data.fill_(1)
321 | if norm_last_layer:
322 | self.last_layer.weight_g.requires_grad = False
323 |
324 | self.aux_last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
325 | self.aux_last_layer.weight_g.data.fill_(1)
326 | if norm_last_layer:
327 | self.aux_last_layer.weight_g.requires_grad = False
328 |
329 | def _init_weights(self, m):
330 | if isinstance(m, nn.Linear):
331 | trunc_normal_(m.weight, std=.02)
332 | if isinstance(m, nn.Linear) and m.bias is not None:
333 | nn.init.constant_(m.bias, 0)
334 |
335 | def forward(self, x):
336 | rgb_x, aux_x = x[0], x[1]
337 |
338 | rgb_x = self.mlp(rgb_x)
339 | rgb_x = nn.functional.normalize(rgb_x, dim=-1, p=2)
340 | rgb_x = self.last_layer(rgb_x)
341 |
342 | aux_x = self.aux_mlp(aux_x)
343 | aux_x = nn.functional.normalize(aux_x, dim=-1, p=2)
344 | aux_x = self.aux_last_layer(aux_x)
345 | return rgb_x, aux_x
346 |
--------------------------------------------------------------------------------
/datasets/data_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import logging
4 | import numpy as np
5 | import os
6 | import random
7 | import time
8 | from collections import defaultdict
9 | import cv2
10 | import torch
11 | from fvcore.common.file_io import PathManager
12 | from torch.utils.data.distributed import DistributedSampler
13 |
14 | import datasets.transform as transform
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def retry_load_images(image_paths, retry=10, backend="pytorch"):
20 | """
21 | This function is to load images with support of retrying for failed load.
22 |
23 | Args:
24 | image_paths (list): paths of images needed to be loaded.
25 | retry (int, optional): maximum time of loading retrying. Defaults to 10.
26 | backend (str): `pytorch` or `cv2`.
27 |
28 | Returns:
29 | imgs (list): list of loaded images.
30 | """
31 | for i in range(retry):
32 | imgs = []
33 | for image_path in image_paths:
34 | with PathManager.open(image_path, "rb") as f:
35 | img_str = np.frombuffer(f.read(), np.uint8)
36 | img = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR)
37 | imgs.append(img)
38 |
39 | if all(img is not None for img in imgs):
40 | if backend == "pytorch":
41 | imgs = torch.as_tensor(np.stack(imgs))
42 | return imgs
43 | else:
44 | logger.warn("Reading failed. Will retry.")
45 | time.sleep(1.0)
46 | if i == retry - 1:
47 | raise Exception("Failed to load images {}".format(image_paths))
48 |
49 |
50 | def get_sequence(center_idx, half_len, sample_rate, num_frames):
51 | """
52 | Sample frames among the corresponding clip.
53 |
54 | Args:
55 | center_idx (int): center frame idx for current clip
56 | half_len (int): half of the clip length
57 | sample_rate (int): sampling rate for sampling frames inside of the clip
58 | num_frames (int): number of expected sampled frames
59 |
60 | Returns:
61 | seq (list): list of indexes of sampled frames in this clip.
62 | """
63 | seq = list(range(center_idx - half_len, center_idx + half_len, sample_rate))
64 |
65 | for seq_idx in range(len(seq)):
66 | if seq[seq_idx] < 0:
67 | seq[seq_idx] = 0
68 | elif seq[seq_idx] >= num_frames:
69 | seq[seq_idx] = num_frames - 1
70 | return seq
71 |
72 |
73 | def pack_pathway_output(cfg, frames):
74 | """
75 | Prepare output as a list of tensors. Each tensor corresponding to a
76 | unique pathway.
77 | Args:
78 | frames (tensor): frames of images sampled from the video. The
79 | dimension is `channel` x `num frames` x `height` x `width`.
80 | Returns:
81 | frame_list (list): list of tensors with the dimension of
82 | `channel` x `num frames` x `height` x `width`.
83 | """
84 | if cfg.DATA.REVERSE_INPUT_CHANNEL:
85 | frames = frames[[2, 1, 0], :, :, :]
86 | if cfg.MODEL.ARCH in cfg.MODEL.SINGLE_PATHWAY_ARCH:
87 | frame_list = [frames]
88 | elif cfg.MODEL.ARCH in cfg.MODEL.MULTI_PATHWAY_ARCH:
89 | fast_pathway = frames
90 | # Perform temporal sampling from the fast pathway.
91 | slow_pathway = torch.index_select(
92 | frames,
93 | 1,
94 | torch.linspace(
95 | 0, frames.shape[1] - 1, frames.shape[1] // cfg.SLOWFAST.ALPHA
96 | ).long(),
97 | )
98 | frame_list = [slow_pathway, fast_pathway]
99 | else:
100 | raise NotImplementedError(
101 | "Model arch {} is not in {}".format(
102 | cfg.MODEL.ARCH,
103 | cfg.MODEL.SINGLE_PATHWAY_ARCH + cfg.MODEL.MULTI_PATHWAY_ARCH,
104 | )
105 | )
106 | return frame_list
107 |
108 |
109 | def spatial_sampling(
110 | frames,
111 | spatial_idx=-1,
112 | min_scale=256,
113 | max_scale=320,
114 | crop_size=224,
115 | random_horizontal_flip=True,
116 | inverse_uniform_sampling=False,
117 | ):
118 | """
119 | Perform spatial sampling on the given video frames. If spatial_idx is
120 | -1, perform random scale, random crop, and random flip on the given
121 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
122 | with the given spatial_idx.
123 | Args:
124 | frames (tensor): frames of images sampled from the video. The
125 | dimension is `num frames` x `height` x `width` x `channel`.
126 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
127 | or 2, perform left, center, right crop if width is larger than
128 | height, and perform top, center, buttom crop if height is larger
129 | than width.
130 | min_scale (int): the minimal size of scaling.
131 | max_scale (int): the maximal size of scaling.
132 | crop_size (int): the size of height and width used to crop the
133 | frames.
134 | inverse_uniform_sampling (bool): if True, sample uniformly in
135 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
136 | scale. If False, take a uniform sample from [min_scale,
137 | max_scale].
138 | Returns:
139 | frames (tensor): spatially sampled frames.
140 | """
141 | assert spatial_idx in [-1, 0, 1, 2]
142 | if spatial_idx == -1:
143 | frames, _ = transform.random_short_side_scale_jitter(
144 | images=frames,
145 | min_size=min_scale,
146 | max_size=max_scale,
147 | inverse_uniform_sampling=inverse_uniform_sampling,
148 | )
149 | frames, _ = transform.random_crop(frames, crop_size)
150 | if random_horizontal_flip:
151 | frames, _ = transform.horizontal_flip(0.5, frames)
152 | else:
153 | # The testing is deterministic and no jitter should be performed.
154 | # min_scale, max_scale, and crop_size are expect to be the same.
155 | #assert len({min_scale, max_scale, crop_size}) == 1
156 | frames, _ = transform.random_short_side_scale_jitter(
157 | frames, min_scale, max_scale
158 | )
159 | frames, _ = transform.uniform_crop(frames, crop_size, spatial_idx)
160 | return frames
161 |
162 | def spatial_sampling_2crops(
163 | frames,
164 | spatial_idx=-1,
165 | min_scale=256,
166 | max_scale=320,
167 | crop_size=224,
168 | random_horizontal_flip=True,
169 | inverse_uniform_sampling=False,
170 | ):
171 | """
172 | Perform spatial sampling on the given video frames. If spatial_idx is
173 | -1, perform random scale, random crop, and random flip on the given
174 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
175 | with the given spatial_idx.
176 | Args:
177 | frames (tensor): frames of images sampled from the video. The
178 | dimension is `num frames` x `height` x `width` x `channel`.
179 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
180 | or 2, perform left, center, right crop if width is larger than
181 | height, and perform top, center, buttom crop if height is larger
182 | than width.
183 | min_scale (int): the minimal size of scaling.
184 | max_scale (int): the maximal size of scaling.
185 | crop_size (int): the size of height and width used to crop the
186 | frames.
187 | inverse_uniform_sampling (bool): if True, sample uniformly in
188 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
189 | scale. If False, take a uniform sample from [min_scale,
190 | max_scale].
191 | Returns:
192 | frames (tensor): spatially sampled frames.
193 | """
194 | assert spatial_idx in [-1, 0, 1, 2]
195 | if spatial_idx == -1:
196 | frames, _ = transform.random_short_side_scale_jitter(
197 | images=frames,
198 | min_size=min_scale,
199 | max_size=max_scale,
200 | inverse_uniform_sampling=inverse_uniform_sampling,
201 | )
202 | frames, _ = transform.random_crop(frames, crop_size)
203 | if random_horizontal_flip:
204 | frames, _ = transform.horizontal_flip(0.5, frames)
205 | else:
206 | # The testing is deterministic and no jitter should be performed.
207 | # min_scale, max_scale, and crop_size are expect to be the same.
208 | #assert len({min_scale, max_scale, crop_size}) == 1
209 | frames, _ = transform.random_short_side_scale_jitter(
210 | frames, min_scale, max_scale
211 | )
212 | frames, _ = transform.uniform_crop_2crops(frames, crop_size, spatial_idx)
213 | return frames
214 |
215 |
216 | def as_binary_vector(labels, num_classes):
217 | """
218 | Construct binary label vector given a list of label indices.
219 | Args:
220 | labels (list): The input label list.
221 | num_classes (int): Number of classes of the label vector.
222 | Returns:
223 | labels (numpy array): the resulting binary vector.
224 | """
225 | label_arr = np.zeros((num_classes,))
226 |
227 | for lbl in set(labels):
228 | label_arr[lbl] = 1.0
229 | return label_arr
230 |
231 |
232 | def aggregate_labels(label_list):
233 | """
234 | Join a list of label list.
235 | Args:
236 | labels (list): The input label list.
237 | Returns:
238 | labels (list): The joint list of all lists in input.
239 | """
240 | all_labels = []
241 | for labels in label_list:
242 | for l in labels:
243 | all_labels.append(l)
244 | return list(set(all_labels))
245 |
246 |
247 | def convert_to_video_level_labels(labels):
248 | """
249 | Aggregate annotations from all frames of a video to form video-level labels.
250 | Args:
251 | labels (list): The input label list.
252 | Returns:
253 | labels (list): Same as input, but with each label replaced by
254 | a video-level one.
255 | """
256 | for video_id in range(len(labels)):
257 | video_level_labels = aggregate_labels(labels[video_id])
258 | for i in range(len(labels[video_id])):
259 | labels[video_id][i] = video_level_labels
260 | return labels
261 |
262 |
263 | def load_image_lists(frame_list_file, prefix="", return_list=False):
264 | """
265 | Load image paths and labels from a "frame list".
266 | Each line of the frame list contains:
267 | `original_vido_id video_id frame_id path labels`
268 | Args:
269 | frame_list_file (string): path to the frame list.
270 | prefix (str): the prefix for the path.
271 | return_list (bool): if True, return a list. If False, return a dict.
272 | Returns:
273 | image_paths (list or dict): list of list containing path to each frame.
274 | If return_list is False, then return in a dict form.
275 | labels (list or dict): list of list containing label of each frame.
276 | If return_list is False, then return in a dict form.
277 | """
278 | image_paths = defaultdict(list)
279 | labels = defaultdict(list)
280 | with PathManager.open(frame_list_file, "r") as f:
281 | assert f.readline().startswith("original_vido_id")
282 | for line in f:
283 | row = line.split()
284 | # original_vido_id video_id frame_id path labels
285 | assert len(row) == 5
286 | video_name = row[0]
287 | if prefix == "":
288 | path = row[3]
289 | else:
290 | path = os.path.join(prefix, row[3])
291 | image_paths[video_name].append(path)
292 | frame_labels = row[-1].replace('"', "")
293 | if frame_labels != "":
294 | labels[video_name].append(
295 | [int(x) for x in frame_labels.split(",")]
296 | )
297 | else:
298 | labels[video_name].append([])
299 |
300 | if return_list:
301 | keys = image_paths.keys()
302 | image_paths = [image_paths[key] for key in keys]
303 | labels = [labels[key] for key in keys]
304 | return image_paths, labels
305 | return dict(image_paths), dict(labels)
306 |
307 |
308 | def tensor_normalize(tensor, mean, std):
309 | """
310 | Normalize a given tensor by subtracting the mean and dividing the std.
311 | Args:
312 | tensor (tensor): tensor to normalize.
313 | mean (tensor or list): mean value to subtract.
314 | std (tensor or list): std to divide.
315 | """
316 | if tensor.dtype == torch.uint8:
317 | tensor = tensor.float()
318 | tensor = tensor / 255.0
319 | if type(mean) == list:
320 | mean = torch.tensor(mean)
321 | if type(std) == list:
322 | std = torch.tensor(std)
323 | tensor = tensor - mean
324 | tensor = tensor / std
325 | return tensor
326 |
327 |
328 | def get_random_sampling_rate(long_cycle_sampling_rate, sampling_rate):
329 | """
330 | When multigrid training uses a fewer number of frames, we randomly
331 | increase the sampling rate so that some clips cover the original span.
332 | """
333 | if long_cycle_sampling_rate > 0:
334 | assert long_cycle_sampling_rate >= sampling_rate
335 | return random.randint(sampling_rate, long_cycle_sampling_rate)
336 | else:
337 | return sampling_rate
338 |
339 |
340 | def revert_tensor_normalize(tensor, mean, std):
341 | """
342 | Revert normalization for a given tensor by multiplying by the std and adding the mean.
343 | Args:
344 | tensor (tensor): tensor to revert normalization.
345 | mean (tensor or list): mean value to add.
346 | std (tensor or list): std to multiply.
347 | """
348 | if type(mean) == list:
349 | mean = torch.tensor(mean)
350 | if type(std) == list:
351 | std = torch.tensor(std)
352 | tensor = tensor * std
353 | tensor = tensor + mean
354 | return tensor
355 |
356 |
357 | def create_sampler(dataset, shuffle, cfg):
358 | """
359 | Create sampler for the given dataset.
360 | Args:
361 | dataset (torch.utils.data.Dataset): the given dataset.
362 | shuffle (bool): set to ``True`` to have the data reshuffled
363 | at every epoch.
364 | cfg (CfgNode): configs. Details can be found in
365 | slowfast/config/defaults.py
366 | Returns:
367 | sampler (Sampler): the created sampler.
368 | """
369 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
370 |
371 | return sampler
372 |
373 |
374 | def loader_worker_init_fn(dataset):
375 | """
376 | Create init function passed to pytorch data loader.
377 | Args:
378 | dataset (torch.utils.data.Dataset): the given dataset.
379 | """
380 | return None
381 |
--------------------------------------------------------------------------------
/datasets/kinetics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import glob
3 | import os
4 | import random
5 | import warnings
6 | from PIL import Image
7 | import torch
8 | import torch.utils.data
9 | import torchvision
10 | import kornia
11 |
12 | from datasets.transform import resize
13 | from datasets.data_utils import get_random_sampling_rate, tensor_normalize, spatial_sampling, pack_pathway_output
14 | from datasets.decoder import decode
15 | from datasets.video_container import get_video_container
16 | from datasets.transform import VideoDataAugmentationDINO
17 | from einops import rearrange
18 |
19 |
20 | class Kinetics(torch.utils.data.Dataset):
21 | """
22 | Kinetics video loader. Construct the Kinetics video loader, then sample
23 | clips from the videos. For training and validation, a single clip is
24 | randomly sampled from every video with random cropping, scaling, and
25 | flipping. For testing, multiple clips are uniformaly sampled from every
26 | video with uniform cropping. For uniform cropping, we take the left, center,
27 | and right crop if the width is larger than height, or take top, center, and
28 | bottom crop if the height is larger than the width.
29 | """
30 |
31 | def __init__(self, cfg, mode, num_retries=10, get_flow=False):
32 | """
33 | Construct the Kinetics video loader with a given csv file. The format of
34 | the csv file is:
35 | ```
36 | path_to_video_1 label_1
37 | path_to_video_2 label_2
38 | ...
39 | path_to_video_N label_N
40 | ```
41 | Args:
42 | cfg (CfgNode): configs.
43 | mode (string): Options includes `train`, `val`, or `test` mode.
44 | For the train and val mode, the data loader will take data
45 | from the train or val set, and sample one clip per video.
46 | For the test mode, the data loader will take data from test set,
47 | and sample multiple clips per video.
48 | num_retries (int): number of retries.
49 | """
50 | # Only support train, val, and test mode.
51 | assert mode in [
52 | "train",
53 | "val",
54 | "test",
55 | ], "Split '{}' not supported for Kinetics".format(mode)
56 | self.mode = mode
57 | self.cfg = cfg
58 | if get_flow:
59 | assert mode == "train", "invalid: flow only for train mode"
60 | self.get_flow = get_flow
61 |
62 | self._video_meta = {}
63 | self._num_retries = num_retries
64 | # For training or validation mode, one single clip is sampled from every
65 | # video. For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every
66 | # video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from
67 | # the frames.
68 | if self.mode in ["train", "val"]:
69 | self._num_clips = 1
70 | elif self.mode in ["test"]:
71 | self._num_clips = (
72 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS
73 | )
74 |
75 | print("Constructing Kinetics {}...".format(mode))
76 | self._construct_loader()
77 |
78 | def _construct_loader(self):
79 | """
80 | Construct the video loader.
81 | """
82 | path_to_file = os.path.join(
83 | self.cfg.DATA.PATH_TO_DATA_DIR, "{}.csv".format(self.mode)
84 | )
85 | assert os.path.exists(path_to_file), "{} dir not found".format(
86 | path_to_file
87 | )
88 |
89 | self._path_to_videos = []
90 | self._labels = []
91 | self._spatial_temporal_idx = []
92 | with open(path_to_file, "r") as f:
93 | for clip_idx, path_label in enumerate(f.read().splitlines()):
94 | assert (
95 | len(path_label.split(self.cfg.DATA.PATH_LABEL_SEPARATOR))
96 | == 2
97 | )
98 | path, label = path_label.split(
99 | self.cfg.DATA.PATH_LABEL_SEPARATOR
100 | )
101 | for idx in range(self._num_clips):
102 | self._path_to_videos.append(
103 | os.path.join(self.cfg.DATA.PATH_PREFIX, path)
104 | )
105 | self._labels.append(int(label))
106 | self._spatial_temporal_idx.append(idx)
107 | self._video_meta[clip_idx * self._num_clips + idx] = {}
108 | assert (
109 | len(self._path_to_videos) > 0
110 | ), "Failed to load Kinetics split {} from {}".format(
111 | self._split_idx, path_to_file
112 | )
113 | print(
114 | "Constructing kinetics dataloader (size: {}) from {}".format(
115 | len(self._path_to_videos), path_to_file
116 | )
117 | )
118 |
119 | def __getitem__(self, index):
120 | """
121 | Given the video index, return the list of frames, label, and video
122 | index if the video can be fetched and decoded successfully, otherwise
123 | repeatly find a random video that can be decoded as a replacement.
124 | Args:
125 | index (int): the video index provided by the pytorch sampler.
126 | Returns:
127 | frames (tensor): the frames of sampled from the video. The dimension
128 | is `channel` x `num frames` x `height` x `width`.
129 | label (int): the label of the current video.
130 | index (int): if the video provided by pytorch sampler can be
131 | decoded, then return the index of the video. If not, return the
132 | index of the video replacement that can be decoded.
133 | """
134 | short_cycle_idx = None
135 | # When short cycle is used, input index is a tupple.
136 | if isinstance(index, tuple):
137 | index, short_cycle_idx = index
138 |
139 | if self.mode in ["train", "val"]:
140 | # -1 indicates random sampling.
141 | temporal_sample_index = -1
142 | spatial_sample_index = -1
143 | min_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[0]
144 | max_scale = self.cfg.DATA.TRAIN_JITTER_SCALES[1]
145 | crop_size = self.cfg.DATA.TRAIN_CROP_SIZE
146 | if short_cycle_idx in [0, 1]:
147 | crop_size = int(
148 | round(
149 | self.cfg.MULTIGRID.SHORT_CYCLE_FACTORS[short_cycle_idx]
150 | * self.cfg.MULTIGRID.DEFAULT_S
151 | )
152 | )
153 | if self.cfg.MULTIGRID.DEFAULT_S > 0:
154 | # Decreasing the scale is equivalent to using a larger "span"
155 | # in a sampling grid.
156 | min_scale = int(
157 | round(
158 | float(min_scale)
159 | * crop_size
160 | / self.cfg.MULTIGRID.DEFAULT_S
161 | )
162 | )
163 | elif self.mode in ["test"]:
164 | temporal_sample_index = (
165 | self._spatial_temporal_idx[index]
166 | // self.cfg.TEST.NUM_SPATIAL_CROPS
167 | )
168 | # spatial_sample_index is in [0, 1, 2]. Corresponding to left,
169 | # center, or right if width is larger than height, and top, middle,
170 | # or bottom if height is larger than width.
171 | spatial_sample_index = (
172 | (
173 | self._spatial_temporal_idx[index]
174 | % self.cfg.TEST.NUM_SPATIAL_CROPS
175 | )
176 | if self.cfg.TEST.NUM_SPATIAL_CROPS > 1
177 | else 1
178 | )
179 | min_scale, max_scale, crop_size = (
180 | [self.cfg.DATA.TEST_CROP_SIZE] * 3
181 | if self.cfg.TEST.NUM_SPATIAL_CROPS > 1
182 | else [self.cfg.DATA.TRAIN_JITTER_SCALES[0]] * 2
183 | + [self.cfg.DATA.TEST_CROP_SIZE]
184 | )
185 | # The testing is deterministic and no jitter should be performed.
186 | # min_scale, max_scale, and crop_size are expect to be the same.
187 | assert len({min_scale, max_scale}) == 1
188 | else:
189 | raise NotImplementedError(
190 | "Does not support {} mode".format(self.mode)
191 | )
192 | sampling_rate = get_random_sampling_rate(
193 | self.cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE,
194 | self.cfg.DATA.SAMPLING_RATE,
195 | )
196 | # Try to decode and sample a clip from a video. If the video can not be
197 | # decoded, repeatly find a random video replacement that can be decoded.
198 | for i_try in range(self._num_retries):
199 | video_container = None
200 | try:
201 | video_container = get_video_container(
202 | self._path_to_videos[index],
203 | self.cfg.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE,
204 | self.cfg.DATA.DECODING_BACKEND,
205 | )
206 | except Exception as e:
207 | print(
208 | "Failed to load video from {} with error {}".format(
209 | self._path_to_videos[index], e
210 | )
211 | )
212 | # Select a random video if the current video was not able to access.
213 | if video_container is None:
214 | warnings.warn(
215 | "Failed to meta load video idx {} from {}; trial {}".format(
216 | index, self._path_to_videos[index], i_try
217 | )
218 | )
219 | if self.mode not in ["test"] and i_try > self._num_retries // 2:
220 | # let's try another one
221 | index = random.randint(0, len(self._path_to_videos) - 1)
222 | continue
223 |
224 | # Decode video. Meta info is used to perform selective decoding.
225 | frames = decode(
226 | container=video_container,
227 | sampling_rate=sampling_rate,
228 | num_frames=self.cfg.DATA.NUM_FRAMES,
229 | clip_idx=temporal_sample_index,
230 | num_clips=self.cfg.TEST.NUM_ENSEMBLE_VIEWS,
231 | video_meta=self._video_meta[index],
232 | target_fps=self.cfg.DATA.TARGET_FPS,
233 | backend=self.cfg.DATA.DECODING_BACKEND,
234 | max_spatial_scale=min_scale,
235 | temporal_aug=self.mode == "train" and not self.cfg.DATA.NO_RGB_AUG,
236 | two_token=self.cfg.MODEL.TWO_TOKEN,
237 | rand_fr=self.cfg.DATA.RAND_FR
238 | )
239 |
240 | # If decoding failed (wrong format, video is too short, and etc),
241 | # select another video.
242 | if frames is None:
243 | warnings.warn(
244 | "Failed to decode video idx {} from {}; trial {}".format(
245 | index, self._path_to_videos[index], i_try
246 | )
247 | )
248 | if self.mode not in ["test"] and i_try > self._num_retries // 2:
249 | # let's try another one
250 | index = random.randint(0, len(self._path_to_videos) - 1)
251 | continue
252 |
253 | label = self._labels[index]
254 |
255 | if self.mode in ["test", "val"] or self.cfg.DATA.NO_RGB_AUG:
256 | # Perform color normalization.
257 | frames = tensor_normalize(
258 | frames, self.cfg.DATA.MEAN, self.cfg.DATA.STD
259 | )
260 |
261 | # T H W C -> C T H W.
262 | frames = frames.permute(3, 0, 1, 2)
263 |
264 | # Perform data augmentation.
265 | frames = spatial_sampling(
266 | frames,
267 | spatial_idx=spatial_sample_index,
268 | min_scale=min_scale,
269 | max_scale=max_scale,
270 | crop_size=crop_size,
271 | random_horizontal_flip=self.cfg.DATA.RANDOM_FLIP,
272 | inverse_uniform_sampling=self.cfg.DATA.INV_UNIFORM_SAMPLE,
273 | )
274 |
275 | if not self.cfg.MODEL.ARCH in ['vit']:
276 | frames = pack_pathway_output(self.cfg, frames)
277 | else:
278 | # Perform temporal sampling from the fast pathway.
279 | frames = torch.index_select(
280 | frames,
281 | 1,
282 | torch.linspace(
283 | 0, frames.shape[1] - 1, self.cfg.DATA.NUM_FRAMES
284 |
285 | ).long(),
286 | )
287 |
288 | else:
289 | # T H W C -> T C H W.
290 | frames = [rearrange(x, "t h w c -> t c h w") for x in frames]
291 |
292 | # Perform data augmentation.
293 | augmentation = VideoDataAugmentationDINO()
294 | frames = augmentation(frames, from_list=True, no_aug=self.cfg.DATA.NO_SPATIAL,
295 | two_token=self.cfg.MODEL.TWO_TOKEN)
296 |
297 | # T C H W -> C T H W.
298 | frames = [rearrange(x, "t c h w -> c t h w") for x in frames]
299 |
300 | # Perform temporal sampling from the fast pathway.
301 | frames = [torch.index_select(
302 | x,
303 | 1,
304 | torch.linspace(
305 | 0, x.shape[1] - 1, x.shape[1] if self.cfg.DATA.RAND_FR else self.cfg.DATA.NUM_FRAMES
306 |
307 | ).long(),
308 | ) for x in frames]
309 |
310 | meta_data = {}
311 | if self.get_flow:
312 | assert self.mode == "train", "flow only for train"
313 | try:
314 | flow_path = self._path_to_videos[index].replace("train_d256", "train_flow")[:-4]
315 | flow_tensor = self.get_flow_from_folder(flow_path)
316 | flow_tensor = kornia.filters.sobel(flow_tensor)
317 | if self.cfg.DATA.NO_FLOW_AUG:
318 | flow_tensor = resize(flow_tensor, size=self.cfg.DATA.CROP_SIZE, mode="bicubic")
319 | flow_tensor = [x for x in flow_tensor]
320 | else:
321 | flow_tensor = augmentation(flow_tensor)
322 | flow_tensor = [rearrange(x, "t c h w -> c t h w") for x in flow_tensor]
323 | meta_data["flow"] = flow_tensor
324 | except Exception as e:
325 | print(e)
326 | continue
327 | return frames, label, index, meta_data
328 |
329 | else:
330 | raise RuntimeError(
331 | "Failed to fetch video after {} retries.".format(
332 | self._num_retries
333 | )
334 | )
335 |
336 | def __len__(self):
337 | """
338 | Returns:
339 | (int): the number of videos in the dataset.
340 | """
341 | return len(self._path_to_videos)
342 |
343 | @staticmethod
344 | def get_flow_from_folder(dir_path):
345 | flow_image_list = sorted(glob.glob(f"{dir_path}/*.jpg"))
346 | flow_image_list = [Image.open(im_path) for im_path in flow_image_list]
347 | flow_image_list = [torchvision.transforms.functional.to_tensor(im_path) for im_path in flow_image_list]
348 | return torch.stack(flow_image_list, dim=0)
349 |
350 |
351 | if __name__ == '__main__':
352 |
353 | # import torch
354 | # from timesformer.datasets import Kinetics
355 | from utils.parser import parse_args, load_config
356 | from tqdm import tqdm
357 |
358 | args = parse_args()
359 | args.cfg_file = "/home/kanchanaranasinghe/repo/timesformer/configs/Kinetics/TimeSformer_divST_8x32_224.yaml"
360 | config = load_config(args)
361 | config.DATA.PATH_TO_DATA_DIR = "/home/kanchanaranasinghe/data/kinetics400/new_annotations"
362 | # config.DATA.PATH_TO_DATA_DIR = "/home/kanchanaranasinghe/data/kinetics400/k400-mini"
363 | config.DATA.PATH_PREFIX = "/home/kanchanaranasinghe/data/kinetics400"
364 | # dataset = Kinetics(cfg=config, mode="val", num_retries=10)
365 | dataset = Kinetics(cfg=config, mode="train", num_retries=10, get_flow=True)
366 | print(f"Loaded train dataset of length: {len(dataset)}")
367 | dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=4)
368 | for idx, i in enumerate(dataloader):
369 | print([x.shape for x in i[0]], i[1:3], [x.shape for x in i[3]['flow']])
370 | break
371 |
372 | do_vis = False
373 | if do_vis:
374 | from PIL import Image
375 | from transform import undo_normalize
376 |
377 | vis_path = "/home/kanchanaranasinghe/data/kinetics400/vis/spatial_aug"
378 |
379 | for aug_idx in range(len(i[0])):
380 | temp = i[0][aug_idx][3].permute(1, 2, 3, 0)
381 | temp = undo_normalize(temp, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
382 | for idx in range(temp.shape[0]):
383 | im = Image.fromarray(temp[idx].numpy())
384 | im.resize((224, 224)).save(f"{vis_path}/aug_{aug_idx}_fr_{idx:02d}.jpg")
385 |
--------------------------------------------------------------------------------
/eval_linear.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import argparse
15 | import json
16 | import os
17 | import torch
18 | import torch.backends.cudnn as cudnn
19 | from pathlib import Path
20 | from torch import nn
21 | from tqdm import tqdm
22 |
23 | from datasets import UCF101, HMDB51, Kinetics
24 | from models import get_vit_base_patch16_224, get_aux_token_vit, SwinTransformer3D
25 | from utils import utils
26 | from utils.meters import TestMeter
27 | from utils.parser import load_config
28 |
29 |
30 | def eval_linear(args):
31 | utils.init_distributed_mode(args)
32 | print("git:\n {}\n".format(utils.get_sha()))
33 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
34 | cudnn.benchmark = True
35 | os.makedirs(args.output_dir, exist_ok=True)
36 | json.dump(vars(args), open(f"{args.output_dir}/config.json", "w"), indent=4)
37 |
38 | # ============ preparing data ... ============
39 | config = load_config(args)
40 | # config.DATA.PATH_TO_DATA_DIR = f"{os.path.expanduser('~')}/repo/mmaction2/data/{args.dataset}/splits"
41 | # config.DATA.PATH_PREFIX = f"{os.path.expanduser('~')}/repo/mmaction2/data/{args.dataset}/videos"
42 | config.TEST.NUM_SPATIAL_CROPS = 1
43 | if args.dataset == "ucf101":
44 | dataset_train = UCF101(cfg=config, mode="train", num_retries=10)
45 | dataset_val = UCF101(cfg=config, mode="val", num_retries=10)
46 | config.TEST.NUM_SPATIAL_CROPS = 3
47 | multi_crop_val = UCF101(cfg=config, mode="val", num_retries=10)
48 | elif args.dataset == "hmdb51":
49 | dataset_train = HMDB51(cfg=config, mode="train", num_retries=10)
50 | dataset_val = HMDB51(cfg=config, mode="val", num_retries=10)
51 | config.TEST.NUM_SPATIAL_CROPS = 3
52 | multi_crop_val = HMDB51(cfg=config, mode="val", num_retries=10)
53 | elif args.dataset == "kinetics400":
54 | dataset_train = Kinetics(cfg=config, mode="train", num_retries=10)
55 | dataset_val = Kinetics(cfg=config, mode="val", num_retries=10)
56 | config.TEST.NUM_SPATIAL_CROPS = 3
57 | multi_crop_val = Kinetics(cfg=config, mode="val", num_retries=10)
58 | else:
59 | raise NotImplementedError(f"invalid dataset: {args.dataset}")
60 |
61 | sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
62 | train_loader = torch.utils.data.DataLoader(
63 | dataset_train,
64 | sampler=sampler,
65 | batch_size=args.batch_size_per_gpu,
66 | num_workers=args.num_workers,
67 | pin_memory=True,
68 | )
69 | val_loader = torch.utils.data.DataLoader(
70 | dataset_val,
71 | batch_size=args.batch_size_per_gpu,
72 | num_workers=args.num_workers,
73 | pin_memory=True,
74 | )
75 |
76 | multi_crop_val_loader = torch.utils.data.DataLoader(
77 | multi_crop_val,
78 | batch_size=args.batch_size_per_gpu,
79 | num_workers=args.num_workers,
80 | pin_memory=True,
81 | )
82 |
83 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
84 |
85 | # ============ building network ... ============
86 | if config.DATA.USE_FLOW or config.MODEL.TWO_TOKEN:
87 | model = get_aux_token_vit(cfg=config, no_head=True)
88 | model_embed_dim = 2 * model.embed_dim
89 | else:
90 | if args.arch == "vit_base":
91 | model = get_vit_base_patch16_224(cfg=config, no_head=True)
92 | model_embed_dim = model.embed_dim
93 | elif args.arch == "swin":
94 | model = SwinTransformer3D(depths=[2, 2, 18, 2], embed_dim=128, num_heads=[4, 8, 16, 32])
95 | model_embed_dim = 1024
96 | else:
97 | raise Exception(f"invalid model: {args.arch}")
98 |
99 | ckpt = torch.load(args.pretrained_weights)
100 | # select_ckpt = 'motion_teacher' if args.use_flow else "teacher"
101 | if "teacher" in ckpt:
102 | ckpt = ckpt["teacher"]
103 | renamed_checkpoint = {x[len("backbone."):]: y for x, y in ckpt.items() if x.startswith("backbone.")}
104 | msg = model.load_state_dict(renamed_checkpoint, strict=False)
105 | print(f"Loaded model with msg: {msg}")
106 | model.cuda()
107 | model.eval()
108 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
109 | # load weights to evaluate
110 |
111 | linear_classifier = LinearClassifier(model_embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)),
112 | num_labels=args.num_labels)
113 | linear_classifier = linear_classifier.cuda()
114 | linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])
115 |
116 | if args.lc_pretrained_weights:
117 | lc_ckpt = torch.load(args.lc_pretrained_weights)
118 | msg = linear_classifier.load_state_dict(lc_ckpt['state_dict'])
119 | print(f"Loaded linear classifier weights with msg: {msg}")
120 | test_stats = validate_network_multi_view(multi_crop_val_loader, model, linear_classifier, args.n_last_blocks,
121 | args.avgpool_patchtokens, config)
122 | # test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
123 | print(test_stats)
124 | return True
125 |
126 |
127 | # set optimizer
128 | optimizer = torch.optim.SGD(
129 | linear_classifier.parameters(),
130 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
131 | momentum=0.9,
132 | weight_decay=0, # we do not apply weight decay
133 | )
134 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)
135 |
136 | # Optionally resume from a checkpoint
137 | to_restore = {"epoch": 0, "best_acc": 0.}
138 | utils.restart_from_checkpoint(
139 | os.path.join(args.output_dir, "checkpoint.pth.tar"),
140 | run_variables=to_restore,
141 | state_dict=linear_classifier,
142 | optimizer=optimizer,
143 | scheduler=scheduler,
144 | )
145 | start_epoch = to_restore["epoch"]
146 | best_acc = to_restore["best_acc"]
147 |
148 | for epoch in range(start_epoch, args.epochs):
149 | train_loader.sampler.set_epoch(epoch)
150 |
151 | train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens)
152 | scheduler.step()
153 |
154 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
155 | 'epoch': epoch}
156 | if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
157 | test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
158 | print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
159 | best_acc = max(best_acc, test_stats["acc1"])
160 | print(f'Max accuracy so far: {best_acc:.2f}%')
161 | log_stats = {**{k: v for k, v in log_stats.items()},
162 | **{f'test_{k}': v for k, v in test_stats.items()}}
163 | if utils.is_main_process():
164 | with (Path(args.output_dir) / "log.txt").open("a") as f:
165 | f.write(json.dumps(log_stats) + "\n")
166 | save_dict = {
167 | "epoch": epoch + 1,
168 | "state_dict": linear_classifier.state_dict(),
169 | "optimizer": optimizer.state_dict(),
170 | "scheduler": scheduler.state_dict(),
171 | "best_acc": best_acc,
172 | }
173 | torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
174 |
175 | test_stats = validate_network_multi_view(multi_crop_val_loader, model, linear_classifier, args.n_last_blocks,
176 | args.avgpool_patchtokens, config)
177 | print(test_stats)
178 |
179 | print("Training of the supervised linear classifier on frozen features completed.\n"
180 | "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))
181 |
182 |
183 | def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
184 | linear_classifier.train()
185 | metric_logger = utils.MetricLogger(delimiter=" ")
186 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
187 | header = 'Epoch: [{}]'.format(epoch)
188 | for (inp, target, sample_idx, meta) in metric_logger.log_every(loader, 20, header):
189 | # move to gpu
190 | inp = inp.cuda(non_blocking=True)
191 | target = target.cuda(non_blocking=True)
192 |
193 | # forward
194 | with torch.no_grad():
195 | # intermediate_output = model.get_intermediate_layers(inp, n)
196 | # output = [x[:, 0] for x in intermediate_output]
197 | # if avgpool:
198 | # output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
199 | # output = torch.cat(output, dim=-1)
200 |
201 | output = model(inp)
202 |
203 | output = linear_classifier(output)
204 |
205 | # compute cross entropy loss
206 | loss = nn.CrossEntropyLoss()(output, target)
207 |
208 | # compute the gradients
209 | optimizer.zero_grad()
210 | loss.backward()
211 |
212 | # step
213 | optimizer.step()
214 |
215 | # log
216 | torch.cuda.synchronize()
217 | metric_logger.update(loss=loss.item())
218 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
219 | # gather the stats from all processes
220 | metric_logger.synchronize_between_processes()
221 | print("Averaged stats:", metric_logger)
222 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
223 |
224 |
225 | @torch.no_grad()
226 | def validate_network(val_loader, model, linear_classifier, n, avgpool):
227 | linear_classifier.eval()
228 | metric_logger = utils.MetricLogger(delimiter=" ")
229 | header = 'Test:'
230 | for (inp, target, sample_idx, meta) in metric_logger.log_every(val_loader, 20, header):
231 | # move to gpu
232 | inp = inp.cuda(non_blocking=True)
233 | target = target.cuda(non_blocking=True)
234 |
235 | # forward
236 | with torch.no_grad():
237 | # intermediate_output = model.get_intermediate_layers(inp, n)
238 | # output = [x[:, 0] for x in intermediate_output]
239 | # if avgpool:
240 | # output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
241 | # output = torch.cat(output, dim=-1)
242 | output = model(inp)
243 | output = linear_classifier(output)
244 | loss = nn.CrossEntropyLoss()(output, target)
245 |
246 | if linear_classifier.module.num_labels >= 5:
247 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
248 | else:
249 | acc1, = utils.accuracy(output, target, topk=(1,))
250 |
251 | batch_size = inp.shape[0]
252 | metric_logger.update(loss=loss.item())
253 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
254 | if linear_classifier.module.num_labels >= 5:
255 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
256 | if linear_classifier.module.num_labels >= 5:
257 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
258 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
259 | else:
260 | print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}'
261 | .format(top1=metric_logger.acc1, losses=metric_logger.loss))
262 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
263 |
264 |
265 | @torch.no_grad()
266 | def validate_network_multi_view(val_loader, model, linear_classifier, n, avgpool, cfg):
267 | linear_classifier.eval()
268 | test_meter = TestMeter(
269 | len(val_loader.dataset)
270 | // (cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS),
271 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS,
272 | args.num_labels,
273 | len(val_loader),
274 | cfg.DATA.MULTI_LABEL,
275 | cfg.DATA.ENSEMBLE_METHOD,
276 | )
277 | test_meter.iter_tic()
278 |
279 | for cur_iter, (inp, target, sample_idx, meta) in tqdm(enumerate(val_loader), total=len(val_loader)):
280 | # move to gpu
281 | inp = inp.cuda(non_blocking=True)
282 | # target = target.cuda(non_blocking=True)
283 | test_meter.data_toc()
284 |
285 | # forward
286 | with torch.no_grad():
287 | output = model(inp)
288 | output = linear_classifier(output)
289 |
290 | output = output.cpu()
291 | target = target.cpu()
292 | sample_idx = sample_idx.cpu()
293 |
294 | test_meter.iter_toc()
295 | # Update and log stats.
296 | test_meter.update_stats(
297 | output.detach(), target.detach(), sample_idx.detach()
298 | )
299 | test_meter.log_iter_stats(cur_iter)
300 |
301 | test_meter.iter_tic()
302 |
303 | test_meter.finalize_metrics()
304 | return test_meter.stats
305 |
306 |
307 | class LinearClassifier(nn.Module):
308 | """Linear layer to train on top of frozen features"""
309 | def __init__(self, dim, num_labels=1000):
310 | super(LinearClassifier, self).__init__()
311 | self.num_labels = num_labels
312 | self.linear = nn.Linear(dim, num_labels)
313 | self.linear.weight.data.normal_(mean=0.0, std=0.01)
314 | self.linear.bias.data.zero_()
315 |
316 | def forward(self, x):
317 | # flatten
318 | x = x.view(x.size(0), -1)
319 |
320 | # linear layer
321 | return self.linear(x)
322 |
323 |
324 | if __name__ == '__main__':
325 | parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet')
326 | parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens
327 | for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""")
328 | parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag,
329 | help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
330 | We typically set this to False for ViT-Small and to True with ViT-Base.""")
331 | parser.add_argument('--arch', default='vit_small', type=str,
332 | choices=['vit_tiny', 'vit_small', 'vit_base', 'swin'],
333 | help='Architecture (support only ViT atm).')
334 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
335 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
336 | parser.add_argument('--lc_pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
337 | parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
338 | parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.')
339 | parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of
340 | training (highest LR used during training). The learning rate is linearly scaled
341 | with the batch size, and specified here for a reference batch size of 256.
342 | We recommend tweaking the LR depending on the checkpoint evaluated.""")
343 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
344 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
345 | distributed training; see https://pytorch.org/docs/stable/distributed.html""")
346 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
347 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
348 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
349 | parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
350 | parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
351 | parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
352 | parser.add_argument('--dataset', default="ucf101", help='Dataset: ucf101 / hmdb51')
353 | parser.add_argument('--use_flow', default=False, type=utils.bool_flag, help="use flow teacher")
354 |
355 | # config file
356 | parser.add_argument("--cfg", dest="cfg_file", help="Path to the config file", type=str,
357 | default="models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml")
358 | parser.add_argument("--opts", help="See utils/defaults.py for all options", default=None, nargs=argparse.REMAINDER)
359 |
360 | args = parser.parse_args()
361 | eval_linear(args)
362 |
--------------------------------------------------------------------------------