├── slowfast ├── utils │ ├── ava_evaluation │ │ ├── __init__.py │ │ ├── README.md │ │ ├── np_box_mask_list.py │ │ ├── ava_action_list_v2.1_for_activitynet_2018.pbtxt.txt │ │ ├── np_box_ops.py │ │ ├── np_mask_ops.py │ │ ├── np_box_list.py │ │ ├── metrics.py │ │ ├── label_map_util.py │ │ └── standard_fields.py │ ├── __init__.py │ ├── env.py │ ├── weight_init_helper.py │ ├── multiprocessing.py │ ├── metrics.py │ ├── logging.py │ ├── bn_helper.py │ ├── lr_policy.py │ ├── parser.py │ ├── benchmark.py │ ├── c2_model_loading.py │ ├── multigrid.py │ └── distributed.py ├── config │ ├── __init__.py │ └── custom_config.py ├── visualization │ ├── __init__.py │ ├── demo_loader.py │ └── predictor.py ├── models │ ├── custom_video_model_builder.py │ ├── __init__.py │ ├── build.py │ ├── losses.py │ ├── bank_model_builder.py │ ├── optimizer.py │ ├── nonlocal_helper.py │ ├── stem_helper.py │ └── batchnorm_helper.py ├── __init__.py └── datasets │ ├── __init__.py │ ├── video_container.py │ ├── build.py │ ├── multigrid_helper.py │ ├── loader.py │ ├── ava_helper.py │ └── utils.py ├── linter.sh ├── tools ├── benchmark.py ├── run_net.py ├── demo_net.py ├── extract_feature.py ├── visualization.py └── test_net.py ├── setup.cfg ├── setup.py ├── INSTALL.md ├── configs ├── I3D_8x8_R50.yaml ├── SLOWFAST_32x2_BANK.yaml ├── SLOWFAST_32x2_R101_LFB.yaml └── SLOWFAST_32x2_R50_LFB.yaml ├── README.md └── DATASET.md /slowfast/utils/ava_evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slowfast/config/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /slowfast/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /slowfast/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/README.md: -------------------------------------------------------------------------------- 1 | The code under this folder is from the official [ActivityNet repo](https://github.com/activitynet/ActivityNet). 2 | -------------------------------------------------------------------------------- /slowfast/models/custom_video_model_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | 5 | """A More Flexible Video models.""" 6 | -------------------------------------------------------------------------------- /slowfast/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from slowfast.utils.env import setup_environment 5 | 6 | setup_environment() 7 | -------------------------------------------------------------------------------- /slowfast/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from .ava_dataset import Ava # noqa 5 | from .build import DATASET_REGISTRY, build_dataset # noqa 6 | -------------------------------------------------------------------------------- /slowfast/config/custom_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Add custom configs and default values""" 5 | 6 | 7 | def add_custom_config(_C): 8 | # Add your own customized configs. 9 | pass 10 | -------------------------------------------------------------------------------- /slowfast/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from .build import MODEL_REGISTRY, build_model # noqa 5 | from .custom_video_model_builder import * # noqa 6 | from .video_model_builder import ResNet, SlowFast # noqa 7 | from .bank_model_builder import BankContext 8 | -------------------------------------------------------------------------------- /slowfast/utils/env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Set up Environment.""" 5 | 6 | import slowfast.utils.logging as logging 7 | 8 | _ENV_SETUP_DONE = False 9 | 10 | 11 | def setup_environment(): 12 | global _ENV_SETUP_DONE 13 | if _ENV_SETUP_DONE: 14 | return 15 | _ENV_SETUP_DONE = True 16 | -------------------------------------------------------------------------------- /linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | # Run this script at project root by ".linter.sh" before you commit. 4 | echo "Running isort..." 5 | isort -y -sp . 6 | 7 | echo "Running black..." 8 | black -l 80 . 9 | 10 | echo "Running flake..." 11 | flake8 . 12 | 13 | command -v arc > /dev/null && { 14 | echo "Running arc lint ..." 15 | arc lint 16 | } 17 | -------------------------------------------------------------------------------- /tools/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | """ 4 | A script to benchmark data loading. 5 | """ 6 | 7 | import slowfast.utils.logging as logging 8 | from slowfast.utils.benchmark import benchmark_data_loading 9 | from slowfast.utils.misc import launch_job 10 | from slowfast.utils.parser import load_config, parse_args 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | 15 | def main(): 16 | args = parse_args() 17 | cfg = load_config(args) 18 | 19 | launch_job( 20 | cfg=cfg, init_method=args.init_method, func=benchmark_data_loading 21 | ) 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=4 4 | known_standard_library=numpy,setuptools 5 | known_myself=slowfast 6 | known_third_party=fvcore,av,torch,pycocotools,yacs,termcolor,scipy,simplejson,matplotlib,detectron2,torchvision,yaml,tqdm,psutil,opencv-python,pandas,tensorboard,moviepy 7 | no_lines_before=STDLIB,THIRDPARTY 8 | sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER 9 | default_section=FIRSTPARTY 10 | 11 | [mypy] 12 | python_version=3.6 13 | ignore_missing_imports = True 14 | warn_unused_configs = True 15 | disallow_untyped_defs = True 16 | check_untyped_defs = True 17 | warn_unused_ignores = True 18 | warn_redundant_casts = True 19 | show_column_numbers = True 20 | follow_imports = silent 21 | allow_redefinition = True 22 | ; Require all functions to be annotated 23 | disallow_incomplete_defs = True 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from setuptools import find_packages, setup 5 | 6 | setup( 7 | name="slowfast", 8 | version="1.0", 9 | author="FAIR", 10 | url="unknown", 11 | description="SlowFast Video Understanding", 12 | install_requires=[ 13 | "yacs>=0.1.6", 14 | "pyyaml>=5.1", 15 | "av", 16 | "matplotlib", 17 | "termcolor>=1.1", 18 | "simplejson", 19 | "tqdm", 20 | "psutil", 21 | "matplotlib", 22 | "detectron2", 23 | "opencv-python", 24 | "pandas", 25 | "torchvision>=0.4.2", 26 | "sklearn", 27 | "lmdb", 28 | "tensorboard", 29 | ], 30 | extras_require={"tensorboard_video_visualization": ["moviepy"]}, 31 | packages=find_packages(exclude=("configs", "tests")), 32 | ) 33 | -------------------------------------------------------------------------------- /tools/run_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Wrapper to train and test a video classification model.""" 5 | from slowfast.utils.misc import launch_job 6 | from slowfast.utils.parser import load_config, parse_args 7 | 8 | from demo_net import demo 9 | from test_net import test 10 | from train_net import train 11 | from visualization import visualize 12 | 13 | def main(): 14 | """ 15 | Main function to spawn the train and test process. 16 | """ 17 | args = parse_args() 18 | cfg = load_config(args) 19 | 20 | # Perform training. 21 | if cfg.TRAIN.ENABLE: 22 | launch_job(cfg=cfg, args=args, func=train, start_method='cmd') 23 | 24 | # Perform multi-clip testing. 25 | if cfg.TEST.ENABLE: 26 | launch_job(cfg=cfg, args=args, func=test, start_method='cmd') 27 | 28 | # Run demo. 29 | if cfg.DEMO.ENABLE: 30 | demo(cfg) 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /slowfast/datasets/video_container.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import av 5 | 6 | 7 | def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"): 8 | """ 9 | Given the path to the video, return the pyav video container. 10 | Args: 11 | path_to_vid (str): path to the video. 12 | multi_thread_decode (bool): if True, perform multi-thread decoding. 13 | backend (str): decoder backend, options include `pyav` and 14 | `torchvision`, default is `pyav`. 15 | Returns: 16 | container (container): video container. 17 | """ 18 | if backend == "torchvision": 19 | with open(path_to_vid, "rb") as fp: 20 | container = fp.read() 21 | return container 22 | elif backend == "pyav": 23 | container = av.open(path_to_vid) 24 | if multi_thread_decode: 25 | # Enable multiple threads for decoding. 26 | container.streams.video[0].thread_type = "AUTO" 27 | return container 28 | else: 29 | raise NotImplementedError("Unknown backend {}".format(backend)) 30 | -------------------------------------------------------------------------------- /slowfast/datasets/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | from fvcore.common.registry import Registry 5 | 6 | DATASET_REGISTRY = Registry("DATASET") 7 | DATASET_REGISTRY.__doc__ = """ 8 | Registry for dataset. 9 | 10 | The registered object will be called with `obj(cfg, split)`. 11 | The call should return a `torch.utils.data.Dataset` object. 12 | """ 13 | 14 | 15 | def build_dataset(dataset_name, cfg, split): 16 | """ 17 | Build a dataset, defined by `dataset_name`. 18 | Args: 19 | dataset_name (str): the name of the dataset to be constructed. 20 | cfg (CfgNode): configs. Details can be found in 21 | slowfast/config/defaults.py 22 | split (str): the split of the data loader. Options include `train`, 23 | `val`, and `test`. 24 | Returns: 25 | Dataset: a constructed dataset specified by dataset_name. 26 | """ 27 | # Capitalize the the first letter of the dataset_name since the dataset_name 28 | # in configs may be in lowercase but the name of dataset class should always 29 | # start with an uppercase letter. 30 | name = dataset_name.capitalize() 31 | return DATASET_REGISTRY.get(name)(cfg, split) 32 | -------------------------------------------------------------------------------- /slowfast/utils/weight_init_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Utility function for weight initialization""" 5 | 6 | import torch.nn as nn 7 | from fvcore.nn.weight_init import c2_msra_fill 8 | 9 | 10 | def init_weights(model, fc_init_std=0.01, zero_init_final_bn=True): 11 | """ 12 | Performs ResNet style weight initialization. 13 | Args: 14 | fc_init_std (float): the expected standard deviation for fc layer. 15 | zero_init_final_bn (bool): if True, zero initialize the final bn for 16 | every bottleneck. 17 | """ 18 | for m in model.modules(): 19 | if isinstance(m, nn.Conv3d): 20 | """ 21 | Follow the initialization method proposed in: 22 | {He, Kaiming, et al. 23 | "Delving deep into rectifiers: Surpassing human-level 24 | performance on imagenet classification." 25 | arXiv preprint arXiv:1502.01852 (2015)} 26 | """ 27 | c2_msra_fill(m) 28 | elif isinstance(m, nn.BatchNorm3d): 29 | if ( 30 | hasattr(m, "transform_final_bn") 31 | and m.transform_final_bn 32 | and zero_init_final_bn 33 | ): 34 | batchnorm_weight = 0.0 35 | else: 36 | batchnorm_weight = 1.0 37 | if m.weight is not None: 38 | m.weight.data.fill_(batchnorm_weight) 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | if isinstance(m, nn.Linear): 42 | m.weight.data.normal_(mean=0.0, std=fc_init_std) 43 | if m.bias is not None: 44 | m.bias.data.zero_() 45 | -------------------------------------------------------------------------------- /slowfast/models/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Model construction functions.""" 5 | 6 | import torch 7 | from fvcore.common.registry import Registry 8 | 9 | MODEL_REGISTRY = Registry("MODEL") 10 | MODEL_REGISTRY.__doc__ = """ 11 | Registry for video model. 12 | 13 | The registered object will be called with `obj(cfg)`. 14 | The call should return a `torch.nn.Module` object. 15 | """ 16 | 17 | 18 | def build_model(cfg): 19 | """ 20 | Builds the video model. 21 | Args: 22 | cfg (configs): configs that contains the hyper-parameters to build the 23 | backbone. Details can be seen in slowfast/config/defaults.py. 24 | """ 25 | if torch.cuda.is_available(): 26 | assert ( 27 | cfg.NUM_GPUS <= torch.cuda.device_count() 28 | ), "Cannot use more GPU devices than available" 29 | else: 30 | assert ( 31 | cfg.NUM_GPUS == 0 32 | ), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs." 33 | 34 | # Construct the model 35 | name = cfg.MODEL.MODEL_NAME 36 | model = MODEL_REGISTRY.get(name)(cfg) 37 | 38 | if cfg.NUM_GPUS: 39 | # Determine the GPU used by the current process 40 | cur_device = torch.cuda.current_device() 41 | # Transfer the model to the current GPU device 42 | model = model.cuda(device=cur_device) 43 | # Use multi-process data parallel model in the multi-gpu setting 44 | if cfg.NUM_GPUS > 1: 45 | # Make model replica operate on the current device 46 | model = torch.nn.parallel.DistributedDataParallel( 47 | module=model, device_ids=[cur_device], output_device=cur_device, 48 | find_unused_parameters=False 49 | ) 50 | return model 51 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Requirements 4 | - Python >= 3.6 5 | - Numpy 6 | - PyTorch 1.3 7 | - [fvcore](https://github.com/facebookresearch/fvcore/): `pip install 'git+https://github.com/facebookresearch/fvcore'` 8 | - [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. 9 | You can install them together at [pytorch.org](https://pytorch.org) to make sure of this. 10 | - simplejson: `pip install simplejson` 11 | - GCC >= 4.9 12 | - PyAV: `conda install av -c conda-forge` 13 | - ffmpeg (4.0 is prefereed, will be installed along with PyAV) 14 | - PyYaml: (will be installed along with fvcore) 15 | - tqdm: (will be installed along with fvcore) 16 | - psutil: `pip install psutil` 17 | - OpenCV: `pip install opencv-python` 18 | - torchvision: `pip install torchvision` or `conda install torchvision -c pytorch` 19 | - tensorboard: `pip install tensorboard` 20 | - moviepy: (optional, for visualizing video on tensorboard) `conda install -c conda-forge moviepy` or `pip install moviepy` 21 | - [Detectron2](https://github.com/facebookresearch/detectron2): 22 | ``` 23 | pip install -U torch torchvision cython 24 | pip install -U 'git+https://github.com/facebookresearch/fvcore.git' 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 25 | git clone https://github.com/facebookresearch/detectron2 detectron2_repo 26 | pip install -e detectron2_repo 27 | # You can find more details at https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md 28 | ``` 29 | 30 | ## Setup 31 | 32 | Clone the LSTC repository and setup the pyslowfast package 33 | ``` 34 | git clone https://github.com/facebookresearch/slowfast 35 | cd LSTC 36 | python3 setup.py build 37 | ``` 38 | 39 | Add the installed package to $PYTHONPATH. 40 | ``` 41 | export PYTHONPATH=/path/to/LSTC/build/lib:$PYTHONPATH 42 | ``` 43 | -------------------------------------------------------------------------------- /configs/I3D_8x8_R50.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: ava 4 | BATCH_SIZE: 64 5 | EVAL_PERIOD: 20 6 | CHECKPOINT_PERIOD: 1 7 | AUTO_RESUME: True 8 | CHECKPOINT_FILE_PATH: "backbones/I3D_8x8_R50.pkl" 9 | CHECKPOINT_TYPE: caffe2 10 | DATA: 11 | NUM_FRAMES: 32 12 | SAMPLING_RATE: 2 13 | TRAIN_JITTER_SCALES: [256, 320] 14 | TRAIN_CROP_SIZE: 224 15 | TEST_CROP_SIZE: 224 16 | INPUT_CHANNEL_NUM: [3] 17 | DETECTION: 18 | ENABLE: True 19 | ALIGNED: True 20 | AVA: 21 | DETECTION_SCORE_THRESH: 0.9 22 | TRAIN_PREDICT_BOX_LISTS: [ 23 | "ava_train_predicted_boxes.csv", 24 | ] 25 | TEST_PREDICT_BOX_LISTS: ["ava_val_predicted_boxes.csv"] 26 | FEATURE_BANK_PATH: "output/feature_bank_i3d" 27 | SLIDING_WINDOW_SIZE: 25 28 | GATHER_BANK: True 29 | FEATURE_BANK_DIM: 2048 30 | TEMPORAL_EMBED: 0 31 | RESNET: 32 | ZERO_INIT_FINAL_BN: True 33 | WIDTH_PER_GROUP: 64 34 | NUM_GROUPS: 1 35 | DEPTH: 50 36 | TRANS_FUNC: bottleneck_transform 37 | STRIDE_1X1: False 38 | NUM_BLOCK_TEMP_KERNEL: [[3], [4], [6], [3]] 39 | NONLOCAL: 40 | LOCATION: [[[]], [[]], [[]], [[]]] 41 | GROUP: [[1], [1], [1], [1]] 42 | INSTANTIATION: softmax 43 | BN: 44 | USE_PRECISE_STATS: False 45 | NUM_BATCHES_PRECISE: 200 46 | SOLVER: 47 | BASE_LR: 0.1 48 | LR_POLICY: steps_with_relative_lrs 49 | STEPS: [0, 10, 13] 50 | LRS: [1, 0.1, 0.01, 0.001] 51 | MAX_EPOCH: 13 52 | MOMENTUM: 0.9 53 | WEIGHT_DECAY: 1e-7 54 | WARMUP_EPOCHS: 4.0 55 | WARMUP_START_LR: 0.000125 56 | OPTIMIZING_METHOD: sgd 57 | MODEL: 58 | NUM_CLASSES: 80 59 | ARCH: i3d 60 | MODEL_NAME: ResNet 61 | LOSS_FUNC: bce 62 | DROPOUT_RATE: 0.5 63 | HEAD_ACT: sigmoid 64 | TEST: 65 | ENABLE: False 66 | DATASET: ava 67 | BATCH_SIZE: 8 68 | CHECKPOINT_FILE_PATH: "" 69 | CACHE: 70 | ENABLE: True 71 | CONTEXT_KEY: [256] 72 | CONTEXT_VAL: [512] 73 | DATA_LOADER: 74 | NUM_WORKERS: 2 75 | PIN_MEMORY: True 76 | NUM_GPUS: 8 77 | NUM_SHARDS: 1 78 | RNG_SEED: 0 79 | OUTPUT_DIR: "output/val_i3d_pair" 80 | LOG_MODEL_INFO: False 81 | -------------------------------------------------------------------------------- /slowfast/utils/multiprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Multiprocessing helpers.""" 5 | 6 | import torch 7 | 8 | 9 | def run( 10 | local_rank, num_proc, func, init_method, shard_id, num_shards, backend, cfg 11 | ): 12 | """ 13 | Runs a function from a child process. 14 | Args: 15 | local_rank (int): rank of the current process on the current machine. 16 | num_proc (int): number of processes per machine. 17 | func (function): function to execute on each of the process. 18 | init_method (string): method to initialize the distributed training. 19 | TCP initialization: equiring a network address reachable from all 20 | processes followed by the port. 21 | Shared file-system initializaerall machines for the distributed 22 | training job. 23 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 24 | supports, each with different capabilities. Details can be found 25 | here:tion: makes use of a file system that 26 | is shared and visible from all machines. The URL should start with 27 | file:// and contain a path to a non-existent file on a shared file 28 | system. 29 | shard_id (int): the rank of the current machine. 30 | num_shards (int): number of ov 31 | https://pytorch.org/docs/stable/distributed.html 32 | cfg (CfgNode): configs. Details can be found in 33 | slowfast/config/defaults.py 34 | """ 35 | # Initialize the process group. 36 | world_size = num_proc * num_shards 37 | rank = shard_id * num_proc + local_rank 38 | 39 | try: 40 | torch.distributed.init_process_group( 41 | backend=backend, 42 | init_method=init_method, 43 | world_size=world_size, 44 | rank=rank, 45 | ) 46 | except Exception as e: 47 | raise e 48 | 49 | torch.cuda.set_device(local_rank) 50 | func(cfg) 51 | -------------------------------------------------------------------------------- /slowfast/models/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Loss functions.""" 5 | 6 | import torch.nn as nn 7 | import torch 8 | 9 | class CustomizeCrossEntropy(nn.Module): 10 | 11 | def __init__(self, reduction="mean"): 12 | super(CustomizeCrossEntropy, self).__init__() 13 | self.reduction = reduction 14 | 15 | def forward(self, pred, labels): 16 | 17 | pred = -1 * torch.log(pred + 1e-5) 18 | loss = torch.sum(pred * labels, dim=-1) 19 | 20 | if self.reduction == "mean": 21 | return torch.mean(loss) 22 | else: 23 | return torch.sum(loss) 24 | 25 | class CrossScopeFocalLoss(nn.Module): 26 | 27 | def __init__(self, reduction="mean", gamma=1.0): 28 | super(CrossScopeFocalLoss, self).__init__() 29 | self.reduction = reduction 30 | self.gamma = gamma 31 | 32 | def forward(self, pred1, pred2, labels): 33 | 34 | prob1 = pred1.clone() 35 | prob2 = pred2.clone() 36 | 37 | pred1 = -1 * torch.log(prob1 + 1e-5) 38 | pred2 = -1 * torch.log(prob2 + 1e-5) 39 | 40 | loss1 = torch.sum(((1 - prob2) ** self.gamma) * pred1 * labels, dim=-1) 41 | loss2 = torch.sum(((1 - prob1) ** self.gamma) * pred2 * labels, dim=-1) 42 | 43 | loss = loss1 + loss2 44 | 45 | if self.reduction == "mean": 46 | return torch.mean(loss) 47 | else: 48 | return torch.sum(loss) 49 | 50 | def get_loss_func(loss_name): 51 | """ 52 | Retrieve the loss given the loss name. 53 | Args (int): 54 | loss_name: the name of the loss to use. 55 | """ 56 | if loss_name not in _LOSSES.keys(): 57 | raise NotImplementedError("Loss {} is not supported".format(loss_name)) 58 | return _LOSSES[loss_name] 59 | 60 | _LOSSES = { 61 | "cross_entropy": CustomizeCrossEntropy, # nn.CrossEntropyLoss, 62 | "bce": nn.BCELoss, 63 | "bce_logit": nn.BCEWithLogitsLoss, 64 | "scope_focal_loss": CrossScopeFocalLoss 65 | } 66 | 67 | -------------------------------------------------------------------------------- /configs/SLOWFAST_32x2_BANK.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: ava 4 | BATCH_SIZE: 64 5 | EVAL_PERIOD: 20 6 | CHECKPOINT_PERIOD: 1 7 | AUTO_RESUME: True 8 | CHECKPOINT_FILE_PATH: "" 9 | DATA: 10 | NUM_FRAMES: 32 11 | SAMPLING_RATE: 2 12 | TRAIN_JITTER_SCALES: [256, 320] 13 | TRAIN_CROP_SIZE: 224 14 | TEST_CROP_SIZE: 224 15 | INPUT_CHANNEL_NUM: [3, 3] 16 | DETECTION: 17 | ENABLE: True 18 | ALIGNED: True 19 | AVA: 20 | DETECTION_SCORE_THRESH: 0.9 21 | TRAIN_PREDICT_BOX_LISTS: [] 22 | TEST_PREDICT_BOX_LISTS: ["ava_val_predicted_boxes.csv"] 23 | TEST_GT_BOX_LISTS: [] 24 | FEATURE_BANK_PATH: "output/feature_bank" 25 | SLIDING_WINDOW_SIZE: 15 26 | GATHER_BANK: False 27 | SLOWFAST: 28 | ALPHA: 4 29 | BETA_INV: 8 30 | FUSION_CONV_CHANNEL_RATIO: 2 31 | FUSION_KERNEL_SZ: 7 32 | RESNET: 33 | ZERO_INIT_FINAL_BN: True 34 | WIDTH_PER_GROUP: 64 35 | NUM_GROUPS: 1 36 | DEPTH: 50 37 | TRANS_FUNC: bottleneck_transform 38 | STRIDE_1X1: False 39 | NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]] 40 | SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [2, 2]] 41 | SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [1, 1]] 42 | NONLOCAL: 43 | LOCATION: [[[], []], [[], []], [[], []], [[], []]] 44 | GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]] 45 | INSTANTIATION: dot_product 46 | POOL: [[[1, 2, 2], [1, 2, 2]], [[1, 2, 2], [1, 2, 2]], [[1, 2, 2], [1, 2, 2]], [[1, 2, 2], [1, 2, 2]]] 47 | BN: 48 | USE_PRECISE_STATS: False 49 | FREEZE: False 50 | NUM_BATCHES_PRECISE: 200 51 | SOLVER: 52 | BASE_LR: 0.1 53 | LR_POLICY: steps_with_relative_lrs 54 | STEPS: [0, 10, 15, 20] 55 | LRS: [1, 0.1, 0.01, 0.001] 56 | MAX_EPOCH: 20 57 | MOMENTUM: 0.9 58 | WEIGHT_DECAY: 1e-7 59 | WARMUP_EPOCHS: 5.0 60 | WARMUP_START_LR: 0.000125 61 | OPTIMIZING_METHOD: sgd 62 | MODEL: 63 | NUM_CLASSES: 80 64 | ARCH: slowfast 65 | MODEL_NAME: BankContext 66 | LOSS_FUNC: bce 67 | DROPOUT_RATE: 0.5 68 | HEAD_ACT: sigmoid 69 | TEST: 70 | ENABLE: False 71 | DATASET: ava 72 | BATCH_SIZE: 8 73 | CHECKPOINT_FILE_PATH: "" 74 | DATA_LOADER: 75 | NUM_WORKERS: 2 76 | PIN_MEMORY: True 77 | NUM_GPUS: 8 78 | NUM_SHARDS: 1 79 | RNG_SEED: 0 80 | OUTPUT_DIR: "output/raw_bank" 81 | CACHE: 82 | ENABLE: True 83 | LOG_MODEL_INFO: False 84 | -------------------------------------------------------------------------------- /configs/SLOWFAST_32x2_R101_LFB.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: ava 4 | BATCH_SIZE: 64 5 | EVAL_PERIOD: 20 6 | CHECKPOINT_PERIOD: 1 7 | AUTO_RESUME: True 8 | CHECKPOINT_FILE_PATH: "backbones/SLOWFAST_32x2_R101_Kinetics.pkl" 9 | CHECKPOINT_TYPE: caffe2 10 | DATA: 11 | NUM_FRAMES: 32 12 | SAMPLING_RATE: 2 13 | TRAIN_JITTER_SCALES: [256, 320] 14 | TRAIN_CROP_SIZE: 224 15 | TEST_CROP_SIZE: 256 16 | INPUT_CHANNEL_NUM: [3, 3] 17 | DETECTION: 18 | ENABLE: True 19 | ALIGNED: True 20 | AVA: 21 | DETECTION_SCORE_THRESH: 0.9 22 | TRAIN_PREDICT_BOX_LISTS: [ 23 | "ava_train_predicted_boxes.csv", 24 | ] 25 | TEST_PREDICT_BOX_LISTS: [ 26 | "ava_val_predicted_boxes.csv" 27 | ] 28 | FEATURE_BANK_PATH: "feature_bank/feature_bank_res101" 29 | SLIDING_WINDOW_SIZE: 25 30 | GATHER_BANK: True 31 | FEATURE_BANK_DIM: 2304 32 | GROUNDTRUTH_FILE: "ava_val_v2.2.csv" 33 | SLOWFAST: 34 | ALPHA: 4 35 | BETA_INV: 8 36 | FUSION_CONV_CHANNEL_RATIO: 2 37 | FUSION_KERNEL_SZ: 5 38 | RESNET: 39 | ZERO_INIT_FINAL_BN: True 40 | WIDTH_PER_GROUP: 64 41 | NUM_GROUPS: 1 42 | DEPTH: 101 43 | TRANS_FUNC: bottleneck_transform 44 | STRIDE_1X1: False 45 | NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]] 46 | SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [2, 2]] 47 | SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [1, 1]] 48 | NONLOCAL: 49 | LOCATION: [[[], []], [[], []], [[6, 13, 20], []], [[], []]] 50 | GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]] 51 | INSTANTIATION: dot_product 52 | POOL: [[[2, 2, 2], [2, 2, 2]], [[2, 2, 2], [2, 2, 2]], [[2, 2, 2], [2, 2, 2]], [[2, 2, 2], [2, 2, 2]]] 53 | BN: 54 | USE_PRECISE_STATS: False 55 | NUM_BATCHES_PRECISE: 200 56 | SOLVER: 57 | MOMENTUM: 0.9 58 | WEIGHT_DECAY: 1e-7 59 | LR_POLICY: steps_with_relative_lrs 60 | BASE_LR: 0.1 61 | MAX_EPOCH: 20 62 | WARMUP_EPOCHS: 5.0 63 | LRS: [1.0, 0.1, 0.01, 0.001] 64 | STEPS: [0, 10, 15, 20] 65 | MODEL: 66 | NUM_CLASSES: 80 67 | ARCH: slowfast 68 | MODEL_NAME: SlowFast 69 | LOSS_FUNC: bce 70 | DROPOUT_RATE: 0.5 71 | HEAD_ACT: sigmoid 72 | TEST: 73 | ENABLE: False 74 | DATASET: ava 75 | BATCH_SIZE: 8 76 | CHECKPOINT_FILE_PATH: "" 77 | DATA_LOADER: 78 | NUM_WORKERS: 2 79 | PIN_MEMORY: True 80 | NUM_GPUS: 8 81 | NUM_SHARDS: 1 82 | RNG_SEED: 0 83 | OUTPUT_DIR: "output/val_101_pair" 84 | LSTC: 85 | ENABLE: True 86 | NUM_READERS: 2 87 | NUM_PAIRS: 2 88 | LOG_MODEL_INFO: False -------------------------------------------------------------------------------- /configs/SLOWFAST_32x2_R50_LFB.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ENABLE: True 3 | DATASET: ava 4 | BATCH_SIZE: 64 5 | EVAL_PERIOD: 20 6 | CHECKPOINT_PERIOD: 1 7 | AUTO_RESUME: True 8 | CHECKPOINT_FILE_PATH: "backbones/SLOWFAST_8x8_R50_KINETICS.pkl" 9 | CHECKPOINT_TYPE: caffe2 10 | DATA: 11 | NUM_FRAMES: 32 12 | SAMPLING_RATE: 2 13 | TRAIN_JITTER_SCALES: [256, 320] 14 | TRAIN_CROP_SIZE: 224 15 | TEST_CROP_SIZE: 256 16 | INPUT_CHANNEL_NUM: [3, 3] 17 | DETECTION: 18 | ENABLE: True 19 | ALIGNED: True 20 | AVA: 21 | DETECTION_SCORE_THRESH: 0.9 22 | TRAIN_PREDICT_BOX_LISTS: [ 23 | "ava_train_predicted_boxes.csv", 24 | ] 25 | TEST_PREDICT_BOX_LISTS: [ 26 | "ava_val_predicted_boxes.csv" 27 | ] 28 | FEATURE_BANK_PATH: "feature_bank/feature_bank_res50" 29 | SLIDING_WINDOW_SIZE: 25 30 | GATHER_BANK: True 31 | FEATURE_BANK_DIM: 1280 32 | GROUNDTRUTH_FILE: "ava_val_v2.2.csv" 33 | SLOWFAST: 34 | ALPHA: 4 35 | BETA_INV: 8 36 | FUSION_CONV_CHANNEL_RATIO: 2 37 | FUSION_KERNEL_SZ: 7 38 | RESNET: 39 | ZERO_INIT_FINAL_BN: True 40 | WIDTH_PER_GROUP: 64 41 | NUM_GROUPS: 1 42 | DEPTH: 50 43 | TRANS_FUNC: bottleneck_transform 44 | STRIDE_1X1: False 45 | NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]] 46 | SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [2, 2]] 47 | SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [1, 1]] 48 | NONLOCAL: 49 | LOCATION: [[[], []], [[], []], [[], []], [[], []]] 50 | GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]] 51 | INSTANTIATION: dot_product 52 | POOL: [[[1, 2, 2], [1, 2, 2]], [[1, 2, 2], [1, 2, 2]], [[1, 2, 2], [1, 2, 2]], [[1, 2, 2], [1, 2, 2]]] 53 | BN: 54 | USE_PRECISE_STATS: False 55 | FREEZE: False 56 | NUM_BATCHES_PRECISE: 200 57 | SOLVER: 58 | BASE_LR: 0.1 59 | LR_POLICY: steps_with_relative_lrs 60 | STEPS: [0, 10, 15, 20] 61 | LRS: [1, 0.1, 0.01, 0.001] 62 | MAX_EPOCH: 16 63 | MOMENTUM: 0.9 64 | WEIGHT_DECAY: 1e-7 65 | WARMUP_EPOCHS: 5.0 66 | WARMUP_START_LR: 0.000125 67 | OPTIMIZING_METHOD: sgd 68 | MODEL: 69 | NUM_CLASSES: 80 70 | ARCH: slowfast 71 | MODEL_NAME: SlowFast 72 | LOSS_FUNC: bce 73 | DROPOUT_RATE: 0.5 74 | HEAD_ACT: sigmoid 75 | TEST: 76 | ENABLE: False 77 | DATASET: ava 78 | BATCH_SIZE: 8 79 | CHECKPOINT_FILE_PATH: "" 80 | DATA_LOADER: 81 | NUM_WORKERS: 2 82 | PIN_MEMORY: True 83 | NUM_GPUS: 8 84 | NUM_SHARDS: 1 85 | RNG_SEED: 0 86 | OUTPUT_DIR: "output/val_pair" 87 | LSTC: 88 | ENABLE: True 89 | NUM_READERS: 2 90 | NUM_PAIRS: 2 91 | LOG_MODEL_INFO: False 92 | -------------------------------------------------------------------------------- /slowfast/models/bank_model_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Video models.""" 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from slowfast.utils.logging import get_logger 10 | from .build import MODEL_REGISTRY 11 | 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | @MODEL_REGISTRY.register() 17 | class BankContext(nn.Module): 18 | 19 | act = { 20 | 'sigmoid': nn.Sigmoid, 21 | 'softmax': nn.Softmax 22 | } 23 | 24 | def __init__(self, cfg): 25 | super(BankContext, self).__init__() 26 | self.cfg = cfg 27 | self.window_size = cfg.AVA.SLIDING_WINDOW_SIZE * 2 + 1 28 | self.feature_size = sum(cfg.SLOWFAST.OUTPUT_CHANNEL) 29 | self.classifier = nn.Linear(self.feature_size, cfg.MODEL.NUM_CLASSES) 30 | self.dropout = nn.Dropout(cfg.MODEL.DROPOUT_RATE) 31 | self.act_func = self.act[cfg.MODEL.HEAD_ACT]() 32 | 33 | self._build_aggregators() 34 | 35 | def _build_aggregators(self, ratio = 2): 36 | 37 | inter_size = self.feature_size // ratio 38 | self.norm1 = nn.LayerNorm(self.feature_size) 39 | self.ffn1 = nn.Sequential( 40 | nn.Linear(self.feature_size, inter_size), 41 | nn.ReLU(), 42 | nn.Linear(inter_size, self.feature_size) 43 | ) 44 | 45 | self.norm2 = nn.LayerNorm(self.feature_size) 46 | self.ffn2 = nn.Sequential( 47 | nn.Linear(self.feature_size, inter_size), 48 | nn.ReLU(), 49 | nn.Linear(inter_size, self.feature_size) 50 | ) 51 | 52 | def forward(self, FBs): 53 | """ 54 | aggregate context information from banks 55 | Args: 56 | FBs: list[torch.Tensor] 57 | 58 | Returns: 59 | torch.Tensor 60 | """ 61 | num_batch = len(FBs) 62 | output = [] 63 | for b in range(num_batch): 64 | feature_bank = FBs[b] 65 | clip_feat = [torch.mean(val, dim=0) 66 | for k, val in enumerate(feature_bank) 67 | if val is not None and k != self.window_size // 2] 68 | 69 | clip_feat = torch.stack(clip_feat, dim=0) 70 | clip_feat = self.ffn1(self.norm1(clip_feat)) 71 | 72 | feat = torch.mean(clip_feat, dim=0).unsqueeze(0) 73 | output.append(self.ffn2(self.norm2(feat))) 74 | 75 | x = torch.cat(output, dim=0) 76 | x = self.dropout(x) 77 | 78 | return self.act_func(self.classifier(x)) 79 | -------------------------------------------------------------------------------- /slowfast/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Functions for computing metrics.""" 5 | 6 | import torch 7 | 8 | 9 | def topks_correct(preds, labels, ks): 10 | """ 11 | Given the predictions, labels, and a list of top-k values, compute the 12 | number of correct predictions for each top-k value. 13 | 14 | Args: 15 | preds (array): array of predictions. Dimension is batchsize 16 | N x ClassNum. 17 | labels (array): array of labels. Dimension is batchsize N. 18 | ks (list): list of top-k values. For example, ks = [1, 5] correspods 19 | to top-1 and top-5. 20 | 21 | Returns: 22 | topks_correct (list): list of numbers, where the `i`-th entry 23 | corresponds to the number of top-`ks[i]` correct predictions. 24 | """ 25 | assert preds.size(0) == labels.size( 26 | 0 27 | ), "Batch dim of predictions and labels must match" 28 | # Find the top max_k predictions for each sample 29 | _top_max_k_vals, top_max_k_inds = torch.topk( 30 | preds, max(ks), dim=1, largest=True, sorted=True 31 | ) 32 | # (batch_size, max_k) -> (max_k, batch_size). 33 | top_max_k_inds = top_max_k_inds.t() 34 | # (batch_size, ) -> (max_k, batch_size). 35 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 36 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct. 37 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 38 | # Compute the number of topk correct predictions for each k. 39 | topks_correct = [ 40 | top_max_k_correct[:k, :].view(-1).float().sum() for k in ks 41 | ] 42 | return topks_correct 43 | 44 | 45 | def topk_errors(preds, labels, ks): 46 | """ 47 | Computes the top-k error for each k. 48 | Args: 49 | preds (array): array of predictions. Dimension is N. 50 | labels (array): array of labels. Dimension is N. 51 | ks (list): list of ks to calculate the top accuracies. 52 | """ 53 | num_topks_correct = topks_correct(preds, labels, ks) 54 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] 55 | 56 | 57 | def topk_accuracies(preds, labels, ks): 58 | """ 59 | Computes the top-k accuracy for each k. 60 | Args: 61 | preds (array): array of predictions. Dimension is N. 62 | labels (array): array of labels. Dimension is N. 63 | ks (list): list of ks to calculate the top accuracies. 64 | """ 65 | num_topks_correct = topks_correct(preds, labels, ks) 66 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] 67 | -------------------------------------------------------------------------------- /slowfast/datasets/multigrid_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Helper functions for multigrid training.""" 5 | 6 | import numpy as np 7 | from torch._six import int_classes as _int_classes 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | class ShortCycleBatchSampler(Sampler): 12 | """ 13 | Extend Sampler to support "short cycle" sampling. 14 | See paper "A Multigrid Method for Efficiently Training Video Models", 15 | Wu et al., 2019 (https://arxiv.org/abs/1912.00998) for details. 16 | """ 17 | 18 | def __init__(self, sampler, batch_size, drop_last, cfg): 19 | if not isinstance(sampler, Sampler): 20 | raise ValueError( 21 | "sampler should be an instance of " 22 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 23 | ) 24 | if ( 25 | not isinstance(batch_size, _int_classes) 26 | or isinstance(batch_size, bool) 27 | or batch_size <= 0 28 | ): 29 | raise ValueError( 30 | "batch_size should be a positive integer value, " 31 | "but got batch_size={}".format(batch_size) 32 | ) 33 | if not isinstance(drop_last, bool): 34 | raise ValueError( 35 | "drop_last should be a boolean value, but got " 36 | "drop_last={}".format(drop_last) 37 | ) 38 | self.sampler = sampler 39 | self.drop_last = drop_last 40 | 41 | bs_factor = [ 42 | int( 43 | round( 44 | ( 45 | float(cfg.DATA.TRAIN_CROP_SIZE) 46 | / (s * cfg.MULTIGRID.DEFAULT_S) 47 | ) 48 | ** 2 49 | ) 50 | ) 51 | for s in cfg.MULTIGRID.SHORT_CYCLE_FACTORS 52 | ] 53 | 54 | self.batch_sizes = [ 55 | batch_size * bs_factor[0], 56 | batch_size * bs_factor[1], 57 | batch_size, 58 | ] 59 | 60 | def __iter__(self): 61 | counter = 0 62 | batch_size = self.batch_sizes[0] 63 | batch = [] 64 | for idx in self.sampler: 65 | batch.append((idx, counter % 3)) 66 | if len(batch) == batch_size: 67 | yield batch 68 | counter += 1 69 | batch_size = self.batch_sizes[counter % 3] 70 | batch = [] 71 | if len(batch) > 0 and not self.drop_last: 72 | yield batch 73 | 74 | def __len__(self): 75 | avg_batch_size = sum(self.batch_sizes) / 3.0 76 | if self.drop_last: 77 | return int(np.floor(len(self.sampler) / avg_batch_size)) 78 | else: 79 | return int(np.ceil(len(self.sampler) / avg_batch_size)) 80 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/np_box_mask_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxMaskList classes and functions.""" 17 | 18 | from __future__ import ( 19 | absolute_import, 20 | division, 21 | print_function, 22 | unicode_literals, 23 | ) 24 | import numpy as np 25 | 26 | from . import np_box_list 27 | 28 | 29 | class BoxMaskList(np_box_list.BoxList): 30 | """Convenience wrapper for BoxList with masks. 31 | 32 | BoxMaskList extends the np_box_list.BoxList to contain masks as well. 33 | In particular, its constructor receives both boxes and masks. Note that the 34 | masks correspond to the full image. 35 | """ 36 | 37 | def __init__(self, box_data, mask_data): 38 | """Constructs box collection. 39 | 40 | Args: 41 | box_data: a numpy array of shape [N, 4] representing box coordinates 42 | mask_data: a numpy array of shape [N, height, width] representing masks 43 | with values are in {0,1}. The masks correspond to the full 44 | image. The height and the width will be equal to image height and width. 45 | 46 | Raises: 47 | ValueError: if bbox data is not a numpy array 48 | ValueError: if invalid dimensions for bbox data 49 | ValueError: if mask data is not a numpy array 50 | ValueError: if invalid dimension for mask data 51 | """ 52 | super(BoxMaskList, self).__init__(box_data) 53 | if not isinstance(mask_data, np.ndarray): 54 | raise ValueError("Mask data must be a numpy array.") 55 | if len(mask_data.shape) != 3: 56 | raise ValueError("Invalid dimensions for mask data.") 57 | if mask_data.dtype != np.uint8: 58 | raise ValueError( 59 | "Invalid data type for mask data: uint8 is required." 60 | ) 61 | if mask_data.shape[0] != box_data.shape[0]: 62 | raise ValueError( 63 | "There should be the same number of boxes and masks." 64 | ) 65 | self.data["masks"] = mask_data 66 | 67 | def get_masks(self): 68 | """Convenience function for accessing masks. 69 | 70 | Returns: 71 | a numpy array of shape [N, height, width] representing masks 72 | """ 73 | return self.get_field("masks") 74 | -------------------------------------------------------------------------------- /slowfast/utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Logging.""" 5 | 6 | import builtins 7 | import decimal 8 | import functools 9 | import logging 10 | import os 11 | import sys 12 | import simplejson 13 | from fvcore.common.file_io import PathManager 14 | 15 | import slowfast.utils.distributed as du 16 | 17 | def _suppress_print(): 18 | """ 19 | Suppresses printing from the current process. 20 | """ 21 | 22 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 23 | pass 24 | 25 | builtins.print = print_pass 26 | 27 | @functools.lru_cache(maxsize=None) 28 | def _cached_log_stream(filename): 29 | return PathManager.open(filename, "a") 30 | 31 | 32 | def setup_logging(output_dir=None): 33 | """ 34 | Sets up the logging for multiple processes. Only enable the logging for the 35 | master process, and suppress logging for the non-master processes. 36 | """ 37 | # Set up logging format. 38 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 39 | 40 | if du.is_master_proc(): 41 | # Enable logging for the master process. 42 | logging.root.handlers = [] 43 | logging.basicConfig( 44 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 45 | ) 46 | else: 47 | # Suppress logging for non-master processes. 48 | _suppress_print() 49 | 50 | # setup root logger 51 | logger = logging.getLogger() 52 | logger.setLevel(logging.DEBUG) 53 | logger.propagate = False 54 | plain_formatter = logging.Formatter( 55 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 56 | datefmt="%m/%d %H:%M:%S", 57 | ) 58 | 59 | if du.is_master_proc(): 60 | ch = logging.StreamHandler(stream=sys.stdout) 61 | ch.setLevel(logging.DEBUG) 62 | ch.setFormatter(plain_formatter) 63 | logger.addHandler(ch) 64 | 65 | if output_dir is not None and du.is_master_proc(du.get_world_size()): 66 | filename = os.path.join(output_dir, "stdout.log") 67 | fh = logging.StreamHandler(_cached_log_stream(filename)) 68 | fh.setLevel(logging.DEBUG) 69 | fh.setFormatter(plain_formatter) 70 | logger.addHandler(fh) 71 | 72 | 73 | def get_logger(name): 74 | """ 75 | Retrieve the logger with the specified name or, if name is None, return a 76 | logger which is the root logger of the hierarchy. 77 | Args: 78 | name (string): name of the logger. 79 | """ 80 | return logging.getLogger(name) 81 | 82 | 83 | def log_json_stats(stats): 84 | """ 85 | Logs json stats. 86 | Args: 87 | stats (dict): a dictionary of statistical information to log. 88 | """ 89 | stats = { 90 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v 91 | for k, v in stats.items() 92 | } 93 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 94 | logger = get_logger(__name__) 95 | logger.info("json_stats: {:s}".format(json_stats)) 96 | -------------------------------------------------------------------------------- /slowfast/utils/bn_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """bn helper.""" 5 | 6 | import itertools 7 | import torch 8 | 9 | 10 | @torch.no_grad() 11 | def compute_and_update_bn_stats(model, data_loader, num_batches=200): 12 | """ 13 | Compute and update the batch norm stats to make it more precise. During 14 | training both bn stats and the weight are changing after every iteration, 15 | so the bn can not precisely reflect the latest stats of the current model. 16 | Here the bn stats is recomputed without change of weights, to make the 17 | running mean and running var more precise. 18 | Args: 19 | model (model): the model using to compute and update the bn stats. 20 | data_loader (dataloader): dataloader using to provide inputs. 21 | num_batches (int): running iterations using to compute the stats. 22 | """ 23 | 24 | # Prepares all the bn layers. 25 | bn_layers = [ 26 | m 27 | for m in model.modules() 28 | if any( 29 | ( 30 | isinstance(m, bn_type) 31 | for bn_type in ( 32 | torch.nn.BatchNorm1d, 33 | torch.nn.BatchNorm2d, 34 | torch.nn.BatchNorm3d, 35 | ) 36 | ) 37 | ) 38 | ] 39 | 40 | # In order to make the running stats only reflect the current batch, the 41 | # momentum is disabled. 42 | # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean 43 | # Setting the momentum to 1.0 to compute the stats without momentum. 44 | momentum_actual = [bn.momentum for bn in bn_layers] 45 | for bn in bn_layers: 46 | bn.momentum = 1.0 47 | 48 | # Calculates the running iterations for precise stats computation. 49 | running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers] 50 | running_square_mean = [torch.zeros_like(bn.running_var) for bn in bn_layers] 51 | 52 | for ind, (inputs, _, _) in enumerate( 53 | itertools.islice(data_loader, num_batches) 54 | ): 55 | # Forwards the model to update the bn stats. 56 | if isinstance(inputs, (list,)): 57 | for i in range(len(inputs)): 58 | inputs[i] = inputs[i].float().cuda(non_blocking=True) 59 | else: 60 | inputs = inputs.cuda(non_blocking=True) 61 | model(inputs) 62 | 63 | for i, bn in enumerate(bn_layers): 64 | # Accumulates the bn stats. 65 | running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) 66 | # $E(x^2) = Var(x) + E(x)^2$. 67 | cur_square_mean = bn.running_var + bn.running_mean ** 2 68 | running_square_mean[i] += ( 69 | cur_square_mean - running_square_mean[i] 70 | ) / (ind + 1) 71 | 72 | for i, bn in enumerate(bn_layers): 73 | bn.running_mean = running_mean[i] 74 | # Var(x) = $E(x^2) - E(x)^2$. 75 | bn.running_var = running_square_mean[i] - bn.running_mean ** 2 76 | # Sets the precise bn stats. 77 | bn.momentum = momentum_actual[i] 78 | -------------------------------------------------------------------------------- /slowfast/utils/lr_policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Learning rate policy.""" 5 | 6 | import math 7 | 8 | 9 | def get_lr_at_epoch(cfg, cur_epoch): 10 | """ 11 | Retrieve the learning rate of the current epoch with the option to perform 12 | warm up in the beginning of the training stage. 13 | Args: 14 | cfg (CfgNode): configs. Details can be found in 15 | slowfast/config/defaults.py 16 | cur_epoch (float): the number of epoch of the current training stage. 17 | """ 18 | lr = get_lr_func(cfg.SOLVER.LR_POLICY)(cfg, cur_epoch) 19 | # Perform warm up. 20 | if cur_epoch < cfg.SOLVER.WARMUP_EPOCHS: 21 | lr_start = cfg.SOLVER.WARMUP_START_LR 22 | lr_end = get_lr_func(cfg.SOLVER.LR_POLICY)( 23 | cfg, cfg.SOLVER.WARMUP_EPOCHS 24 | ) 25 | alpha = (lr_end - lr_start) / cfg.SOLVER.WARMUP_EPOCHS 26 | lr = cur_epoch * alpha + lr_start 27 | return lr 28 | 29 | 30 | def lr_func_cosine(cfg, cur_epoch): 31 | """ 32 | Retrieve the learning rate to specified values at specified epoch with the 33 | cosine learning rate schedule. Details can be found in: 34 | Ilya Loshchilov, and Frank Hutter 35 | SGDR: Stochastic Gradient Descent With Warm Restarts. 36 | Args: 37 | cfg (CfgNode): configs. Details can be found in 38 | slowfast/config/defaults.py 39 | cur_epoch (float): the number of epoch of the current training stage. 40 | """ 41 | return ( 42 | cfg.SOLVER.BASE_LR 43 | * (math.cos(math.pi * cur_epoch / cfg.SOLVER.MAX_EPOCH) + 1.0) 44 | * 0.5 45 | ) 46 | 47 | 48 | def lr_func_steps_with_relative_lrs(cfg, cur_epoch): 49 | """ 50 | Retrieve the learning rate to specified values at specified epoch with the 51 | steps with relative learning rate schedule. 52 | Args: 53 | cfg (CfgNode): configs. Details can be found in 54 | slowfast/config/defaults.py 55 | cur_epoch (float): the number of epoch of the current training stage. 56 | """ 57 | ind = get_step_index(cfg, cur_epoch) 58 | return cfg.SOLVER.LRS[ind] * cfg.SOLVER.BASE_LR 59 | 60 | 61 | def get_step_index(cfg, cur_epoch): 62 | """ 63 | Retrieves the lr step index for the given epoch. 64 | Args: 65 | cfg (CfgNode): configs. Details can be found in 66 | slowfast/config/defaults.py 67 | cur_epoch (float): the number of epoch of the current training stage. 68 | """ 69 | steps = cfg.SOLVER.STEPS + [cfg.SOLVER.MAX_EPOCH] 70 | for ind, step in enumerate(steps): # NoQA 71 | if cur_epoch < step: 72 | break 73 | return ind - 1 74 | 75 | 76 | def get_lr_func(lr_policy): 77 | """ 78 | Given the configs, retrieve the specified lr policy function. 79 | Args: 80 | lr_policy (string): the learning rate policy to use for the job. 81 | """ 82 | policy = "lr_func_" + lr_policy 83 | if policy not in globals(): 84 | raise NotImplementedError("Unknown LR policy: {}".format(lr_policy)) 85 | else: 86 | return globals()[policy] 87 | -------------------------------------------------------------------------------- /slowfast/models/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Optimizer.""" 5 | 6 | import torch 7 | 8 | import slowfast.utils.lr_policy as lr_policy 9 | 10 | 11 | def construct_optimizer(model, cfg): 12 | """ 13 | Construct a stochastic gradient descent or ADAM optimizer with momentum. 14 | Details can be found in: 15 | Herbert Robbins, and Sutton Monro. "A stochastic approximation method." 16 | and 17 | Diederik P.Kingma, and Jimmy Ba. 18 | "Adam: A Method for Stochastic Optimization." 19 | 20 | Args: 21 | model (model): model to perform stochastic gradient descent 22 | optimization or ADAM optimization. 23 | cfg (config): configs of hyper-parameters of SGD or ADAM, includes base 24 | learning rate, momentum, weight_decay, dampening, and etc. 25 | """ 26 | # Batchnorm parameters. 27 | bn_params = [] 28 | # Non-batchnorm parameters. 29 | non_bn_parameters = [] 30 | for name, p in model.named_parameters(): 31 | if "bn" in name: 32 | bn_params.append(p) 33 | else: 34 | non_bn_parameters.append(p) 35 | # Apply different weight decay to Batchnorm and non-batchnorm parameters. 36 | # In Caffe2 classification codebase the weight decay for batchnorm is 0.0. 37 | # Having a different weight decay on batchnorm might cause a performance 38 | # drop. 39 | optim_params = [ 40 | {"params": bn_params, "weight_decay": cfg.BN.WEIGHT_DECAY}, 41 | {"params": non_bn_parameters, "weight_decay": cfg.SOLVER.WEIGHT_DECAY}, 42 | ] 43 | # Check all parameters will be passed into optimizer. 44 | assert len(list(model.parameters())) == len(non_bn_parameters) + len( 45 | bn_params 46 | ), "parameter size does not match: {} + {} != {}".format( 47 | len(non_bn_parameters), len(bn_params), len(list(model.parameters())) 48 | ) 49 | 50 | if cfg.SOLVER.OPTIMIZING_METHOD == "sgd": 51 | return torch.optim.SGD( 52 | optim_params, 53 | lr=cfg.SOLVER.BASE_LR, 54 | momentum=cfg.SOLVER.MOMENTUM, 55 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 56 | dampening=cfg.SOLVER.DAMPENING, 57 | nesterov=cfg.SOLVER.NESTEROV, 58 | ) 59 | elif cfg.SOLVER.OPTIMIZING_METHOD == "adam": 60 | return torch.optim.Adam( 61 | optim_params, 62 | lr=cfg.SOLVER.BASE_LR, 63 | betas=(0.9, 0.999), 64 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 65 | ) 66 | else: 67 | raise NotImplementedError( 68 | "Does not support {} optimizer".format(cfg.SOLVER.OPTIMIZING_METHOD) 69 | ) 70 | 71 | 72 | def get_epoch_lr(cur_epoch, cfg): 73 | """ 74 | Retrieves the lr for the given epoch (as specified by the lr policy). 75 | Args: 76 | cfg (config): configs of hyper-parameters of ADAM, includes base 77 | learning rate, betas, and weight decays. 78 | cur_epoch (float): the number of epoch of the current training stage. 79 | """ 80 | return lr_policy.get_lr_at_epoch(cfg, cur_epoch) 81 | 82 | 83 | def set_lr(optimizer, new_lr): 84 | """ 85 | Sets the optimizer lr to the specified value. 86 | Args: 87 | optimizer (optim): the optimizer using to optimize the current network. 88 | new_lr (float): the new learning rate to set. 89 | """ 90 | for param_group in optimizer.param_groups: 91 | param_group["lr"] = new_lr 92 | -------------------------------------------------------------------------------- /tools/demo_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import numpy as np 5 | import cv2 6 | import torch 7 | import tqdm 8 | 9 | from slowfast.utils import logging 10 | from slowfast.visualization.demo_loader import VideoReader 11 | from slowfast.visualization.ava_demo_precomputed_boxes import AVAVisualizerWithPrecomputedBox 12 | from slowfast.visualization.predictor import ( 13 | ActionPredictor, 14 | Detectron2Predictor, 15 | draw_predictions, 16 | ) 17 | from slowfast.visualization.utils import init_task_info 18 | from slowfast.visualization.video_visualizer import VideoVisualizer 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | 23 | def run_demo(cfg, frame_provider): 24 | """ 25 | Run demo visualization. 26 | Args: 27 | cfg (CfgNode): configs. Details can be found in 28 | slowfast/config/defaults.py 29 | frame_provider (iterator): Python iterator that return task objects that are filled 30 | with necessary information such as `frames`, `id` and `num_buffer_frames` for the 31 | prediction and visualization pipeline. 32 | """ 33 | # Set random seed from configs. 34 | np.random.seed(cfg.RNG_SEED) 35 | torch.manual_seed(cfg.RNG_SEED) 36 | # Setup logging format. 37 | logging.setup_logging(cfg.OUTPUT_DIR) 38 | # Print config. 39 | logger.info("Run demo with config:") 40 | logger.info(cfg) 41 | assert cfg.NUM_GPUS <= 1, "Cannot run demo on multiple GPUs." 42 | # Print config. 43 | logger.info("Run demo with config:") 44 | logger.info(cfg) 45 | video_vis = VideoVisualizer( 46 | cfg.MODEL.NUM_CLASSES, 47 | cfg.DEMO.LABEL_FILE_PATH, 48 | cfg.TENSORBOARD.MODEL_VIS.TOPK_PREDS, 49 | cfg.TENSORBOARD.MODEL_VIS.COLORMAP, 50 | ) 51 | 52 | if cfg.DETECTION.ENABLE: 53 | object_detector = Detectron2Predictor(cfg) 54 | 55 | model = ActionPredictor(cfg) 56 | 57 | seq_len = cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE 58 | assert ( 59 | cfg.DEMO.BUFFER_SIZE <= seq_len // 2 60 | ), "Buffer size cannot be greater than half of sequence length." 61 | init_task_info( 62 | frame_provider.display_height, 63 | frame_provider.display_width, 64 | cfg.DATA.TEST_CROP_SIZE, 65 | cfg.DEMO.CLIP_VIS_SIZE, 66 | ) 67 | for able_to_read, task in frame_provider: 68 | if not able_to_read: 69 | break 70 | 71 | if cfg.DETECTION.ENABLE: 72 | task = object_detector(task) 73 | 74 | task = model(task) 75 | frames = draw_predictions(task, video_vis) 76 | # hit Esc to quit the demo. 77 | key = cv2.waitKey(1) 78 | if key == 27: 79 | break 80 | yield frames 81 | 82 | 83 | def demo(cfg): 84 | """ 85 | Run inference on an input video or stream from webcam. 86 | Args: 87 | cfg (CfgNode): configs. Details can be found in 88 | slowfast/config/defaults.py 89 | """ 90 | # AVA format-specific visualization with precomputed boxes. 91 | if cfg.DETECTION.ENABLE and cfg.DEMO.PREDS_BOXES != "": 92 | precomputed_box_vis = AVAVisualizerWithPrecomputedBox(cfg) 93 | precomputed_box_vis() 94 | else: 95 | frame_provider = VideoReader(cfg) 96 | 97 | for frames in tqdm.tqdm(run_demo(cfg, frame_provider)): 98 | for frame in frames: 99 | frame_provider.display(frame) 100 | 101 | frame_provider.clean() 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSTC: Boosting Atomic Action Detection with Long-Short-Term Context 2 | 3 | This Repository contains the code on AVA of our ACM MM 2021 paper: LSTC: Boosting Atomic Action Detection with Long-Short-Term Context 4 | 5 | ## Installation 6 | 7 | See [INSTALL.md](./INSTALL.md) for details on installing the codebase, including requirement and environment settings 8 | 9 | ## Data 10 | 11 | For data preparation and setup, our LSTC strictly follows the processing of [PySlowFast](https://github.com/facebookresearch/SlowFast/blob/master/INSTALL.md), 12 | See [DATASET.md](./DATASET.md) for details on preparing the data. 13 | 14 | ## Run the code 15 | 16 | We take SlowFast-ResNet50 as an example 17 | 18 | * train the model 19 | ```shell script 20 | python3 tools/run_net.py --cfg config/AVA/SLOWFAST_32x12_R50_LFB.yaml \ 21 | AVA.FEATURE_BANK_PATH 'path/to/feature/bank/folder' \ 22 | TRAIN.CHECKPOINT_FILE_PATH 'path/to/pretrained/backbone' \ 23 | OUTPUT_DIR 'path/to/output/folder' 24 | ``` 25 | 26 | * test the model 27 | ```shell script 28 | python3 tools/run_net.py --cfg config/AVA/SLOWFAST_32x12_R50_LFB.yaml \ 29 | AVA.FEATURE_BANK_PATH 'path/to/feature/bank/folder' \ 30 | OUTPUT_DIR 'path/to/output/folder' \ 31 | TRAIN.ENABLE False \ 32 | TEST.ENABLE True 33 | ``` 34 | 35 | *If you want to start the DDP training from command line with `torch.distributed.launch`, please set `start_method='cmd'` in `tools/run_net.py`* 36 | 37 | ## Resource 38 | 39 | The codebase provide following resources for fast training and validation 40 | 41 | ### Pretrained backbone on Kinetics 42 | 43 | | backbone | dataset | model type | link | 44 | |----------|:---------------------:|:------------:|:--------------:| 45 | |ResNet50|Kinetics400|Caffe2|[Google Drive](https://drive.google.com/file/d/1zxS57DAXiLswWG-hI8s76zGdtRFNRgxa/view?usp=sharing)/[Baidu Disk](https://pan.baidu.com/s/1VaOY-GBBY9oTc2m-A-9Ogw) (Code: y1wl)| 46 | |ResNet101|Kinetics600|Caffe2|[Google Drive](https://drive.google.com/file/d/1U6i2lGo8-qdtL_UDPHHCHwmfOERJxfnK/view?usp=sharing)/[Baidu Disk](https://pan.baidu.com/s/17I-3YaAAj0I2RELaG6P-xw) (Code: slde)| 47 | 48 | ### Extracted long term feature bank 49 | 50 | | backbone | feature bank (LMDB) | dimension | 51 | |----------|:---------------------:|:------------:| 52 | |ResNet50|[Google Drive](https://drive.google.com/file/d/1IqFuq7GMSBFnHopjbNcDJAIES1EtxpQR/view?usp=sharing)|1280| 53 | |ResNet101|[Google Drive](https://drive.google.com/file/d/1ND4sSGwAv2SFR42J90Vj9cNn1glz1Ex3/view?usp=sharing)|2304| 54 | 55 | ### Checkpoint file 56 | 57 | | backbone | checkpoint | model type | 58 | |----------|:---------------------:|:-----------:| 59 | |ResNet50|[Google Drive](https://drive.google.com/file/d/1yimMvcOXaASOFOmp64HKO13LzS5b_YCj/view?usp=sharing)/[Baidu Disk](https://pan.baidu.com/s/1deRNnxgSwlAuOWHAMrzntQ) (Code: fi0s)|pytorch| 60 | |ResNet101|[Google Drive](https://drive.google.com/file/d/1BZ4MzlhUOzuvBPyaS8DAHcyGikh6TAJh/view?usp=sharing)/[Baidu Disk](https://pan.baidu.com/s/11LesMQk6dU7XNYw_ftsADQ) (Code: g63o)|pytorch| 61 | 62 | ## Acknowledgement 63 | 64 | This codebase is built upon [PySlowFast](https://github.com/facebookresearch/SlowFast). 65 | 66 | ## Citation 67 | 68 | If you find this repository helps your research, please refer following paper 69 | ```bibtex 70 | @InProceedings{Yuxi_2021_ACM, 71 | author = {Li, Yuxi and Zhang, Boshen and Li, Jian and Wang, Yabiao and Wang, Chengjie and Li, Jilin and Huang, Feiyue and Lin, Weiyao}, 72 | title = {LSTC: Boosting Atomic Action Detection with Long-Short-Term Context}, 73 | booktitle = {ACM Conference on Multimedia}, 74 | month = {October}, 75 | year = {2021} 76 | } 77 | ``` -------------------------------------------------------------------------------- /slowfast/utils/parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Argument parser functions.""" 5 | 6 | import argparse 7 | import sys 8 | 9 | import slowfast.utils.checkpoint as cu 10 | from slowfast.config.defaults import get_cfg 11 | 12 | 13 | def parse_args(): 14 | """ 15 | Parse the following arguments for a default parser for PySlowFast users. 16 | Args: 17 | shard_id (int): shard id for the current machine. Starts from 0 to 18 | num_shards - 1. If single machine is used, then set shard id to 0. 19 | num_shards (int): number of shards using by the job. 20 | init_method (str): initialization method to launch the job with multiple 21 | devices. Options includes TCP or shared file-system for 22 | initialization. details can be find in 23 | https://pytorch.org/docs/stable/distributed.html#tcp-initialization 24 | cfg (str): path to the config file. 25 | opts (argument): provide addtional options from the command line, it 26 | overwrites the config loaded from file. 27 | """ 28 | parser = argparse.ArgumentParser( 29 | description="Provide SlowFast video training and testing pipeline." 30 | ) 31 | parser.add_argument( 32 | "--shard_id", 33 | help="The shard id of current node, Starts from 0 to num_shards - 1", 34 | default=0, 35 | type=int, 36 | ) 37 | parser.add_argument( 38 | "--num_shards", 39 | help="Number of shards using by the job", 40 | default=1, 41 | type=int, 42 | ) 43 | parser.add_argument( 44 | "--init_method", 45 | help="Initialization method, includes TCP or shared file-system", 46 | default="tcp://localhost:9999", 47 | type=str, 48 | ) 49 | parser.add_argument( 50 | "--cfg", 51 | dest="cfg_file", 52 | help="Path to the config file", 53 | default="configs/Kinetics/SLOWFAST_4x16_R50.yaml", 54 | type=str, 55 | ) 56 | parser.add_argument( 57 | "--local_rank", 58 | help="Rank number of current device", 59 | default=-1, 60 | type=int 61 | ) 62 | parser.add_argument( 63 | "opts", 64 | help="See slowfast/config/defaults.py for all options", 65 | default=None, 66 | nargs=argparse.REMAINDER, 67 | ) 68 | if len(sys.argv) == 1: 69 | parser.print_help() 70 | return parser.parse_args() 71 | 72 | 73 | def load_config(args): 74 | """ 75 | Given the arguemnts, load and initialize the configs. 76 | Args: 77 | args (argument): arguments includes `shard_id`, `num_shards`, 78 | `init_method`, `cfg_file`, and `opts`. 79 | """ 80 | # Setup cfg. 81 | cfg = get_cfg() 82 | # Load config from cfg. 83 | if args.cfg_file is not None: 84 | cfg.merge_from_file(args.cfg_file) 85 | # Load config from command line, overwrite config from opts. 86 | if args.opts is not None: 87 | cfg.merge_from_list(args.opts) 88 | 89 | # Inherit parameters from args. 90 | if hasattr(args, "num_shards") and hasattr(args, "shard_id"): 91 | cfg.NUM_SHARDS = args.num_shards 92 | cfg.SHARD_ID = args.shard_id 93 | if hasattr(args, "rng_seed"): 94 | cfg.RNG_SEED = args.rng_seed 95 | if hasattr(args, "output_dir"): 96 | cfg.OUTPUT_DIR = args.output_dir 97 | 98 | # Create the checkpoint dir. 99 | cu.make_checkpoint_dir(cfg.OUTPUT_DIR) 100 | return cfg 101 | -------------------------------------------------------------------------------- /slowfast/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Functions for benchmarks. 4 | """ 5 | 6 | import numpy as np 7 | import pprint 8 | import torch 9 | import tqdm 10 | from fvcore.common.timer import Timer 11 | 12 | import slowfast.utils.logging as logging 13 | import slowfast.utils.misc as misc 14 | from slowfast.datasets import loader 15 | from slowfast.utils.env import setup_environment 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | def benchmark_data_loading(cfg): 21 | """ 22 | Benchmark the speed of data loading in PySlowFast. 23 | Args: 24 | 25 | cfg (CfgNode): configs. Details can be found in 26 | slowfast/config/defaults.py 27 | """ 28 | # Set up environment. 29 | setup_environment() 30 | # Set random seed from configs. 31 | np.random.seed(cfg.RNG_SEED) 32 | torch.manual_seed(cfg.RNG_SEED) 33 | 34 | # Setup logging format. 35 | logging.setup_logging(cfg.OUTPUT_DIR) 36 | 37 | # Print config. 38 | logger.info("Benchmark data loading with config:") 39 | logger.info(pprint.pformat(cfg)) 40 | 41 | timer = Timer() 42 | dataloader = loader.construct_loader(cfg, "train") 43 | logger.info( 44 | "Initialize loader using {:.2f} seconds.".format(timer.seconds()) 45 | ) 46 | # Total batch size across different machines. 47 | batch_size = cfg.TRAIN.BATCH_SIZE * cfg.NUM_SHARDS 48 | log_period = cfg.BENCHMARK.LOG_PERIOD 49 | epoch_times = [] 50 | # Test for a few epochs. 51 | for cur_epoch in range(cfg.BENCHMARK.NUM_EPOCHS): 52 | timer = Timer() 53 | timer_epoch = Timer() 54 | iter_times = [] 55 | if cfg.BENCHMARK.SHUFFLE: 56 | loader.shuffle_dataset(dataloader, cur_epoch) 57 | for cur_iter, _ in enumerate(tqdm.tqdm(dataloader)): 58 | if cur_iter > 0 and cur_iter % log_period == 0: 59 | iter_times.append(timer.seconds()) 60 | ram_usage, ram_total = misc.cpu_mem_usage() 61 | logger.info( 62 | "Epoch {}: {} iters ({} videos) in {:.2f} seconds. " 63 | "RAM Usage: {:.2f}/{:.2f} GB.".format( 64 | cur_epoch, 65 | log_period, 66 | log_period * batch_size, 67 | iter_times[-1], 68 | ram_usage, 69 | ram_total, 70 | ) 71 | ) 72 | timer.reset() 73 | epoch_times.append(timer_epoch.seconds()) 74 | ram_usage, ram_total = misc.cpu_mem_usage() 75 | logger.info( 76 | "Epoch {}: in total {} iters ({} videos) in {:.2f} seconds. " 77 | "RAM Usage: {:.2f}/{:.2f} GB.".format( 78 | cur_epoch, 79 | len(dataloader), 80 | len(dataloader) * batch_size, 81 | epoch_times[-1], 82 | ram_usage, 83 | ram_total, 84 | ) 85 | ) 86 | logger.info( 87 | "Epoch {}: on average every {} iters ({} videos) take {:.2f}/{:.2f} " 88 | "(avg/std) seconds.".format( 89 | cur_epoch, 90 | log_period, 91 | log_period * batch_size, 92 | np.mean(iter_times), 93 | np.std(iter_times), 94 | ) 95 | ) 96 | logger.info( 97 | "On average every epoch ({} videos) takes {:.2f}/{:.2f} " 98 | "(avg/std) seconds.".format( 99 | len(dataloader) * batch_size, 100 | np.mean(epoch_times), 101 | np.std(epoch_times), 102 | ) 103 | ) 104 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/ava_action_list_v2.1_for_activitynet_2018.pbtxt.txt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "bend/bow (at the waist)" 3 | id: 1 4 | } 5 | item { 6 | name: "crouch/kneel" 7 | id: 3 8 | } 9 | item { 10 | name: "dance" 11 | id: 4 12 | } 13 | item { 14 | name: "fall down" 15 | id: 5 16 | } 17 | item { 18 | name: "get up" 19 | id: 6 20 | } 21 | item { 22 | name: "jump/leap" 23 | id: 7 24 | } 25 | item { 26 | name: "lie/sleep" 27 | id: 8 28 | } 29 | item { 30 | name: "martial art" 31 | id: 9 32 | } 33 | item { 34 | name: "run/jog" 35 | id: 10 36 | } 37 | item { 38 | name: "sit" 39 | id: 11 40 | } 41 | item { 42 | name: "stand" 43 | id: 12 44 | } 45 | item { 46 | name: "swim" 47 | id: 13 48 | } 49 | item { 50 | name: "walk" 51 | id: 14 52 | } 53 | item { 54 | name: "answer phone" 55 | id: 15 56 | } 57 | item { 58 | name: "carry/hold (an object)" 59 | id: 17 60 | } 61 | item { 62 | name: "climb (e.g., a mountain)" 63 | id: 20 64 | } 65 | item { 66 | name: "close (e.g., a door, a box)" 67 | id: 22 68 | } 69 | item { 70 | name: "cut" 71 | id: 24 72 | } 73 | item { 74 | name: "dress/put on clothing" 75 | id: 26 76 | } 77 | item { 78 | name: "drink" 79 | id: 27 80 | } 81 | item { 82 | name: "drive (e.g., a car, a truck)" 83 | id: 28 84 | } 85 | item { 86 | name: "eat" 87 | id: 29 88 | } 89 | item { 90 | name: "enter" 91 | id: 30 92 | } 93 | item { 94 | name: "hit (an object)" 95 | id: 34 96 | } 97 | item { 98 | name: "lift/pick up" 99 | id: 36 100 | } 101 | item { 102 | name: "listen (e.g., to music)" 103 | id: 37 104 | } 105 | item { 106 | name: "open (e.g., a window, a car door)" 107 | id: 38 108 | } 109 | item { 110 | name: "play musical instrument" 111 | id: 41 112 | } 113 | item { 114 | name: "point to (an object)" 115 | id: 43 116 | } 117 | item { 118 | name: "pull (an object)" 119 | id: 45 120 | } 121 | item { 122 | name: "push (an object)" 123 | id: 46 124 | } 125 | item { 126 | name: "put down" 127 | id: 47 128 | } 129 | item { 130 | name: "read" 131 | id: 48 132 | } 133 | item { 134 | name: "ride (e.g., a bike, a car, a horse)" 135 | id: 49 136 | } 137 | item { 138 | name: "sail boat" 139 | id: 51 140 | } 141 | item { 142 | name: "shoot" 143 | id: 52 144 | } 145 | item { 146 | name: "smoke" 147 | id: 54 148 | } 149 | item { 150 | name: "take a photo" 151 | id: 56 152 | } 153 | item { 154 | name: "text on/look at a cellphone" 155 | id: 57 156 | } 157 | item { 158 | name: "throw" 159 | id: 58 160 | } 161 | item { 162 | name: "touch (an object)" 163 | id: 59 164 | } 165 | item { 166 | name: "turn (e.g., a screwdriver)" 167 | id: 60 168 | } 169 | item { 170 | name: "watch (e.g., TV)" 171 | id: 61 172 | } 173 | item { 174 | name: "work on a computer" 175 | id: 62 176 | } 177 | item { 178 | name: "write" 179 | id: 63 180 | } 181 | item { 182 | name: "fight/hit (a person)" 183 | id: 64 184 | } 185 | item { 186 | name: "give/serve (an object) to (a person)" 187 | id: 65 188 | } 189 | item { 190 | name: "grab (a person)" 191 | id: 66 192 | } 193 | item { 194 | name: "hand clap" 195 | id: 67 196 | } 197 | item { 198 | name: "hand shake" 199 | id: 68 200 | } 201 | item { 202 | name: "hand wave" 203 | id: 69 204 | } 205 | item { 206 | name: "hug (a person)" 207 | id: 70 208 | } 209 | item { 210 | name: "kiss (a person)" 211 | id: 72 212 | } 213 | item { 214 | name: "lift (a person)" 215 | id: 73 216 | } 217 | item { 218 | name: "listen to (a person)" 219 | id: 74 220 | } 221 | item { 222 | name: "push (another person)" 223 | id: 76 224 | } 225 | item { 226 | name: "sing to (e.g., self, a person, a group)" 227 | id: 77 228 | } 229 | item { 230 | name: "take (an object) from (a person)" 231 | id: 78 232 | } 233 | item { 234 | name: "talk to (e.g., self, a person, a group)" 235 | id: 79 236 | } 237 | item { 238 | name: "watch (a person)" 239 | id: 80 240 | } 241 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/np_box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for [N, 4] numpy arrays representing bounding boxes. 17 | 18 | Example box operations that are supported: 19 | * Areas: compute bounding box areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | from __future__ import ( 23 | absolute_import, 24 | division, 25 | print_function, 26 | unicode_literals, 27 | ) 28 | import numpy as np 29 | 30 | 31 | def area(boxes): 32 | """Computes area of boxes. 33 | 34 | Args: 35 | boxes: Numpy array with shape [N, 4] holding N boxes 36 | 37 | Returns: 38 | a numpy array with shape [N*1] representing box areas 39 | """ 40 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 41 | 42 | 43 | def intersection(boxes1, boxes2): 44 | """Compute pairwise intersection areas between boxes. 45 | 46 | Args: 47 | boxes1: a numpy array with shape [N, 4] holding N boxes 48 | boxes2: a numpy array with shape [M, 4] holding M boxes 49 | 50 | Returns: 51 | a numpy array with shape [N*M] representing pairwise intersection area 52 | """ 53 | [y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1) 54 | [y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1) 55 | 56 | all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2)) 57 | all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2)) 58 | intersect_heights = np.maximum( 59 | np.zeros(all_pairs_max_ymin.shape), 60 | all_pairs_min_ymax - all_pairs_max_ymin, 61 | ) 62 | all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2)) 63 | all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2)) 64 | intersect_widths = np.maximum( 65 | np.zeros(all_pairs_max_xmin.shape), 66 | all_pairs_min_xmax - all_pairs_max_xmin, 67 | ) 68 | return intersect_heights * intersect_widths 69 | 70 | 71 | def iou(boxes1, boxes2): 72 | """Computes pairwise intersection-over-union between box collections. 73 | 74 | Args: 75 | boxes1: a numpy array with shape [N, 4] holding N boxes. 76 | boxes2: a numpy array with shape [M, 4] holding N boxes. 77 | 78 | Returns: 79 | a numpy array with shape [N, M] representing pairwise iou scores. 80 | """ 81 | intersect = intersection(boxes1, boxes2) 82 | area1 = area(boxes1) 83 | area2 = area(boxes2) 84 | union = ( 85 | np.expand_dims(area1, axis=1) 86 | + np.expand_dims(area2, axis=0) 87 | - intersect 88 | ) 89 | return intersect / union 90 | 91 | 92 | def ioa(boxes1, boxes2): 93 | """Computes pairwise intersection-over-area between box collections. 94 | 95 | Intersection-over-area (ioa) between two boxes box1 and box2 is defined as 96 | their intersection area over box2's area. Note that ioa is not symmetric, 97 | that is, IOA(box1, box2) != IOA(box2, box1). 98 | 99 | Args: 100 | boxes1: a numpy array with shape [N, 4] holding N boxes. 101 | boxes2: a numpy array with shape [M, 4] holding N boxes. 102 | 103 | Returns: 104 | a numpy array with shape [N, M] representing pairwise ioa scores. 105 | """ 106 | intersect = intersection(boxes1, boxes2) 107 | areas = np.expand_dims(area(boxes2), axis=0) 108 | return intersect / areas 109 | -------------------------------------------------------------------------------- /slowfast/visualization/demo_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import cv2 5 | 6 | from slowfast.visualization.utils import TaskInfo 7 | 8 | 9 | class VideoReader: 10 | """ 11 | VideoReader object for getting frames from video source for real-time inference. 12 | """ 13 | 14 | def __init__(self, cfg): 15 | """ 16 | Args: 17 | cfg (CfgNode): configs. Details can be found in 18 | slowfast/config/defaults.py 19 | """ 20 | assert ( 21 | cfg.DEMO.WEBCAM > -1 or cfg.DEMO.INPUT_VIDEO != "" 22 | ), "Must specify a data source as input." 23 | 24 | self.source = ( 25 | cfg.DEMO.WEBCAM if cfg.DEMO.WEBCAM > -1 else cfg.DEMO.INPUT_VIDEO 26 | ) 27 | 28 | self.display_width = cfg.DEMO.DISPLAY_WIDTH 29 | self.display_height = cfg.DEMO.DISPLAY_HEIGHT 30 | 31 | self.cap = cv2.VideoCapture(self.source) 32 | 33 | if self.display_width > 0 and self.display_height > 0: 34 | self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.display_width) 35 | self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.display_height) 36 | else: 37 | self.display_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 38 | self.display_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | 40 | if not self.cap.isOpened(): 41 | raise IOError("Video {} cannot be opened".format(self.source)) 42 | 43 | self.output_file = None 44 | if cfg.DEMO.OUTPUT_FILE != "": 45 | if cfg.DEMO.OUTPUT_FPS == -1: 46 | output_fps = self.cap.get(cv2.CAP_PROP_FPS) 47 | else: 48 | output_fps = cfg.DEMO.OUTPUT_FPS 49 | self.output_file = self.get_output_file( 50 | cfg.DEMO.OUTPUT_FILE, fps=output_fps 51 | ) 52 | self.id = -1 53 | self.buffer = [] 54 | self.buffer_size = cfg.DEMO.BUFFER_SIZE 55 | self.seq_length = cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE 56 | 57 | def __iter__(self): 58 | return self 59 | 60 | def __next__(self): 61 | """ 62 | Read and return the required number of frames for 1 clip. 63 | Returns: 64 | was_read (bool): False if not enough frames to return. 65 | task (TaskInfo object): object contains metadata for the current clips. 66 | """ 67 | self.id += 1 68 | task = TaskInfo() 69 | 70 | frames = [] 71 | if len(self.buffer) != 0: 72 | frames = self.buffer 73 | was_read = True 74 | while was_read and len(frames) < self.seq_length: 75 | was_read, frame = self.cap.read() 76 | frames.append(frame) 77 | if was_read: 78 | self.buffer = frames[-self.buffer_size :] 79 | 80 | task.add_frames(self.id, frames) 81 | task.num_buffer_frames = 0 if self.id == 0 else self.buffer_size 82 | 83 | return was_read, task 84 | 85 | def get_output_file(self, path, fps=30): 86 | """ 87 | Return a video writer object. 88 | Args: 89 | path (str): path to the output video file. 90 | fps (int or float): frames per second. 91 | """ 92 | return cv2.VideoWriter( 93 | filename=path, 94 | fourcc=cv2.VideoWriter_fourcc(*"mp4v"), 95 | fps=float(fps), 96 | frameSize=(self.display_width, self.display_height), 97 | isColor=True, 98 | ) 99 | 100 | def display(self, frame): 101 | """ 102 | Either display a single frame (BGR image) to a window or write to 103 | an output file if output path is provided. 104 | """ 105 | if self.output_file is None: 106 | cv2.imshow("SlowFast", frame) 107 | else: 108 | self.output_file.write(frame) 109 | 110 | def clean(self): 111 | """ 112 | Clean up open video files and windows. 113 | """ 114 | self.cap.release() 115 | if self.output_file is None: 116 | cv2.destroyAllWindows() 117 | else: 118 | self.output_file.release() 119 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/np_mask_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for [N, height, width] numpy arrays representing masks. 17 | 18 | Example mask operations that are supported: 19 | * Areas: compute mask areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | from __future__ import ( 23 | absolute_import, 24 | division, 25 | print_function, 26 | unicode_literals, 27 | ) 28 | import numpy as np 29 | 30 | EPSILON = 1e-7 31 | 32 | 33 | def area(masks): 34 | """Computes area of masks. 35 | 36 | Args: 37 | masks: Numpy array with shape [N, height, width] holding N masks. Masks 38 | values are of type np.uint8 and values are in {0,1}. 39 | 40 | Returns: 41 | a numpy array with shape [N*1] representing mask areas. 42 | 43 | Raises: 44 | ValueError: If masks.dtype is not np.uint8 45 | """ 46 | if masks.dtype != np.uint8: 47 | raise ValueError("Masks type should be np.uint8") 48 | return np.sum(masks, axis=(1, 2), dtype=np.float32) 49 | 50 | 51 | def intersection(masks1, masks2): 52 | """Compute pairwise intersection areas between masks. 53 | 54 | Args: 55 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 56 | values are of type np.uint8 and values are in {0,1}. 57 | masks2: a numpy array with shape [M, height, width] holding M masks. Masks 58 | values are of type np.uint8 and values are in {0,1}. 59 | 60 | Returns: 61 | a numpy array with shape [N*M] representing pairwise intersection area. 62 | 63 | Raises: 64 | ValueError: If masks1 and masks2 are not of type np.uint8. 65 | """ 66 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 67 | raise ValueError("masks1 and masks2 should be of type np.uint8") 68 | n = masks1.shape[0] 69 | m = masks2.shape[0] 70 | answer = np.zeros([n, m], dtype=np.float32) 71 | for i in np.arange(n): 72 | for j in np.arange(m): 73 | answer[i, j] = np.sum( 74 | np.minimum(masks1[i], masks2[j]), dtype=np.float32 75 | ) 76 | return answer 77 | 78 | 79 | def iou(masks1, masks2): 80 | """Computes pairwise intersection-over-union between mask collections. 81 | 82 | Args: 83 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 84 | values are of type np.uint8 and values are in {0,1}. 85 | masks2: a numpy array with shape [M, height, width] holding N masks. Masks 86 | values are of type np.uint8 and values are in {0,1}. 87 | 88 | Returns: 89 | a numpy array with shape [N, M] representing pairwise iou scores. 90 | 91 | Raises: 92 | ValueError: If masks1 and masks2 are not of type np.uint8. 93 | """ 94 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 95 | raise ValueError("masks1 and masks2 should be of type np.uint8") 96 | intersect = intersection(masks1, masks2) 97 | area1 = area(masks1) 98 | area2 = area(masks2) 99 | union = ( 100 | np.expand_dims(area1, axis=1) 101 | + np.expand_dims(area2, axis=0) 102 | - intersect 103 | ) 104 | return intersect / np.maximum(union, EPSILON) 105 | 106 | 107 | def ioa(masks1, masks2): 108 | """Computes pairwise intersection-over-area between box collections. 109 | 110 | Intersection-over-area (ioa) between two masks, mask1 and mask2 is defined as 111 | their intersection area over mask2's area. Note that ioa is not symmetric, 112 | that is, IOA(mask1, mask2) != IOA(mask2, mask1). 113 | 114 | Args: 115 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 116 | values are of type np.uint8 and values are in {0,1}. 117 | masks2: a numpy array with shape [M, height, width] holding N masks. Masks 118 | values are of type np.uint8 and values are in {0,1}. 119 | 120 | Returns: 121 | a numpy array with shape [N, M] representing pairwise ioa scores. 122 | 123 | Raises: 124 | ValueError: If masks1 and masks2 are not of type np.uint8. 125 | """ 126 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 127 | raise ValueError("masks1 and masks2 should be of type np.uint8") 128 | intersect = intersection(masks1, masks2) 129 | areas = np.expand_dims(area(masks2), axis=0) 130 | return intersect / (areas + EPSILON) 131 | -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | # Dataset Preparation 2 | 3 | The AVA Dataset could be downloaded from the [official site](https://research.google.com/ava/download.html#ava_actions_download) 4 | 5 | We followed the same [downloading and preprocessing procedure](https://github.com/facebookresearch/video-long-term-feature-banks/blob/master/DATASET.md) as the [Long-Term Feature Banks for Detailed Video Understanding](https://arxiv.org/abs/1812.05038) do. 6 | 7 | You could follow these steps to download and preprocess the data: 8 | 9 | 1. Download videos 10 | 11 | ``` 12 | DATA_DIR="../../data/ava/videos" 13 | 14 | if [[ ! -d "${DATA_DIR}" ]]; then 15 | echo "${DATA_DIR} doesn't exist. Creating it."; 16 | mkdir -p ${DATA_DIR} 17 | fi 18 | 19 | wget https://s3.amazonaws.com/ava-dataset/annotations/ava_file_names_trainval_v2.1.txt 20 | 21 | for line in $(cat ava_file_names_trainval_v2.1.txt) 22 | do 23 | wget https://s3.amazonaws.com/ava-dataset/trainval/$line -P ${DATA_DIR} 24 | done 25 | ``` 26 | 27 | 2. Cut each video from its 15th to 30th minute 28 | 29 | ``` 30 | IN_DATA_DIR="../../data/ava/videos" 31 | OUT_DATA_DIR="../../data/ava/videos_15min" 32 | 33 | if [[ ! -d "${OUT_DATA_DIR}" ]]; then 34 | echo "${OUT_DATA_DIR} doesn't exist. Creating it."; 35 | mkdir -p ${OUT_DATA_DIR} 36 | fi 37 | 38 | for video in $(ls -A1 -U ${IN_DATA_DIR}/*) 39 | do 40 | out_name="${OUT_DATA_DIR}/${video##*/}" 41 | if [ ! -f "${out_name}" ]; then 42 | ffmpeg -ss 900 -t 901 -i "${video}" "${out_name}" 43 | fi 44 | done 45 | ``` 46 | 47 | 3. Extract frames 48 | 49 | ``` 50 | IN_DATA_DIR="../../data/ava/videos_15min" 51 | OUT_DATA_DIR="../../data/ava/frames" 52 | 53 | if [[ ! -d "${OUT_DATA_DIR}" ]]; then 54 | echo "${OUT_DATA_DIR} doesn't exist. Creating it."; 55 | mkdir -p ${OUT_DATA_DIR} 56 | fi 57 | 58 | for video in $(ls -A1 -U ${IN_DATA_DIR}/*) 59 | do 60 | video_name=${video##*/} 61 | 62 | if [[ $video_name = *".webm" ]]; then 63 | video_name=${video_name::-5} 64 | else 65 | video_name=${video_name::-4} 66 | fi 67 | 68 | out_video_dir=${OUT_DATA_DIR}/${video_name}/ 69 | mkdir -p "${out_video_dir}" 70 | 71 | out_name="${out_video_dir}/${video_name}_%06d.jpg" 72 | 73 | ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}" 74 | done 75 | ``` 76 | 77 | 4. Download annotations 78 | 79 | ``` 80 | DATA_DIR="../../data/ava/annotations" 81 | 82 | if [[ ! -d "${DATA_DIR}" ]]; then 83 | echo "${DATA_DIR} doesn't exist. Creating it."; 84 | mkdir -p ${DATA_DIR} 85 | fi 86 | 87 | wget https://research.google.com/ava/download/ava_train_v2.1.csv -P ${DATA_DIR} 88 | wget https://research.google.com/ava/download/ava_val_v2.1.csv -P ${DATA_DIR} 89 | wget https://research.google.com/ava/download/ava_action_list_v2.1_for_activitynet_2018.pbtxt -P ${DATA_DIR} 90 | wget https://research.google.com/ava/download/ava_train_excluded_timestamps_v2.1.csv -P ${DATA_DIR} 91 | wget https://research.google.com/ava/download/ava_val_excluded_timestamps_v2.1.csv -P ${DATA_DIR} 92 | ``` 93 | 94 | 5. Download "frame lists" ([train](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/frame_lists/val.csv)) and put them in 95 | the `frame_lists` folder (see structure above). 96 | 97 | 6. Download person boxes ([train](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/annotations/ava_train_predicted_boxes.csv), [val](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/annotations/ava_val_predicted_boxes.csv), [test](https://dl.fbaipublicfiles.com/video-long-term-feature-banks/data/ava/annotations/ava_test_predicted_boxes.csv)) and put them in the `annotations` folder (see structure above). 98 | If you prefer to use your own person detector, please see details 99 | in [here](https://github.com/facebookresearch/video-long-term-feature-banks/blob/master/GETTING_STARTED.md#ava-person-detector). 100 | 101 | 102 | Download the ava dataset with the following structure: 103 | 104 | ``` 105 | ava 106 | |_ frames 107 | | |_ [video name 0] 108 | | | |_ [video name 0]_000001.jpg 109 | | | |_ [video name 0]_000002.jpg 110 | | | |_ ... 111 | | |_ [video name 1] 112 | | |_ [video name 1]_000001.jpg 113 | | |_ [video name 1]_000002.jpg 114 | | |_ ... 115 | |_ frame_lists 116 | | |_ train.csv 117 | | |_ val.csv 118 | |_ annotations 119 | |_ [official AVA annotation files] 120 | |_ ava_train_predicted_boxes.csv 121 | |_ ava_val_predicted_boxes.csv 122 | ``` 123 | 124 | You could also replace the `v2.1` by `v2.2` if you need the AVA v2.2 annotation. You can also download some pre-prepared annotations from [here](https://dl.fbaipublicfiles.com/pyslowfast/annotation/ava/ava_annotations.tar). 125 | 126 | 7. Setup the root folder. In your training and testing phase please ensure your root folder is correctly set in the config file. 127 | You can set `_C.DATA_DIR=/path/to/AVA/folder` in `slowfast/config/defaults.py` before setting up, or config them in the command line 128 | 129 | ``` 130 | DATA_DIR /path/to/AVA/folder ${OTHER COMMAND} 131 | ``` -------------------------------------------------------------------------------- /slowfast/utils/c2_model_loading.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Caffe2 to PyTorch checkpoint name converting utility.""" 5 | 6 | import re 7 | 8 | 9 | def get_name_convert_func(): 10 | """ 11 | Get the function to convert Caffe2 layer names to PyTorch layer names. 12 | Returns: 13 | (func): function to convert parameter name from Caffe2 format to PyTorch 14 | format. 15 | """ 16 | pairs = [ 17 | # ------------------------------------------------------------ 18 | # 'nonlocal_conv3_1_theta_w' -> 's3.pathway0_nonlocal3.conv_g.weight' 19 | [ 20 | r"^nonlocal_conv([0-9]+)_([0-9]+)_(.*)", 21 | r"s\1.pathway0_nonlocal\2_\3", 22 | ], 23 | # 'theta' -> 'conv_theta' 24 | [r"^(.*)_nonlocal([0-9]+)_(theta)(.*)", r"\1_nonlocal\2.conv_\3\4"], 25 | # 'g' -> 'conv_g' 26 | [r"^(.*)_nonlocal([0-9]+)_(g)(.*)", r"\1_nonlocal\2.conv_\3\4"], 27 | # 'phi' -> 'conv_phi' 28 | [r"^(.*)_nonlocal([0-9]+)_(phi)(.*)", r"\1_nonlocal\2.conv_\3\4"], 29 | # 'out' -> 'conv_out' 30 | [r"^(.*)_nonlocal([0-9]+)_(out)(.*)", r"\1_nonlocal\2.conv_\3\4"], 31 | # 'nonlocal_conv4_5_bn_s' -> 's4.pathway0_nonlocal3.bn.weight' 32 | [r"^(.*)_nonlocal([0-9]+)_(bn)_(.*)", r"\1_nonlocal\2.\3.\4"], 33 | # ------------------------------------------------------------ 34 | # 't_pool1_subsample_bn' -> 's1_fuse.conv_f2s.bn.running_mean' 35 | [r"^t_pool1_subsample_bn_(.*)", r"s1_fuse.bn.\1"], 36 | # 't_pool1_subsample' -> 's1_fuse.conv_f2s' 37 | [r"^t_pool1_subsample_(.*)", r"s1_fuse.conv_f2s.\1"], 38 | # 't_res4_5_branch2c_bn_subsample_bn_rm' -> 's4_fuse.conv_f2s.bias' 39 | [ 40 | r"^t_res([0-9]+)_([0-9]+)_branch2c_bn_subsample_bn_(.*)", 41 | r"s\1_fuse.bn.\3", 42 | ], 43 | # 't_pool1_subsample' -> 's1_fuse.conv_f2s' 44 | [ 45 | r"^t_res([0-9]+)_([0-9]+)_branch2c_bn_subsample_(.*)", 46 | r"s\1_fuse.conv_f2s.\3", 47 | ], 48 | # ------------------------------------------------------------ 49 | # 'res4_4_branch_2c_bn_b' -> 's4.pathway0_res4.branch2.c_bn_b' 50 | [ 51 | r"^res([0-9]+)_([0-9]+)_branch([0-9]+)([a-z])_(.*)", 52 | r"s\1.pathway0_res\2.branch\3.\4_\5", 53 | ], 54 | # 'res_conv1_bn_' -> 's1.pathway0_stem.bn.' 55 | [r"^res_conv1_bn_(.*)", r"s1.pathway0_stem.bn.\1"], 56 | # 'conv1_w_momentum' -> 's1.pathway0_stem.conv.' 57 | [r"^conv1_(.*)", r"s1.pathway0_stem.conv.\1"], 58 | # 'res4_0_branch1_w' -> 'S4.pathway0_res0.branch1.weight' 59 | [ 60 | r"^res([0-9]+)_([0-9]+)_branch([0-9]+)_(.*)", 61 | r"s\1.pathway0_res\2.branch\3_\4", 62 | ], 63 | # 'res_conv1_' -> 's1.pathway0_stem.conv.' 64 | [r"^res_conv1_(.*)", r"s1.pathway0_stem.conv.\1"], 65 | # ------------------------------------------------------------ 66 | # 'res4_4_branch_2c_bn_b' -> 's4.pathway0_res4.branch2.c_bn_b' 67 | [ 68 | r"^t_res([0-9]+)_([0-9]+)_branch([0-9]+)([a-z])_(.*)", 69 | r"s\1.pathway1_res\2.branch\3.\4_\5", 70 | ], 71 | # 'res_conv1_bn_' -> 's1.pathway0_stem.bn.' 72 | [r"^t_res_conv1_bn_(.*)", r"s1.pathway1_stem.bn.\1"], 73 | # 'conv1_w_momentum' -> 's1.pathway0_stem.conv.' 74 | [r"^t_conv1_(.*)", r"s1.pathway1_stem.conv.\1"], 75 | # 'res4_0_branch1_w' -> 'S4.pathway0_res0.branch1.weight' 76 | [ 77 | r"^t_res([0-9]+)_([0-9]+)_branch([0-9]+)_(.*)", 78 | r"s\1.pathway1_res\2.branch\3_\4", 79 | ], 80 | # 'res_conv1_' -> 's1.pathway0_stem.conv.' 81 | [r"^t_res_conv1_(.*)", r"s1.pathway1_stem.conv.\1"], 82 | # ------------------------------------------------------------ 83 | # pred_ -> head.projection. 84 | [r"pred_(.*)", r"head.projection.\1"], 85 | # '.bn_b' -> '.weight' 86 | [r"(.*)bn.b\Z", r"\1bn.bias"], 87 | # '.bn_s' -> '.weight' 88 | [r"(.*)bn.s\Z", r"\1bn.weight"], 89 | # '_bn_rm' -> '.running_mean' 90 | [r"(.*)bn.rm\Z", r"\1bn.running_mean"], 91 | # '_bn_riv' -> '.running_var' 92 | [r"(.*)bn.riv\Z", r"\1bn.running_var"], 93 | # '_b' -> '.bias' 94 | [r"(.*)[\._]b\Z", r"\1.bias"], 95 | # '_w' -> '.weight' 96 | [r"(.*)[\._]w\Z", r"\1.weight"], 97 | ] 98 | 99 | def convert_caffe2_name_to_pytorch(caffe2_layer_name): 100 | """ 101 | Convert the caffe2_layer_name to pytorch format by apply the list of 102 | regular expressions. 103 | Args: 104 | caffe2_layer_name (str): caffe2 layer name. 105 | Returns: 106 | (str): pytorch layer name. 107 | """ 108 | for source, dest in pairs: 109 | caffe2_layer_name = re.sub(source, dest, caffe2_layer_name) 110 | return caffe2_layer_name 111 | 112 | return convert_caffe2_name_to_pytorch 113 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/np_box_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxList classes and functions.""" 17 | 18 | from __future__ import ( 19 | absolute_import, 20 | division, 21 | print_function, 22 | unicode_literals, 23 | ) 24 | import numpy as np 25 | 26 | 27 | class BoxList(object): 28 | """Box collection. 29 | 30 | BoxList represents a list of bounding boxes as numpy array, where each 31 | bounding box is represented as a row of 4 numbers, 32 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a 33 | given list correspond to a single image. 34 | 35 | Optionally, users can add additional related fields (such as 36 | objectness/classification scores). 37 | """ 38 | 39 | def __init__(self, data): 40 | """Constructs box collection. 41 | 42 | Args: 43 | data: a numpy array of shape [N, 4] representing box coordinates 44 | 45 | Raises: 46 | ValueError: if bbox data is not a numpy array 47 | ValueError: if invalid dimensions for bbox data 48 | """ 49 | if not isinstance(data, np.ndarray): 50 | raise ValueError("data must be a numpy array.") 51 | if len(data.shape) != 2 or data.shape[1] != 4: 52 | raise ValueError("Invalid dimensions for box data.") 53 | if data.dtype != np.float32 and data.dtype != np.float64: 54 | raise ValueError( 55 | "Invalid data type for box data: float is required." 56 | ) 57 | if not self._is_valid_boxes(data): 58 | raise ValueError( 59 | "Invalid box data. data must be a numpy array of " 60 | "N*[y_min, x_min, y_max, x_max]" 61 | ) 62 | self.data = {"boxes": data} 63 | 64 | def num_boxes(self): 65 | """Return number of boxes held in collections.""" 66 | return self.data["boxes"].shape[0] 67 | 68 | def get_extra_fields(self): 69 | """Return all non-box fields.""" 70 | return [k for k in self.data.keys() if k != "boxes"] 71 | 72 | def has_field(self, field): 73 | return field in self.data 74 | 75 | def add_field(self, field, field_data): 76 | """Add data to a specified field. 77 | 78 | Args: 79 | field: a string parameter used to speficy a related field to be accessed. 80 | field_data: a numpy array of [N, ...] representing the data associated 81 | with the field. 82 | Raises: 83 | ValueError: if the field is already exist or the dimension of the field 84 | data does not matches the number of boxes. 85 | """ 86 | if self.has_field(field): 87 | raise ValueError("Field " + field + "already exists") 88 | if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes(): 89 | raise ValueError("Invalid dimensions for field data") 90 | self.data[field] = field_data 91 | 92 | def get(self): 93 | """Convenience function for accesssing box coordinates. 94 | 95 | Returns: 96 | a numpy array of shape [N, 4] representing box corners 97 | """ 98 | return self.get_field("boxes") 99 | 100 | def get_field(self, field): 101 | """Accesses data associated with the specified field in the box collection. 102 | 103 | Args: 104 | field: a string parameter used to speficy a related field to be accessed. 105 | 106 | Returns: 107 | a numpy 1-d array representing data of an associated field 108 | 109 | Raises: 110 | ValueError: if invalid field 111 | """ 112 | if not self.has_field(field): 113 | raise ValueError("field {} does not exist".format(field)) 114 | return self.data[field] 115 | 116 | def get_coordinates(self): 117 | """Get corner coordinates of boxes. 118 | 119 | Returns: 120 | a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max] 121 | """ 122 | box_coordinates = self.get() 123 | y_min = box_coordinates[:, 0] 124 | x_min = box_coordinates[:, 1] 125 | y_max = box_coordinates[:, 2] 126 | x_max = box_coordinates[:, 3] 127 | return [y_min, x_min, y_max, x_max] 128 | 129 | def _is_valid_boxes(self, data): 130 | """Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin]. 131 | 132 | Args: 133 | data: a numpy array of shape [N, 4] representing box coordinates 134 | 135 | Returns: 136 | a boolean indicating whether all ymax of boxes are equal or greater than 137 | ymin, and all xmax of boxes are equal or greater than xmin. 138 | """ 139 | if data.shape[0] > 0: 140 | for i in range(data.shape[0]): 141 | if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]: 142 | return False 143 | return True 144 | -------------------------------------------------------------------------------- /slowfast/models/nonlocal_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Non-local helper""" 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Nonlocal(nn.Module): 11 | """ 12 | Builds Non-local Neural Networks as a generic family of building 13 | blocks for capturing long-range dependencies. Non-local Network 14 | computes the response at a position as a weighted sum of the 15 | features at all positions. This building block can be plugged into 16 | many computer vision architectures. 17 | More details in the paper: https://arxiv.org/pdf/1711.07971.pdf 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dim, 23 | dim_inner, 24 | pool_size=None, 25 | instantiation="softmax", 26 | zero_init_final_conv=False, 27 | zero_init_final_norm=True, 28 | norm_eps=1e-5, 29 | norm_momentum=0.1, 30 | norm_module=nn.BatchNorm3d, 31 | ): 32 | """ 33 | Args: 34 | dim (int): number of dimension for the input. 35 | dim_inner (int): number of dimension inside of the Non-local block. 36 | pool_size (list): the kernel size of spatial temporal pooling, 37 | temporal pool kernel size, spatial pool kernel size, spatial 38 | pool kernel size in order. By default pool_size is None, 39 | then there would be no pooling used. 40 | instantiation (string): supports two different instantiation method: 41 | "dot_product": normalizing correlation matrix with L2. 42 | "softmax": normalizing correlation matrix with Softmax. 43 | zero_init_final_conv (bool): If true, zero initializing the final 44 | convolution of the Non-local block. 45 | zero_init_final_norm (bool): 46 | If true, zero initializing the final batch norm of the Non-local 47 | block. 48 | norm_module (nn.Module): nn.Module for the normalization layer. The 49 | default is nn.BatchNorm3d. 50 | """ 51 | super(Nonlocal, self).__init__() 52 | self.dim = dim 53 | self.dim_inner = dim_inner 54 | self.pool_size = pool_size 55 | self.instantiation = instantiation 56 | self.use_pool = ( 57 | False 58 | if pool_size is None 59 | else any((size > 1 for size in pool_size)) 60 | ) 61 | self.norm_eps = norm_eps 62 | self.norm_momentum = norm_momentum 63 | self._construct_nonlocal( 64 | zero_init_final_conv, zero_init_final_norm, norm_module 65 | ) 66 | 67 | def _construct_nonlocal( 68 | self, zero_init_final_conv, zero_init_final_norm, norm_module 69 | ): 70 | # Three convolution heads: theta, phi, and g. 71 | self.conv_theta = nn.Conv3d( 72 | self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 73 | ) 74 | self.conv_phi = nn.Conv3d( 75 | self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 76 | ) 77 | self.conv_g = nn.Conv3d( 78 | self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 79 | ) 80 | 81 | # Final convolution output. 82 | self.conv_out = nn.Conv3d( 83 | self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0 84 | ) 85 | # Zero initializing the final convolution output. 86 | self.conv_out.zero_init = zero_init_final_conv 87 | 88 | # TODO: change the name to `norm` 89 | self.bn = norm_module( 90 | num_features=self.dim, 91 | eps=self.norm_eps, 92 | momentum=self.norm_momentum, 93 | ) 94 | # Zero initializing the final bn. 95 | self.bn.transform_final_bn = zero_init_final_norm 96 | 97 | # Optional to add the spatial-temporal pooling. 98 | if self.use_pool: 99 | self.pool = nn.MaxPool3d( 100 | kernel_size=self.pool_size, 101 | stride=self.pool_size, 102 | padding=[0, 0, 0], 103 | ) 104 | 105 | def forward(self, x): 106 | x_identity = x 107 | N, C, T, H, W = x.size() 108 | 109 | theta = self.conv_theta(x) 110 | 111 | # Perform temporal-spatial pooling to reduce the computation. 112 | if self.use_pool: 113 | x = self.pool(x) 114 | 115 | phi = self.conv_phi(x) 116 | g = self.conv_g(x) 117 | 118 | theta = theta.view(N, self.dim_inner, -1) 119 | phi = phi.view(N, self.dim_inner, -1) 120 | g = g.view(N, self.dim_inner, -1) 121 | 122 | # (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW). 123 | theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi)) 124 | # For original Non-local paper, there are two main ways to normalize 125 | # the affinity tensor: 126 | # 1) Softmax normalization (norm on exp). 127 | # 2) dot_product normalization. 128 | if self.instantiation == "softmax": 129 | # Normalizing the affinity tensor theta_phi before softmax. 130 | theta_phi = theta_phi * (self.dim_inner ** -0.5) 131 | theta_phi = nn.functional.softmax(theta_phi, dim=2) 132 | elif self.instantiation == "dot_product": 133 | spatial_temporal_dim = theta_phi.shape[2] 134 | theta_phi = theta_phi / spatial_temporal_dim 135 | else: 136 | raise NotImplementedError( 137 | "Unknown norm type {}".format(self.instantiation) 138 | ) 139 | 140 | # (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW). 141 | theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g)) 142 | 143 | # (N, C, TxHxW) => (N, C, T, H, W). 144 | theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W) 145 | 146 | p = self.conv_out(theta_phi_g) 147 | p = self.bn(p) 148 | return x_identity + p 149 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for computing metrics like precision, recall, CorLoc and etc.""" 17 | from __future__ import division 18 | import numpy as np 19 | 20 | 21 | def compute_precision_recall(scores, labels, num_gt): 22 | """Compute precision and recall. 23 | 24 | Args: 25 | scores: A float numpy array representing detection score 26 | labels: A boolean numpy array representing true/false positive labels 27 | num_gt: Number of ground truth instances 28 | 29 | Raises: 30 | ValueError: if the input is not of the correct format 31 | 32 | Returns: 33 | precision: Fraction of positive instances over detected ones. This value is 34 | None if no ground truth labels are present. 35 | recall: Fraction of detected positive instance over all positive instances. 36 | This value is None if no ground truth labels are present. 37 | 38 | """ 39 | if ( 40 | not isinstance(labels, np.ndarray) 41 | or labels.dtype != np.bool 42 | or len(labels.shape) != 1 43 | ): 44 | raise ValueError("labels must be single dimension bool numpy array") 45 | 46 | if not isinstance(scores, np.ndarray) or len(scores.shape) != 1: 47 | raise ValueError("scores must be single dimension numpy array") 48 | 49 | if num_gt < np.sum(labels): 50 | raise ValueError( 51 | "Number of true positives must be smaller than num_gt." 52 | ) 53 | 54 | if len(scores) != len(labels): 55 | raise ValueError("scores and labels must be of the same size.") 56 | 57 | if num_gt == 0: 58 | return None, None 59 | 60 | sorted_indices = np.argsort(scores) 61 | sorted_indices = sorted_indices[::-1] 62 | labels = labels.astype(int) 63 | true_positive_labels = labels[sorted_indices] 64 | false_positive_labels = 1 - true_positive_labels 65 | cum_true_positives = np.cumsum(true_positive_labels) 66 | cum_false_positives = np.cumsum(false_positive_labels) 67 | precision = cum_true_positives.astype(float) / ( 68 | cum_true_positives + cum_false_positives 69 | ) 70 | recall = cum_true_positives.astype(float) / num_gt 71 | return precision, recall 72 | 73 | 74 | def compute_average_precision(precision, recall): 75 | """Compute Average Precision according to the definition in VOCdevkit. 76 | 77 | Precision is modified to ensure that it does not decrease as recall 78 | decrease. 79 | 80 | Args: 81 | precision: A float [N, 1] numpy array of precisions 82 | recall: A float [N, 1] numpy array of recalls 83 | 84 | Raises: 85 | ValueError: if the input is not of the correct format 86 | 87 | Returns: 88 | average_precison: The area under the precision recall curve. NaN if 89 | precision and recall are None. 90 | 91 | """ 92 | if precision is None: 93 | if recall is not None: 94 | raise ValueError("If precision is None, recall must also be None") 95 | return np.NAN 96 | 97 | if not isinstance(precision, np.ndarray) or not isinstance( 98 | recall, np.ndarray 99 | ): 100 | raise ValueError("precision and recall must be numpy array") 101 | if precision.dtype != np.float or recall.dtype != np.float: 102 | raise ValueError("input must be float numpy array.") 103 | if len(precision) != len(recall): 104 | raise ValueError("precision and recall must be of the same size.") 105 | if not precision.size: 106 | return 0.0 107 | if np.amin(precision) < 0 or np.amax(precision) > 1: 108 | raise ValueError("Precision must be in the range of [0, 1].") 109 | if np.amin(recall) < 0 or np.amax(recall) > 1: 110 | raise ValueError("recall must be in the range of [0, 1].") 111 | if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)): 112 | raise ValueError("recall must be a non-decreasing array") 113 | 114 | recall = np.concatenate([[0], recall, [1]]) 115 | precision = np.concatenate([[0], precision, [0]]) 116 | 117 | # Preprocess precision to be a non-decreasing array 118 | for i in range(len(precision) - 2, -1, -1): 119 | precision[i] = np.maximum(precision[i], precision[i + 1]) 120 | 121 | indices = np.where(recall[1:] != recall[:-1])[0] + 1 122 | average_precision = np.sum( 123 | (recall[indices] - recall[indices - 1]) * precision[indices] 124 | ) 125 | return average_precision 126 | 127 | 128 | def compute_cor_loc( 129 | num_gt_imgs_per_class, num_images_correctly_detected_per_class 130 | ): 131 | """Compute CorLoc according to the definition in the following paper. 132 | 133 | https://www.robots.ox.ac.uk/~vgg/rg/papers/deselaers-eccv10.pdf 134 | 135 | Returns nans if there are no ground truth images for a class. 136 | 137 | Args: 138 | num_gt_imgs_per_class: 1D array, representing number of images containing 139 | at least one object instance of a particular class 140 | num_images_correctly_detected_per_class: 1D array, representing number of 141 | images that are correctly detected at least one object instance of a 142 | particular class 143 | 144 | Returns: 145 | corloc_per_class: A float numpy array represents the corloc score of each 146 | class 147 | """ 148 | # Divide by zero expected for classes with no gt examples. 149 | with np.errstate(divide="ignore", invalid="ignore"): 150 | return np.where( 151 | num_gt_imgs_per_class == 0, 152 | np.nan, 153 | num_images_correctly_detected_per_class / num_gt_imgs_per_class, 154 | ) 155 | -------------------------------------------------------------------------------- /tools/extract_feature.py: -------------------------------------------------------------------------------- 1 | #---------------------------------------------------------------------------- 2 | # 3 | # A Basic Script for Feature Bank Extraction Using Trained Models 4 | # 5 | #---------------------------------------------------------------------------- 6 | 7 | import numpy as np 8 | import torch 9 | import pickle 10 | import os 11 | import os.path as osp 12 | import lmdb 13 | 14 | import slowfast.utils.checkpoint as cu 15 | import slowfast.utils.distributed as du 16 | import slowfast.utils.logging as logging 17 | 18 | import slowfast.utils.misc as misc 19 | from slowfast.datasets import loader 20 | from slowfast.models import build_model 21 | from slowfast.utils.meters import AVAMeter, TestMeter, TrainMeter, HieveMeter 22 | from slowfast.utils.parser import load_config, parse_args 23 | from slowfast.utils.misc import launch_job 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | @torch.no_grad() 28 | def feature_extraction_launch(loader, model, meter, cfg, feature_bank): 29 | # Enable eval mode. 30 | model.eval() 31 | 32 | logger.info('extract feature for {} iters'.format(len(loader))) 33 | for cur_iter, (inputs, labels, video_idx, meta, _, _, _) in enumerate(loader): 34 | 35 | meter.iter_tic() 36 | if cfg.NUM_GPUS: 37 | # Transfer the data to the current GPU device. 38 | if isinstance(inputs, (list,)): 39 | for i in range(len(inputs)): 40 | inputs[i] = inputs[i].cuda(non_blocking=True) 41 | else: 42 | inputs = inputs.cuda(non_blocking=True) 43 | 44 | # Transfer the data to the current GPU device. 45 | labels = labels.cuda() 46 | video_idx = video_idx.cuda() 47 | for key, val in meta.items(): 48 | if isinstance(val, (list,)): 49 | for i in range(len(val)): 50 | val[i] = val[i].cuda(non_blocking=True) 51 | else: 52 | meta[key] = val.cuda(non_blocking=True) 53 | 54 | if cfg.DETECTION.ENABLE: 55 | # Compute the predictions. 56 | feat, ctx = model(inputs, meta["boxes"], meta["metadata"], extract=True) 57 | 58 | ori_boxes = meta["ori_boxes"].cpu() 59 | metadata = meta["metadata"].cpu() 60 | 61 | if cfg.NUM_GPUS > 1: 62 | ori_boxes = torch.cat(du.all_gather_unaligned(ori_boxes), dim=0) 63 | metadata = torch.cat(du.all_gather_unaligned(metadata), dim=0) 64 | feats = torch.cat(du.all_gather_unaligned(feat), dim=0) 65 | 66 | feats = feats.detach().cpu() if cfg.NUM_GPUS else feats.detach() 67 | ori_boxes = ( 68 | ori_boxes.detach().cpu() if cfg.NUM_GPUS else ori_boxes.detach() 69 | ) 70 | metadata = ( 71 | metadata.detach().cpu() if cfg.NUM_GPUS else metadata.detach() 72 | ) 73 | 74 | if du.is_master_proc(du.get_world_size()): 75 | num = metadata.shape[0] 76 | for i in range(num): 77 | vid, sec = int(metadata[i, 0].item()), int(metadata[i, 1].item()) 78 | 79 | if vid not in feature_bank: 80 | feature_bank[vid] = dict() 81 | 82 | if sec not in feature_bank[vid]: 83 | feature_bank[vid][sec] = [] 84 | 85 | feature_bank[vid][sec].append(feats[i].squeeze()) 86 | 87 | meter.iter_toc() 88 | meter.log_iter_stats(cur_epoch=0, cur_iter=cur_iter) 89 | du.synchronize() 90 | else: 91 | raise NotImplementedError() 92 | 93 | meter.reset() 94 | 95 | def write(feature_bank, cfg): 96 | 97 | # save as lmdb 98 | 99 | env = lmdb.open(os.path.join(cfg.AVA.FEATURE_BANK_PATH, 'rdb'), map_size=3e10) 100 | txn = env.begin(write=True) 101 | count = 0 102 | 103 | for split in feature_bank: 104 | for vid in feature_bank[split]: 105 | for sec in feature_bank[split][vid]: 106 | feat_key = f"{split}/{vid}/{sec}/feature" 107 | feat_val = pickle.dumps(feature_bank[split][vid][sec]) 108 | txn.put(key=feat_key.encode(), value=feat_val) 109 | 110 | count += 1 111 | if count % 2000 == 0: 112 | logger.info(f"commit for {count} frames") 113 | txn.commit() 114 | txn = env.begin(write=True) 115 | 116 | txn.commit() 117 | env.close() 118 | 119 | def extract_feature(cfg): 120 | """ 121 | Perform multi-view testing on the pretrained video model. 122 | Args: 123 | cfg (CfgNode): configs. Details can be found in 124 | slowfast/config/defaults.py 125 | """ 126 | # Set up environment. 127 | du.init_distributed_training(cfg) 128 | # Set random seed from configs. 129 | np.random.seed(cfg.RNG_SEED) 130 | torch.manual_seed(cfg.RNG_SEED) 131 | 132 | # Setup logging format. 133 | logging.setup_logging(cfg.OUTPUT_DIR) 134 | 135 | # Print config. 136 | logger.info("Extracting AVA features with config:") 137 | logger.info(cfg) 138 | 139 | # Build the video model and print model statistics. 140 | model = build_model(cfg) 141 | if du.is_master_proc() and cfg.LOG_MODEL_INFO: 142 | misc.log_model_info(model, cfg, use_train_input=False) 143 | 144 | cu.load_test_checkpoint(cfg, model) 145 | 146 | # Create video testing loaders. 147 | train_loader = loader.construct_loader(cfg, "train") 148 | test_loader = loader.construct_loader(cfg, "test") 149 | 150 | # Create video feature bank 151 | # format {video_idx: {sec: {name: [feature1, feature2, ....]}}} name in ['feature', 'context'] 152 | feature_bank = { 153 | 'train': dict(), 154 | 'test': dict() 155 | } 156 | 157 | if du.is_master_proc() and not osp.exists(cfg.AVA.FEATURE_BANK_PATH): 158 | os.makedirs(cfg.AVA.FEATURE_BANK_PATH) 159 | 160 | test_meter = AVAMeter(len(test_loader), cfg, mode="test") 161 | train_meter = AVAMeter(len(train_loader), cfg, mode="test") 162 | 163 | # main process for feature extraction 164 | feature_extraction_launch(train_loader, model, train_meter, cfg, feature_bank['train']) 165 | feature_extraction_launch(test_loader, model, test_meter, cfg, feature_bank['test']) 166 | 167 | if du.is_master_proc(du.get_world_size()): 168 | write(feature_bank, cfg) 169 | 170 | if __name__ == '__main__': 171 | 172 | args = parse_args() 173 | cfg = load_config(args) 174 | 175 | launch_job( 176 | cfg=cfg, 177 | args=args, 178 | func=extract_feature 179 | ) 180 | -------------------------------------------------------------------------------- /slowfast/models/stem_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """ResNe(X)t 3D stem helper.""" 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class VideoModelStem(nn.Module): 10 | """ 11 | Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool 12 | on input data tensor for one or multiple pathways. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | dim_in, 18 | dim_out, 19 | kernel, 20 | stride, 21 | padding, 22 | inplace_relu=True, 23 | eps=1e-5, 24 | bn_mmt=0.1, 25 | norm_module=nn.BatchNorm3d, 26 | ): 27 | """ 28 | The `__init__` method of any subclass should also contain these 29 | arguments. List size of 1 for single pathway models (C2D, I3D, Slow 30 | and etc), list size of 2 for two pathway models (SlowFast). 31 | 32 | Args: 33 | dim_in (list): the list of channel dimensions of the inputs. 34 | dim_out (list): the output dimension of the convolution in the stem 35 | layer. 36 | kernel (list): the kernels' size of the convolutions in the stem 37 | layers. Temporal kernel size, height kernel size, width kernel 38 | size in order. 39 | stride (list): the stride sizes of the convolutions in the stem 40 | layer. Temporal kernel stride, height kernel size, width kernel 41 | size in order. 42 | padding (list): the paddings' sizes of the convolutions in the stem 43 | layer. Temporal padding size, height padding size, width padding 44 | size in order. 45 | inplace_relu (bool): calculate the relu on the original input 46 | without allocating new memory. 47 | eps (float): epsilon for batch norm. 48 | bn_mmt (float): momentum for batch norm. Noted that BN momentum in 49 | PyTorch = 1 - BN momentum in Caffe2. 50 | norm_module (nn.Module): nn.Module for the normalization layer. The 51 | default is nn.BatchNorm3d. 52 | """ 53 | super(VideoModelStem, self).__init__() 54 | 55 | assert ( 56 | len( 57 | { 58 | len(dim_in), 59 | len(dim_out), 60 | len(kernel), 61 | len(stride), 62 | len(padding), 63 | } 64 | ) 65 | == 1 66 | ), "Input pathway dimensions are not consistent." 67 | self.num_pathways = len(dim_in) 68 | self.kernel = kernel 69 | self.stride = stride 70 | self.padding = padding 71 | self.inplace_relu = inplace_relu 72 | self.eps = eps 73 | self.bn_mmt = bn_mmt 74 | # Construct the stem layer. 75 | self._construct_stem(dim_in, dim_out, norm_module) 76 | 77 | def _construct_stem(self, dim_in, dim_out, norm_module): 78 | for pathway in range(len(dim_in)): 79 | stem = ResNetBasicStem( 80 | dim_in[pathway], 81 | dim_out[pathway], 82 | self.kernel[pathway], 83 | self.stride[pathway], 84 | self.padding[pathway], 85 | self.inplace_relu, 86 | self.eps, 87 | self.bn_mmt, 88 | norm_module, 89 | ) 90 | self.add_module("pathway{}_stem".format(pathway), stem) 91 | 92 | def forward(self, x): 93 | assert ( 94 | len(x) == self.num_pathways 95 | ), "Input tensor does not contain {} pathway".format(self.num_pathways) 96 | for pathway in range(len(x)): 97 | m = getattr(self, "pathway{}_stem".format(pathway)) 98 | x[pathway] = m(x[pathway]) 99 | return x 100 | 101 | 102 | class ResNetBasicStem(nn.Module): 103 | """ 104 | ResNe(X)t 3D stem module. 105 | Performs spatiotemporal Convolution, BN, and Relu following by a 106 | spatiotemporal pooling. 107 | """ 108 | 109 | def __init__( 110 | self, 111 | dim_in, 112 | dim_out, 113 | kernel, 114 | stride, 115 | padding, 116 | inplace_relu=True, 117 | eps=1e-5, 118 | bn_mmt=0.1, 119 | norm_module=nn.BatchNorm3d, 120 | ): 121 | """ 122 | The `__init__` method of any subclass should also contain these arguments. 123 | 124 | Args: 125 | dim_in (int): the channel dimension of the input. Normally 3 is used 126 | for rgb input, and 2 or 3 is used for optical flow input. 127 | dim_out (int): the output dimension of the convolution in the stem 128 | layer. 129 | kernel (list): the kernel size of the convolution in the stem layer. 130 | temporal kernel size, height kernel size, width kernel size in 131 | order. 132 | stride (list): the stride size of the convolution in the stem layer. 133 | temporal kernel stride, height kernel size, width kernel size in 134 | order. 135 | padding (int): the padding size of the convolution in the stem 136 | layer, temporal padding size, height padding size, width 137 | padding size in order. 138 | inplace_relu (bool): calculate the relu on the original input 139 | without allocating new memory. 140 | eps (float): epsilon for batch norm. 141 | bn_mmt (float): momentum for batch norm. Noted that BN momentum in 142 | PyTorch = 1 - BN momentum in Caffe2. 143 | norm_module (nn.Module): nn.Module for the normalization layer. The 144 | default is nn.BatchNorm3d. 145 | """ 146 | super(ResNetBasicStem, self).__init__() 147 | self.kernel = kernel 148 | self.stride = stride 149 | self.padding = padding 150 | self.inplace_relu = inplace_relu 151 | self.eps = eps 152 | self.bn_mmt = bn_mmt 153 | # Construct the stem layer. 154 | self._construct_stem(dim_in, dim_out, norm_module) 155 | 156 | def _construct_stem(self, dim_in, dim_out, norm_module): 157 | self.conv = nn.Conv3d( 158 | dim_in, 159 | dim_out, 160 | self.kernel, 161 | stride=self.stride, 162 | padding=self.padding, 163 | bias=False, 164 | ) 165 | self.bn = norm_module( 166 | num_features=dim_out, eps=self.eps, momentum=self.bn_mmt 167 | ) 168 | self.relu = nn.ReLU(self.inplace_relu) 169 | self.pool_layer = nn.MaxPool3d( 170 | kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1] 171 | ) 172 | 173 | def forward(self, x): 174 | x = self.conv(x) 175 | x = self.bn(x) 176 | x = self.relu(x) 177 | x = self.pool_layer(x) 178 | return x 179 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Label map utility functions.""" 16 | 17 | from __future__ import ( 18 | absolute_import, 19 | division, 20 | print_function, 21 | unicode_literals, 22 | ) 23 | import logging 24 | 25 | # from google.protobuf import text_format 26 | # from google3.third_party.tensorflow_models.object_detection.protos import string_int_label_map_pb2 27 | 28 | 29 | def _validate_label_map(label_map): 30 | """Checks if a label map is valid. 31 | 32 | Args: 33 | label_map: StringIntLabelMap to validate. 34 | 35 | Raises: 36 | ValueError: if label map is invalid. 37 | """ 38 | for item in label_map.item: 39 | if item.id < 1: 40 | raise ValueError("Label map ids should be >= 1.") 41 | 42 | 43 | def create_category_index(categories): 44 | """Creates dictionary of COCO compatible categories keyed by category id. 45 | 46 | Args: 47 | categories: a list of dicts, each of which has the following keys: 48 | 'id': (required) an integer id uniquely identifying this category. 49 | 'name': (required) string representing category name 50 | e.g., 'cat', 'dog', 'pizza'. 51 | 52 | Returns: 53 | category_index: a dict containing the same entries as categories, but keyed 54 | by the 'id' field of each category. 55 | """ 56 | category_index = {} 57 | for cat in categories: 58 | category_index[cat["id"]] = cat 59 | return category_index 60 | 61 | 62 | def get_max_label_map_index(label_map): 63 | """Get maximum index in label map. 64 | 65 | Args: 66 | label_map: a StringIntLabelMapProto 67 | 68 | Returns: 69 | an integer 70 | """ 71 | return max([item.id for item in label_map.item]) 72 | 73 | 74 | def convert_label_map_to_categories( 75 | label_map, max_num_classes, use_display_name=True 76 | ): 77 | """Loads label map proto and returns categories list compatible with eval. 78 | 79 | This function loads a label map and returns a list of dicts, each of which 80 | has the following keys: 81 | 'id': (required) an integer id uniquely identifying this category. 82 | 'name': (required) string representing category name 83 | e.g., 'cat', 'dog', 'pizza'. 84 | We only allow class into the list if its id-label_id_offset is 85 | between 0 (inclusive) and max_num_classes (exclusive). 86 | If there are several items mapping to the same id in the label map, 87 | we will only keep the first one in the categories list. 88 | 89 | Args: 90 | label_map: a StringIntLabelMapProto or None. If None, a default categories 91 | list is created with max_num_classes categories. 92 | max_num_classes: maximum number of (consecutive) label indices to include. 93 | use_display_name: (boolean) choose whether to load 'display_name' field 94 | as category name. If False or if the display_name field does not exist, 95 | uses 'name' field as category names instead. 96 | Returns: 97 | categories: a list of dictionaries representing all possible categories. 98 | """ 99 | categories = [] 100 | list_of_ids_already_added = [] 101 | if not label_map: 102 | label_id_offset = 1 103 | for class_id in range(max_num_classes): 104 | categories.append( 105 | { 106 | "id": class_id + label_id_offset, 107 | "name": "category_{}".format(class_id + label_id_offset), 108 | } 109 | ) 110 | return categories 111 | for item in label_map.item: 112 | if not 0 < item.id <= max_num_classes: 113 | logging.info( 114 | "Ignore item %d since it falls outside of requested " 115 | "label range.", 116 | item.id, 117 | ) 118 | continue 119 | if use_display_name and item.HasField("display_name"): 120 | name = item.display_name 121 | else: 122 | name = item.name 123 | if item.id not in list_of_ids_already_added: 124 | list_of_ids_already_added.append(item.id) 125 | categories.append({"id": item.id, "name": name}) 126 | return categories 127 | 128 | 129 | def load_labelmap(path): 130 | """Loads label map proto. 131 | 132 | Args: 133 | path: path to StringIntLabelMap proto text file. 134 | Returns: 135 | a StringIntLabelMapProto 136 | """ 137 | with open(path, "r") as fid: 138 | label_map_string = fid.read() 139 | label_map = string_int_label_map_pb2.StringIntLabelMap() 140 | try: 141 | text_format.Merge(label_map_string, label_map) 142 | except text_format.ParseError: 143 | label_map.ParseFromString(label_map_string) 144 | _validate_label_map(label_map) 145 | return label_map 146 | 147 | 148 | def get_label_map_dict(label_map_path, use_display_name=False): 149 | """Reads a label map and returns a dictionary of label names to id. 150 | 151 | Args: 152 | label_map_path: path to label_map. 153 | use_display_name: whether to use the label map items' display names as keys. 154 | 155 | Returns: 156 | A dictionary mapping label names to id. 157 | """ 158 | label_map = load_labelmap(label_map_path) 159 | label_map_dict = {} 160 | for item in label_map.item: 161 | if use_display_name: 162 | label_map_dict[item.display_name] = item.id 163 | else: 164 | label_map_dict[item.name] = item.id 165 | return label_map_dict 166 | 167 | 168 | def create_category_index_from_labelmap(label_map_path): 169 | """Reads a label map and returns a category index. 170 | 171 | Args: 172 | label_map_path: Path to `StringIntLabelMap` proto text file. 173 | 174 | Returns: 175 | A category index, which is a dictionary that maps integer ids to dicts 176 | containing categories, e.g. 177 | {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} 178 | """ 179 | label_map = load_labelmap(label_map_path) 180 | max_num_classes = max(item.id for item in label_map.item) 181 | categories = convert_label_map_to_categories(label_map, max_num_classes) 182 | return create_category_index(categories) 183 | 184 | 185 | def create_class_agnostic_category_index(): 186 | """Creates a category index with a single `object` class.""" 187 | return {1: {"id": 1, "name": "object"}} 188 | -------------------------------------------------------------------------------- /slowfast/visualization/predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import cv2 5 | import torch 6 | from detectron2 import model_zoo 7 | from detectron2.config import get_cfg 8 | from detectron2.engine import DefaultPredictor 9 | 10 | import slowfast.utils.checkpoint as cu 11 | from slowfast.datasets import cv2_transform 12 | from slowfast.models import build_model 13 | from slowfast.utils import logging, misc 14 | from slowfast.visualization.utils import process_cv2_inputs 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | class ActionPredictor: 20 | """ 21 | Action Predictor for action recognition. 22 | """ 23 | 24 | def __init__(self, cfg): 25 | """ 26 | Args: 27 | cfg (CfgNode): configs. Details can be found in 28 | slowfast/config/defaults.py 29 | """ 30 | # Build the video model and print model statistics. 31 | self.model = build_model(cfg) 32 | self.model.eval() 33 | self.cfg = cfg 34 | logger.info("Start loading model info") 35 | misc.log_model_info(self.model, cfg, use_train_input=False) 36 | logger.info("Start loading model weights") 37 | cu.load_test_checkpoint(cfg, self.model) 38 | logger.info("Finish loading model weights") 39 | 40 | def __call__(self, task): 41 | """ 42 | Returns the prediction results for the current task. 43 | Args: 44 | task (TaskInfo object): task object that contain 45 | the necessary information for action prediction. (e.g. frames, boxes) 46 | Returns: 47 | task (TaskInfo object): the same task info object but filled with 48 | prediction values (a tensor) and the corresponding boxes for 49 | action detection task. 50 | """ 51 | frames, bboxes = task.frames, task.bboxes 52 | if bboxes is not None: 53 | bboxes = cv2_transform.scale_boxes( 54 | self.cfg.DATA.TEST_CROP_SIZE, 55 | bboxes, 56 | task.img_height, 57 | task.img_width, 58 | ) 59 | if self.cfg.DEMO.INPUT_FORMAT == "BGR": 60 | frames = [ 61 | cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames 62 | ] 63 | 64 | frames = [ 65 | cv2_transform.scale(self.cfg.DATA.TEST_CROP_SIZE, frame) 66 | for frame in frames 67 | ] 68 | inputs = process_cv2_inputs(frames, self.cfg) 69 | if bboxes is not None: 70 | index_pad = torch.full( 71 | size=(bboxes.shape[0], 1), 72 | fill_value=float(0), 73 | device=bboxes.device, 74 | ) 75 | 76 | # Pad frame index for each box. 77 | bboxes = torch.cat([index_pad, bboxes], axis=1) 78 | if self.cfg.NUM_GPUS > 0: 79 | # Transfer the data to the current GPU device. 80 | if isinstance(inputs, (list,)): 81 | for i in range(len(inputs)): 82 | inputs[i] = inputs[i].cuda(non_blocking=True) 83 | else: 84 | inputs = inputs.cuda(non_blocking=True) 85 | if self.cfg.DETECTION.ENABLE and not bboxes.shape[0]: 86 | preds = torch.tensor([]) 87 | else: 88 | preds = self.model(inputs, bboxes) 89 | 90 | if self.cfg.NUM_GPUS: 91 | preds = preds.cpu() 92 | if bboxes is not None: 93 | bboxes = bboxes.cpu() 94 | 95 | preds = preds.detach() 96 | 97 | task.add_action_preds(preds) 98 | if bboxes is not None: 99 | task.add_bboxes(bboxes[:, 1:]) 100 | 101 | return task 102 | 103 | 104 | class Detectron2Predictor: 105 | """ 106 | Wrapper around Detectron2 to return the required predicted bounding boxes 107 | as a ndarray. 108 | """ 109 | 110 | def __init__(self, cfg): 111 | """ 112 | Args: 113 | cfg (CfgNode): configs. Details can be found in 114 | slowfast/config/defaults.py 115 | """ 116 | 117 | self.cfg = get_cfg() 118 | self.cfg.merge_from_file( 119 | model_zoo.get_config_file(cfg.DEMO.DETECTRON2_CFG) 120 | ) 121 | self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = cfg.DEMO.DETECTRON2_THRESH 122 | self.cfg.MODEL.WEIGHTS = cfg.DEMO.DETECTRON2_WEIGHTS 123 | self.cfg.INPUT.FORMAT = cfg.DEMO.INPUT_FORMAT 124 | self.cfg.MODEL.DEVICE = "cuda:0" if cfg.NUM_GPUS > 0 else "cpu" 125 | 126 | logger.info("Initialized Detectron2 Object Detection Model.") 127 | 128 | self.predictor = DefaultPredictor(self.cfg) 129 | 130 | def __call__(self, task): 131 | """ 132 | Return bounding boxes predictions as a tensor. 133 | Args: 134 | task (TaskInfo object): task object that contain 135 | the necessary information for action prediction. (e.g. frames, boxes) 136 | Returns: 137 | task (TaskInfo object): the same task info object but filled with 138 | prediction values (a tensor) and the corresponding boxes for 139 | action detection task. 140 | """ 141 | middle_frame = task.frames[len(task.frames) // 2] 142 | outputs = self.predictor(middle_frame) 143 | # Get only human instances 144 | mask = outputs["instances"].pred_classes == 0 145 | pred_boxes = outputs["instances"].pred_boxes.tensor[mask] 146 | task.add_bboxes(pred_boxes) 147 | 148 | return task 149 | 150 | 151 | def draw_predictions(task, video_vis): 152 | """ 153 | Draw prediction for the given task. 154 | Args: 155 | task (TaskInfo object): task object that contain 156 | the necessary information for visualization. (e.g. frames, preds) 157 | All attributes must lie on CPU devices. 158 | video_vis (VideoVisualizer object): the video visualizer object. 159 | Returns: 160 | frames (list of ndarray): visualized frames in the clip. 161 | """ 162 | boxes = task.bboxes 163 | frames = task.frames 164 | preds = task.action_preds 165 | if boxes is not None: 166 | img_width = task.img_width 167 | img_height = task.img_height 168 | boxes = cv2_transform.revert_scaled_boxes( 169 | task.crop_size, boxes, img_height, img_width 170 | ) 171 | 172 | keyframe_idx = len(frames) // 2 - task.num_buffer_frames 173 | draw_range = [ 174 | keyframe_idx - task.clip_vis_size, 175 | keyframe_idx + task.clip_vis_size, 176 | ] 177 | frames = frames[task.num_buffer_frames :] 178 | if boxes is not None: 179 | if len(boxes) != 0: 180 | frames = video_vis.draw_clip_range( 181 | frames, 182 | preds, 183 | boxes, 184 | keyframe_idx=keyframe_idx, 185 | draw_range=draw_range, 186 | ) 187 | else: 188 | frames = video_vis.draw_clip_range( 189 | frames, preds, keyframe_idx=keyframe_idx, draw_range=draw_range 190 | ) 191 | del task 192 | 193 | return frames 194 | -------------------------------------------------------------------------------- /slowfast/models/batchnorm_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """BatchNorm (BN) utility functions and custom batch-size BN implementations""" 5 | 6 | from functools import partial 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | from torch.autograd.function import Function 11 | 12 | import slowfast.utils.distributed as du 13 | 14 | 15 | def get_norm(cfg): 16 | """ 17 | Args: 18 | cfg (CfgNode): model building configs, details are in the comments of 19 | the config file. 20 | Returns: 21 | nn.Module: the normalization layer. 22 | """ 23 | if cfg.BN.NORM_TYPE == "batchnorm": 24 | return nn.BatchNorm3d 25 | elif cfg.BN.NORM_TYPE == "sub_batchnorm": 26 | return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS) 27 | elif cfg.BN.NORM_TYPE == "sync_batchnorm": 28 | return partial( 29 | NaiveSyncBatchNorm3d, num_sync_devices=cfg.BN.NUM_SYNC_DEVICES 30 | ) 31 | else: 32 | raise NotImplementedError( 33 | "Norm type {} is not supported".format(cfg.BN.NORM_TYPE) 34 | ) 35 | 36 | 37 | class SubBatchNorm3d(nn.Module): 38 | """ 39 | The standard BN layer computes stats across all examples in a GPU. In some 40 | cases it is desirable to compute stats across only a subset of examples 41 | (e.g., in multigrid training https://arxiv.org/abs/1912.00998). 42 | SubBatchNorm3d splits the batch dimension into N splits, and run BN on 43 | each of them separately (so that the stats are computed on each subset of 44 | examples (1/N of batch) independently. During evaluation, it aggregates 45 | the stats from all splits into one BN. 46 | """ 47 | 48 | def __init__(self, num_splits, **args): 49 | """ 50 | Args: 51 | num_splits (int): number of splits. 52 | args (list): other arguments. 53 | """ 54 | super(SubBatchNorm3d, self).__init__() 55 | self.num_splits = num_splits 56 | num_features = args["num_features"] 57 | # Keep only one set of weight and bias. 58 | if args.get("affine", True): 59 | self.affine = True 60 | args["affine"] = False 61 | self.weight = torch.nn.Parameter(torch.ones(num_features)) 62 | self.bias = torch.nn.Parameter(torch.zeros(num_features)) 63 | else: 64 | self.affine = False 65 | self.bn = nn.BatchNorm3d(**args) 66 | args["num_features"] = num_features * num_splits 67 | self.split_bn = nn.BatchNorm3d(**args) 68 | 69 | def _get_aggregated_mean_std(self, means, stds, n): 70 | """ 71 | Calculate the aggregated mean and stds. 72 | Args: 73 | means (tensor): mean values. 74 | stds (tensor): standard deviations. 75 | n (int): number of sets of means and stds. 76 | """ 77 | mean = means.view(n, -1).sum(0) / n 78 | std = ( 79 | stds.view(n, -1).sum(0) / n 80 | + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n 81 | ) 82 | return mean.detach(), std.detach() 83 | 84 | def aggregate_stats(self): 85 | """ 86 | Synchronize running_mean, and running_var. Call this before eval. 87 | """ 88 | if self.split_bn.track_running_stats: 89 | ( 90 | self.bn.running_mean.data, 91 | self.bn.running_var.data, 92 | ) = self._get_aggregated_mean_std( 93 | self.split_bn.running_mean, 94 | self.split_bn.running_var, 95 | self.num_splits, 96 | ) 97 | 98 | def forward(self, x): 99 | if self.training: 100 | n, c, t, h, w = x.shape 101 | x = x.view(n // self.num_splits, c * self.num_splits, t, h, w) 102 | x = self.split_bn(x) 103 | x = x.view(n, c, t, h, w) 104 | else: 105 | x = self.bn(x) 106 | if self.affine: 107 | x = x * self.weight.view((-1, 1, 1, 1)) 108 | x = x + self.bias.view((-1, 1, 1, 1)) 109 | return x 110 | 111 | 112 | class GroupGather(Function): 113 | """ 114 | GroupGather performs all gather on each of the local process/ GPU groups. 115 | """ 116 | 117 | @staticmethod 118 | def forward(ctx, input, num_sync_devices, num_groups): 119 | """ 120 | Perform forwarding, gathering the stats across different process/ GPU 121 | group. 122 | """ 123 | ctx.num_sync_devices = num_sync_devices 124 | ctx.num_groups = num_groups 125 | 126 | input_list = [ 127 | torch.zeros_like(input) for k in range(du.get_local_size()) 128 | ] 129 | dist.all_gather( 130 | input_list, input, async_op=False, group=du._LOCAL_PROCESS_GROUP 131 | ) 132 | 133 | inputs = torch.stack(input_list, dim=0) 134 | if num_groups > 1: 135 | rank = du.get_local_rank() 136 | group_idx = rank // num_sync_devices 137 | inputs = inputs[ 138 | group_idx 139 | * num_sync_devices : (group_idx + 1) 140 | * num_sync_devices 141 | ] 142 | inputs = torch.sum(inputs, dim=0) 143 | return inputs 144 | 145 | @staticmethod 146 | def backward(ctx, grad_output): 147 | """ 148 | Perform backwarding, gathering the gradients across different process/ GPU 149 | group. 150 | """ 151 | grad_output_list = [ 152 | torch.zeros_like(grad_output) for k in range(du.get_local_size()) 153 | ] 154 | dist.all_gather( 155 | grad_output_list, 156 | grad_output, 157 | async_op=False, 158 | group=du._LOCAL_PROCESS_GROUP, 159 | ) 160 | 161 | grads = torch.stack(grad_output_list, dim=0) 162 | if ctx.num_groups > 1: 163 | rank = du.get_local_rank() 164 | group_idx = rank // ctx.num_sync_devices 165 | grads = grads[ 166 | group_idx 167 | * ctx.num_sync_devices : (group_idx + 1) 168 | * ctx.num_sync_devices 169 | ] 170 | grads = torch.sum(grads, dim=0) 171 | return grads, None, None 172 | 173 | 174 | class NaiveSyncBatchNorm3d(nn.BatchNorm3d): 175 | def __init__(self, num_sync_devices, **args): 176 | """ 177 | Naive version of Synchronized 3D BatchNorm. 178 | Args: 179 | num_sync_devices (int): number of device to sync. 180 | args (list): other arguments. 181 | """ 182 | self.num_sync_devices = num_sync_devices 183 | if self.num_sync_devices > 0: 184 | assert du.get_local_size() % self.num_sync_devices == 0, ( 185 | du.get_local_size(), 186 | self.num_sync_devices, 187 | ) 188 | self.num_groups = du.get_local_size() // self.num_sync_devices 189 | else: 190 | self.num_sync_devices = du.get_local_size() 191 | self.num_groups = 1 192 | super(NaiveSyncBatchNorm3d, self).__init__(**args) 193 | 194 | def forward(self, input): 195 | if du.get_local_size() == 1 or not self.training: 196 | return super().forward(input) 197 | 198 | assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" 199 | C = input.shape[1] 200 | mean = torch.mean(input, dim=[0, 2, 3, 4]) 201 | meansqr = torch.mean(input * input, dim=[0, 2, 3, 4]) 202 | 203 | vec = torch.cat([mean, meansqr], dim=0) 204 | vec = GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * ( 205 | 1.0 / self.num_sync_devices 206 | ) 207 | 208 | mean, meansqr = torch.split(vec, C) 209 | var = meansqr - mean * mean 210 | self.running_mean += self.momentum * (mean.detach() - self.running_mean) 211 | self.running_var += self.momentum * (var.detach() - self.running_var) 212 | 213 | invstd = torch.rsqrt(var + self.eps) 214 | scale = self.weight * invstd 215 | bias = self.bias - mean * scale 216 | scale = scale.reshape(1, -1, 1, 1, 1) 217 | bias = bias.reshape(1, -1, 1, 1, 1) 218 | return input * scale + bias 219 | -------------------------------------------------------------------------------- /tools/visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import slowfast.datasets.utils as data_utils 8 | import slowfast.utils.checkpoint as cu 9 | import slowfast.utils.distributed as du 10 | import slowfast.utils.logging as logging 11 | import slowfast.utils.misc as misc 12 | import slowfast.visualization.tensorboard_vis as tb 13 | from slowfast.datasets import loader 14 | from slowfast.models import build_model 15 | from slowfast.visualization.utils import ( 16 | GetWeightAndActivation, 17 | process_layer_index_data, 18 | ) 19 | from slowfast.visualization.video_visualizer import VideoVisualizer 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | 24 | def run_visualization(vis_loader, model, cfg, writer=None): 25 | """ 26 | Run model visualization (weights, activations and model inputs) and visualize 27 | them on Tensorboard. 28 | Args: 29 | vis_loader (loader): video visualization loader. 30 | model (model): the video model to visualize. 31 | cfg (CfgNode): configs. Details can be found in 32 | slowfast/config/defaults.py 33 | writer (TensorboardWriter, optional): TensorboardWriter object 34 | to writer Tensorboard log. 35 | """ 36 | n_devices = cfg.NUM_GPUS * cfg.NUM_SHARDS 37 | prefix = "module/" if n_devices > 1 else "" 38 | # Get a list of selected layer names and indexing. 39 | layer_ls, indexing_dict = process_layer_index_data( 40 | cfg.TENSORBOARD.MODEL_VIS.LAYER_LIST, layer_name_prefix=prefix 41 | ) 42 | logger.info("Start Model Visualization.") 43 | # Register hooks for activations. 44 | model_vis = GetWeightAndActivation(model, layer_ls) 45 | 46 | if writer is not None and cfg.TENSORBOARD.MODEL_VIS.MODEL_WEIGHTS: 47 | layer_weights = model_vis.get_weights() 48 | writer.plot_weights_and_activations( 49 | layer_weights, tag="Layer Weights/", heat_map=False 50 | ) 51 | 52 | video_vis = VideoVisualizer( 53 | cfg.MODEL.NUM_CLASSES, 54 | cfg.TENSORBOARD.CLASS_NAMES_PATH, 55 | cfg.TENSORBOARD.MODEL_VIS.TOPK_PREDS, 56 | cfg.TENSORBOARD.MODEL_VIS.COLORMAP, 57 | ) 58 | logger.info("Finish drawing weights.") 59 | global_idx = -1 60 | for inputs, _, _, meta in vis_loader: 61 | if cfg.NUM_GPUS: 62 | # Transfer the data to the current GPU device. 63 | if isinstance(inputs, (list,)): 64 | for i in range(len(inputs)): 65 | inputs[i] = inputs[i].cuda(non_blocking=True) 66 | else: 67 | inputs = inputs.cuda(non_blocking=True) 68 | for key, val in meta.items(): 69 | if isinstance(val, (list,)): 70 | for i in range(len(val)): 71 | val[i] = val[i].cuda(non_blocking=True) 72 | else: 73 | meta[key] = val.cuda(non_blocking=True) 74 | 75 | if cfg.DETECTION.ENABLE: 76 | activations, preds = model_vis.get_activations( 77 | inputs, meta["boxes"] 78 | ) 79 | else: 80 | activations, preds = model_vis.get_activations(inputs) 81 | if cfg.NUM_GPUS: 82 | inputs = du.all_gather_unaligned(inputs) 83 | activations = du.all_gather_unaligned(activations) 84 | preds = du.all_gather_unaligned(preds) 85 | if isinstance(inputs[0], list): 86 | for i in range(len(inputs)): 87 | for j in range(len(inputs[0])): 88 | inputs[i][j] = inputs[i][j].cpu() 89 | else: 90 | inputs = [inp.cpu() for inp in inputs] 91 | preds = [pred.cpu() for pred in preds] 92 | else: 93 | inputs, activations, preds = [inputs], [activations], [preds] 94 | 95 | boxes = [None] * max(n_devices, 1) 96 | if cfg.DETECTION.ENABLE and cfg.NUM_GPUS: 97 | boxes = du.all_gather_unaligned(meta["boxes"]) 98 | boxes = [box.cpu() for box in boxes] 99 | 100 | if writer is not None: 101 | total_vids = 0 102 | for i in range(max(n_devices, 1)): 103 | cur_input = inputs[i] 104 | cur_activations = activations[i] 105 | cur_batch_size = cur_input[0].shape[0] 106 | cur_preds = preds[i] 107 | cur_boxes = boxes[i] 108 | for cur_batch_idx in range(cur_batch_size): 109 | global_idx += 1 110 | total_vids += 1 111 | if cfg.TENSORBOARD.MODEL_VIS.INPUT_VIDEO: 112 | for path_idx, input_pathway in enumerate(cur_input): 113 | if cfg.TEST.DATASET == "ava" and cfg.AVA.BGR: 114 | video = input_pathway[ 115 | cur_batch_idx, [2, 1, 0], ... 116 | ] 117 | else: 118 | video = input_pathway[cur_batch_idx] 119 | # Permute to (T, H, W, C) from (C, T, H, W). 120 | video = video.permute(1, 2, 3, 0) 121 | video = data_utils.revert_tensor_normalize( 122 | video, cfg.DATA.MEAN, cfg.DATA.STD 123 | ) 124 | bboxes = ( 125 | None if cur_boxes is None else cur_boxes[:, 1:] 126 | ) 127 | video = video_vis.draw_clip( 128 | video, cur_preds, bboxes=bboxes 129 | ) 130 | video = ( 131 | torch.Tensor(video) 132 | .permute(0, 3, 1, 2) 133 | .unsqueeze(0) 134 | ) 135 | writer.add_video( 136 | video, 137 | tag="Input {}/Input from pathway {}".format( 138 | global_idx, path_idx + 1 139 | ), 140 | ) 141 | if cfg.TENSORBOARD.MODEL_VIS.ACTIVATIONS: 142 | writer.plot_weights_and_activations( 143 | cur_activations, 144 | tag="Input {}/Activations: ".format(global_idx), 145 | batch_idx=cur_batch_idx, 146 | indexing_dict=indexing_dict, 147 | ) 148 | 149 | logger.info("Visualized {} videos...".format(total_vids)) 150 | 151 | 152 | def visualize(cfg): 153 | """ 154 | Perform layer weights and activations visualization on the model. 155 | Args: 156 | cfg (CfgNode): configs. Details can be found in 157 | slowfast/config/defaults.py 158 | """ 159 | if cfg.TENSORBOARD.ENABLE and cfg.TENSORBOARD.MODEL_VIS.ENABLE: 160 | # Set up environment. 161 | du.init_distributed_training(cfg) 162 | # Set random seed from configs. 163 | np.random.seed(cfg.RNG_SEED) 164 | torch.manual_seed(cfg.RNG_SEED) 165 | 166 | # Setup logging format. 167 | logging.setup_logging(cfg.OUTPUT_DIR) 168 | 169 | # Print config. 170 | logger.info("Model Visualization with config:") 171 | logger.info(cfg) 172 | 173 | # Build the video model and print model statistics. 174 | model = build_model(cfg) 175 | if du.is_master_proc() and cfg.LOG_MODEL_INFO: 176 | misc.log_model_info(model, cfg, use_train_input=False) 177 | 178 | cu.load_test_checkpoint(cfg, model) 179 | 180 | # Create video testing loaders. 181 | vis_loader = loader.construct_loader(cfg, "test") 182 | logger.info( 183 | "Visualize model for {} data points".format(len(vis_loader)) 184 | ) 185 | 186 | if cfg.DETECTION.ENABLE: 187 | assert cfg.NUM_GPUS == cfg.TEST.BATCH_SIZE or cfg.NUM_GPUS == 0 188 | 189 | # Set up writer for logging to Tensorboard format. 190 | if du.is_master_proc(cfg.NUM_GPUS * cfg.NUM_SHARDS): 191 | writer = tb.TensorboardWriter(cfg) 192 | else: 193 | writer = None 194 | 195 | # Run visualization on the model 196 | run_visualization(vis_loader, model, cfg, writer) 197 | 198 | if writer is not None: 199 | writer.close() 200 | -------------------------------------------------------------------------------- /tools/test_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Multi-view test a video classification model.""" 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import cv2 10 | import os 11 | import time 12 | 13 | import slowfast.utils.checkpoint as cu 14 | import slowfast.utils.distributed as du 15 | import slowfast.utils.logging as logging 16 | import slowfast.utils.misc as misc 17 | import slowfast.visualization.tensorboard_vis as tb 18 | from slowfast.datasets import loader 19 | from slowfast.models import build_model 20 | from slowfast.utils.meters import AVAMeter, TestMeter, HieveMeter 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | def prepare_data(inputs, labels, video_idx, meta, FBs, BTs): 26 | # Transfer the data to the current GPU device. 27 | if isinstance(inputs, (list,)): 28 | for i in range(len(inputs)): 29 | inputs[i] = inputs[i].cuda(non_blocking=True) 30 | else: 31 | inputs = inputs.cuda(non_blocking=True) 32 | 33 | # Transfer the data to the current GPU device. 34 | labels = labels.cuda() 35 | video_idx = video_idx.cuda() 36 | for key, val in meta.items(): 37 | if isinstance(val, (list,)): 38 | for i in range(len(val)): 39 | val[i] = val[i].cuda(non_blocking=True) 40 | else: 41 | meta[key] = val.cuda(non_blocking=True) 42 | 43 | for i in range(len(FBs)): 44 | FBs[i] = FBs[i].cuda(non_blocking=True) if FBs[i] is not None else None 45 | 46 | return inputs, labels, video_idx, meta, FBs, BTs 47 | 48 | 49 | @torch.no_grad() 50 | def perform_test(test_loader, model, test_meter, cfg, writer=None): 51 | """ 52 | For classification: 53 | Perform mutli-view testing that uniformly samples N clips from a video along 54 | its temporal axis. For each clip, it takes 3 crops to cover the spatial 55 | dimension, followed by averaging the softmax scores across all Nx3 views to 56 | form a video-level prediction. All video predictions are compared to 57 | ground-truth labels and the final testing performance is logged. 58 | For detection: 59 | Perform fully-convolutional testing on the full frames without crop. 60 | Args: 61 | test_loader (loader): video testing loader. 62 | model (model): the pretrained video model to test. 63 | test_meter (TestMeter): testing meters to log and ensemble the testing 64 | results. 65 | cfg (CfgNode): configs. Details can be found in 66 | slowfast/config/defaults.py 67 | writer (TensorboardWriter object, optional): TensorboardWriter object 68 | to writer Tensorboard log. 69 | """ 70 | # Enable eval mode. 71 | model.eval() 72 | test_meter.iter_tic() 73 | 74 | logger.info('online test for {} iters'.format(len(test_loader))) 75 | for cur_iter, (inputs, labels, video_idx, meta, FBs, BTs) in enumerate(test_loader): 76 | if cfg.NUM_GPUS: 77 | 78 | inputs, labels, video_idx, meta, FBs, BTs = \ 79 | prepare_data(inputs, labels, video_idx, meta, FBs, BTs) 80 | 81 | if cfg.DETECTION.ENABLE: 82 | # Compute the predictions. 83 | 84 | if cfg.LSTC.ENABLE: 85 | preds, sp, lp = model(inputs, bboxes=meta["boxes"], extract=False, FBs=FBs, BTs=BTs) 86 | else: 87 | preds = model(inputs, bboxes=meta["boxes"], extract=False, FBs=FBs, BTs=BTs) 88 | 89 | ori_boxes = meta["ori_boxes"].cpu() 90 | metadata = meta["metadata"].cpu() 91 | 92 | # imgs = inputs[-1][:, :, 16].cpu() 93 | # slow = inputs[0].cpu() 94 | # attn = attn.cpu() 95 | # visualize_tensor(slow, imgs, meta["boxes"].cpu(), metadata, attn, cfg) 96 | # visualize_results(imgs, meta["boxes"].cpu(), metadata, sp, lp, cfg, labelmap) 97 | 98 | if cfg.NUM_GPUS > 1: 99 | ori_boxes = torch.cat(du.all_gather_unaligned(ori_boxes), dim=0) 100 | metadata = torch.cat(du.all_gather_unaligned(metadata), dim=0) 101 | preds = torch.cat(du.all_gather_unaligned(preds), dim=0) 102 | 103 | preds = preds.detach().cpu() if cfg.NUM_GPUS else preds.detach() 104 | ori_boxes = ( 105 | ori_boxes.detach().cpu() if cfg.NUM_GPUS else ori_boxes.detach() 106 | ) 107 | metadata = ( 108 | metadata.detach().cpu() if cfg.NUM_GPUS else metadata.detach() 109 | ) 110 | 111 | test_meter.iter_toc() 112 | # Update and log stats. 113 | test_meter.update_stats(preds, ori_boxes, metadata) 114 | test_meter.log_iter_stats(cur_epoch=0, cur_iter=cur_iter) 115 | else: 116 | # Perform the forward pass. 117 | preds = model(inputs) 118 | 119 | # Gather all the predictions across all the devices to perform ensemble. 120 | if cfg.NUM_GPUS > 1: 121 | preds, labels, video_idx = du.all_gather( 122 | [preds, labels, video_idx] 123 | ) 124 | if cfg.NUM_GPUS: 125 | preds = preds.cpu() 126 | labels = labels.cpu() 127 | video_idx = video_idx.cpu() 128 | test_meter.iter_toc() 129 | # Update and log stats. 130 | test_meter.update_stats( 131 | preds.detach(), labels.detach(), video_idx.detach() 132 | ) 133 | test_meter.log_iter_stats(cur_iter) 134 | test_meter.iter_tic() 135 | # Log epoch stats and print the final testing results. 136 | if writer is not None and not cfg.DETECTION.ENABLE: 137 | all_preds = [pred.clone().detach() for pred in test_meter.video_preds] 138 | all_labels = [ 139 | label.clone().detach() for label in test_meter.video_labels 140 | ] 141 | if cfg.NUM_GPUS: 142 | all_preds = [pred.cpu() for pred in all_preds] 143 | all_labels = [label.cpu() for label in all_labels] 144 | writer.plot_eval(preds=all_preds, labels=all_labels) 145 | 146 | if du.is_master_proc(du.get_local_size()): 147 | test_meter.finalize_metrics() 148 | test_meter.reset() 149 | 150 | 151 | def test(cfg): 152 | """ 153 | Perform multi-view testing on the pretrained video model. 154 | Args: 155 | cfg (CfgNode): configs. Details can be found in 156 | slowfast/config/defaults.py 157 | """ 158 | # Set up environment. 159 | du.init_distributed_training(cfg) 160 | # Set random seed from configs. 161 | np.random.seed(cfg.RNG_SEED) 162 | torch.manual_seed(cfg.RNG_SEED) 163 | 164 | # Setup logging format. 165 | logging.setup_logging(cfg.OUTPUT_DIR) 166 | 167 | # Print config. 168 | logger.info("Test with config:") 169 | logger.info(cfg) 170 | 171 | # Build the video model and print model statistics. 172 | model = build_model(cfg) 173 | if du.is_master_proc() and cfg.LOG_MODEL_INFO: 174 | misc.log_model_info(model, cfg, use_train_input=False) 175 | 176 | cu.load_test_checkpoint(cfg, model) 177 | 178 | # Create video testing loaders. 179 | test_loader = loader.construct_loader(cfg, "test") 180 | du.synchronize() 181 | logger.info("Testing model for {} iterations".format(len(test_loader))) 182 | 183 | if cfg.DETECTION.ENABLE: 184 | assert cfg.NUM_GPUS == cfg.TEST.BATCH_SIZE or cfg.NUM_GPUS == 0 185 | if cfg.TEST.DATASET == 'ava': 186 | test_meter = AVAMeter(len(test_loader), cfg, mode="test") 187 | else: 188 | test_meter = HieveMeter(len(test_loader), cfg, mode="test") 189 | else: 190 | assert ( 191 | len(test_loader.dataset) 192 | % (cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS) 193 | == 0 194 | ) 195 | # Create meters for multi-view testing. 196 | test_meter = TestMeter( 197 | len(test_loader.dataset) 198 | // (cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS), 199 | cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS, 200 | cfg.MODEL.NUM_CLASSES, 201 | len(test_loader), 202 | cfg.DATA.MULTI_LABEL, 203 | cfg.DATA.ENSEMBLE_METHOD, 204 | ) 205 | 206 | # Set up writer for logging to Tensorboard format. 207 | if cfg.TENSORBOARD.ENABLE and du.is_master_proc( 208 | cfg.NUM_GPUS * cfg.NUM_SHARDS 209 | ): 210 | writer = tb.TensorboardWriter(cfg) 211 | else: 212 | writer = None 213 | 214 | # # Perform multi-view test on the entire dataset. 215 | tic = time.time() 216 | perform_test(test_loader, model, test_meter, cfg, writer) 217 | toc = time.time() 218 | logger.info(f"total inference is {(toc - tic):.3f}s") 219 | test_loader.dataset.close() 220 | if writer is not None: 221 | writer.close() 222 | -------------------------------------------------------------------------------- /slowfast/utils/multigrid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Helper functions for multigrid training.""" 5 | 6 | import numpy as np 7 | 8 | import slowfast.utils.logging as logging 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | class MultigridSchedule(object): 14 | """ 15 | This class defines multigrid training schedule and update cfg accordingly. 16 | """ 17 | 18 | def init_multigrid(self, cfg): 19 | """ 20 | Update cfg based on multigrid settings. 21 | Args: 22 | cfg (configs): configs that contains training and multigrid specific 23 | hyperparameters. Details can be seen in 24 | slowfast/config/defaults.py. 25 | Returns: 26 | cfg (configs): the updated cfg. 27 | """ 28 | self.schedule = None 29 | # We may modify cfg.TRAIN.BATCH_SIZE, cfg.DATA.NUM_FRAMES, and 30 | # cfg.DATA.TRAIN_CROP_SIZE during training, so we store their original 31 | # value in cfg and use them as global variables. 32 | cfg.MULTIGRID.DEFAULT_B = cfg.TRAIN.BATCH_SIZE 33 | cfg.MULTIGRID.DEFAULT_T = cfg.DATA.NUM_FRAMES 34 | cfg.MULTIGRID.DEFAULT_S = cfg.DATA.TRAIN_CROP_SIZE 35 | 36 | if cfg.MULTIGRID.LONG_CYCLE: 37 | self.schedule = self.get_long_cycle_schedule(cfg) 38 | cfg.SOLVER.STEPS = [0] + [s[-1] for s in self.schedule] 39 | # Fine-tuning phase. 40 | cfg.SOLVER.STEPS[-1] = ( 41 | cfg.SOLVER.STEPS[-2] + cfg.SOLVER.STEPS[-1] 42 | ) // 2 43 | cfg.SOLVER.LRS = [ 44 | cfg.SOLVER.GAMMA ** s[0] * s[1][0] for s in self.schedule 45 | ] 46 | # Fine-tuning phase. 47 | cfg.SOLVER.LRS = cfg.SOLVER.LRS[:-1] + [ 48 | cfg.SOLVER.LRS[-2], 49 | cfg.SOLVER.LRS[-1], 50 | ] 51 | 52 | cfg.SOLVER.MAX_EPOCH = self.schedule[-1][-1] 53 | 54 | elif cfg.MULTIGRID.SHORT_CYCLE: 55 | cfg.SOLVER.STEPS = [ 56 | int(s * cfg.MULTIGRID.EPOCH_FACTOR) for s in cfg.SOLVER.STEPS 57 | ] 58 | cfg.SOLVER.MAX_EPOCH = int( 59 | cfg.SOLVER.MAX_EPOCH * cfg.MULTIGRID.EPOCH_FACTOR 60 | ) 61 | return cfg 62 | 63 | def update_long_cycle(self, cfg, cur_epoch): 64 | """ 65 | Before every epoch, check if long cycle shape should change. If it 66 | should, update cfg accordingly. 67 | Args: 68 | cfg (configs): configs that contains training and multigrid specific 69 | hyperparameters. Details can be seen in 70 | slowfast/config/defaults.py. 71 | cur_epoch (int): current epoch index. 72 | Returns: 73 | cfg (configs): the updated cfg. 74 | changed (bool): do we change long cycle shape at this epoch? 75 | """ 76 | base_b, base_t, base_s = get_current_long_cycle_shape( 77 | self.schedule, cur_epoch 78 | ) 79 | if base_s != cfg.DATA.TRAIN_CROP_SIZE or base_t != cfg.DATA.NUM_FRAMES: 80 | 81 | cfg.DATA.NUM_FRAMES = base_t 82 | cfg.DATA.TRAIN_CROP_SIZE = base_s 83 | cfg.TRAIN.BATCH_SIZE = base_b * cfg.MULTIGRID.DEFAULT_B 84 | 85 | bs_factor = ( 86 | float(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) 87 | / cfg.MULTIGRID.BN_BASE_SIZE 88 | ) 89 | 90 | if bs_factor < 1: 91 | cfg.BN.NORM_TYPE = "sync_batchnorm" 92 | cfg.BN.NUM_SYNC_DEVICES = int(1.0 / bs_factor) 93 | elif bs_factor > 1: 94 | cfg.BN.NORM_TYPE = "sub_batchnorm" 95 | cfg.BN.NUM_SPLITS = int(bs_factor) 96 | else: 97 | cfg.BN.NORM_TYPE = "batchnorm" 98 | 99 | cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE = cfg.DATA.SAMPLING_RATE * ( 100 | cfg.MULTIGRID.DEFAULT_T // cfg.DATA.NUM_FRAMES 101 | ) 102 | logger.info("Long cycle updates:") 103 | logger.info("\tBN.NORM_TYPE: {}".format(cfg.BN.NORM_TYPE)) 104 | if cfg.BN.NORM_TYPE == "sync_batchnorm": 105 | logger.info( 106 | "\tBN.NUM_SYNC_DEVICES: {}".format(cfg.BN.NUM_SYNC_DEVICES) 107 | ) 108 | elif cfg.BN.NORM_TYPE == "sub_batchnorm": 109 | logger.info("\tBN.NUM_SPLITS: {}".format(cfg.BN.NUM_SPLITS)) 110 | logger.info("\tTRAIN.BATCH_SIZE: {}".format(cfg.TRAIN.BATCH_SIZE)) 111 | logger.info( 112 | "\tDATA.NUM_FRAMES x LONG_CYCLE_SAMPLING_RATE: {}x{}".format( 113 | cfg.DATA.NUM_FRAMES, cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE 114 | ) 115 | ) 116 | logger.info( 117 | "\tDATA.TRAIN_CROP_SIZE: {}".format(cfg.DATA.TRAIN_CROP_SIZE) 118 | ) 119 | return cfg, True 120 | else: 121 | return cfg, False 122 | 123 | def get_long_cycle_schedule(self, cfg): 124 | """ 125 | Based on multigrid hyperparameters, define the schedule of a long cycle. 126 | Args: 127 | cfg (configs): configs that contains training and multigrid specific 128 | hyperparameters. Details can be seen in 129 | slowfast/config/defaults.py. 130 | Returns: 131 | schedule (list): Specifies a list long cycle base shapes and their 132 | corresponding training epochs. 133 | """ 134 | 135 | steps = cfg.SOLVER.STEPS 136 | 137 | default_size = float( 138 | cfg.DATA.NUM_FRAMES * cfg.DATA.TRAIN_CROP_SIZE ** 2 139 | ) 140 | default_iters = steps[-1] 141 | 142 | # Get shapes and average batch size for each long cycle shape. 143 | avg_bs = [] 144 | all_shapes = [] 145 | for t_factor, s_factor in cfg.MULTIGRID.LONG_CYCLE_FACTORS: 146 | base_t = int(round(cfg.DATA.NUM_FRAMES * t_factor)) 147 | base_s = int(round(cfg.DATA.TRAIN_CROP_SIZE * s_factor)) 148 | if cfg.MULTIGRID.SHORT_CYCLE: 149 | shapes = [ 150 | [ 151 | base_t, 152 | cfg.MULTIGRID.DEFAULT_S 153 | * cfg.MULTIGRID.SHORT_CYCLE_FACTORS[0], 154 | ], 155 | [ 156 | base_t, 157 | cfg.MULTIGRID.DEFAULT_S 158 | * cfg.MULTIGRID.SHORT_CYCLE_FACTORS[1], 159 | ], 160 | [base_t, base_s], 161 | ] 162 | else: 163 | shapes = [[base_t, base_s]] 164 | 165 | # (T, S) -> (B, T, S) 166 | shapes = [ 167 | [int(round(default_size / (s[0] * s[1] * s[1]))), s[0], s[1]] 168 | for s in shapes 169 | ] 170 | avg_bs.append(np.mean([s[0] for s in shapes])) 171 | all_shapes.append(shapes) 172 | 173 | # Get schedule regardless of cfg.MULTIGRID.EPOCH_FACTOR. 174 | total_iters = 0 175 | schedule = [] 176 | for step_index in range(len(steps) - 1): 177 | step_epochs = steps[step_index + 1] - steps[step_index] 178 | 179 | for long_cycle_index, shapes in enumerate(all_shapes): 180 | cur_epochs = ( 181 | step_epochs * avg_bs[long_cycle_index] / sum(avg_bs) 182 | ) 183 | 184 | cur_iters = cur_epochs / avg_bs[long_cycle_index] 185 | total_iters += cur_iters 186 | schedule.append((step_index, shapes[-1], cur_epochs)) 187 | 188 | iter_saving = default_iters / total_iters 189 | 190 | final_step_epochs = cfg.SOLVER.MAX_EPOCH - steps[-1] 191 | 192 | # We define the fine-tuning phase to have the same amount of iteration 193 | # saving as the rest of the training. 194 | ft_epochs = final_step_epochs / iter_saving * avg_bs[-1] 195 | 196 | schedule.append((step_index + 1, all_shapes[-1][2], ft_epochs)) 197 | 198 | # Obtrain final schedule given desired cfg.MULTIGRID.EPOCH_FACTOR. 199 | x = ( 200 | cfg.SOLVER.MAX_EPOCH 201 | * cfg.MULTIGRID.EPOCH_FACTOR 202 | / sum(s[-1] for s in schedule) 203 | ) 204 | 205 | final_schedule = [] 206 | total_epochs = 0 207 | for s in schedule: 208 | epochs = s[2] * x 209 | total_epochs += epochs 210 | final_schedule.append((s[0], s[1], int(round(total_epochs)))) 211 | print_schedule(final_schedule) 212 | return final_schedule 213 | 214 | 215 | def print_schedule(schedule): 216 | """ 217 | Log schedule. 218 | """ 219 | logger.info("Long cycle index\tBase shape\tEpochs") 220 | for s in schedule: 221 | logger.info("{}\t{}\t{}".format(s[0], s[1], s[2])) 222 | 223 | 224 | def get_current_long_cycle_shape(schedule, epoch): 225 | """ 226 | Given a schedule and epoch index, return the long cycle base shape. 227 | Args: 228 | schedule (configs): configs that contains training and multigrid specific 229 | hyperparameters. Details can be seen in 230 | slowfast/config/defaults.py. 231 | cur_epoch (int): current epoch index. 232 | Returns: 233 | shapes (list): A list describing the base shape in a long cycle: 234 | [batch size relative to default, 235 | number of frames, spatial dimension]. 236 | """ 237 | for s in schedule: 238 | if epoch < s[-1]: 239 | return s[1] 240 | return schedule[-1][1] 241 | -------------------------------------------------------------------------------- /slowfast/datasets/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Data loader.""" 5 | 6 | import itertools 7 | import functools 8 | import numpy as np 9 | import torch 10 | import random 11 | import math 12 | import functools 13 | from torch.utils.data._utils.collate import default_collate 14 | from torch.utils.data.distributed import DistributedSampler 15 | from torch.utils.data.sampler import RandomSampler 16 | 17 | from slowfast.datasets.multigrid_helper import ShortCycleBatchSampler 18 | from slowfast.utils.logging import get_logger 19 | 20 | from .build import build_dataset 21 | from collections import defaultdict 22 | 23 | logger = get_logger(__name__) 24 | 25 | class SequentialDistributedSampler(DistributedSampler): 26 | 27 | """ 28 | a video specific sampler to ensure that each process focus on 29 | processing a distinct video clip sequence 30 | """ 31 | 32 | def __init__(self, cfg, train, dataset): 33 | super(SequentialDistributedSampler, self).__init__(dataset) 34 | # make video samplers divisible 35 | 36 | if train: 37 | batch_size = cfg.TRAIN.BATCH_SIZE // self.num_replicas 38 | else: 39 | batch_size = cfg.TEST.BATCH_SIZE // self.num_replicas 40 | self.batch_size = batch_size 41 | 42 | # reorganize the clip index, so that each shot can be 43 | # indexed via a specific clip id 44 | self.shot_to_clip = dict() 45 | for vid in range(dataset.num_videos()): 46 | indexes = dataset.get_idx_sequence_from_video(vid) 47 | if len(indexes) % self.batch_size > 0: 48 | left = self.batch_size - len(indexes) % self.batch_size 49 | indexes += indexes[-left:] 50 | 51 | self.shot_to_clip.update( 52 | { 53 | indexes[idx] : indexes[idx:idx+self.batch_size] \ 54 | for idx in range(0, len(indexes), self.batch_size) 55 | } 56 | ) 57 | 58 | self.shot_keys = list(self.shot_to_clip.keys()) 59 | self.shot_keys.sort() 60 | self.num_shot = len(self.shot_to_clip) 61 | self.total_shot = int(math.ceil(self.num_shot * 1.0 / self.num_replicas)) \ 62 | * self.num_replicas 63 | self.num_clips = self.total_shot * self.batch_size 64 | self.num_samples = self.num_clips // self.num_replicas 65 | 66 | self.seed = 0 67 | logger.info('{} shots, {} clips'.format(self.num_shot, self.num_clips)) 68 | 69 | def __iter__(self): 70 | 71 | if self.shuffle: 72 | # deterministically shuffle based on epoch and seed 73 | g = torch.Generator() 74 | g.manual_seed(self.seed + self.epoch) 75 | random_indices = torch.randperm(len(self.shot_keys), generator=g).tolist() 76 | video_indices = [self.shot_keys[idx] for idx in random_indices] 77 | video_indices += video_indices[:(self.total_shot - self.num_shot)] 78 | shot_indices = video_indices[self.rank:self.total_shot:self.num_replicas] 79 | 80 | indices = functools.reduce( 81 | lambda x, y: x+y, 82 | [self.shot_to_clip[s] for s in shot_indices] 83 | ) 84 | else: 85 | indices = list(range(len(self.dataset))) 86 | 87 | # add extra samples to make it evenly divisible 88 | indices += indices[:(self.total_size - len(indices))] 89 | assert len(indices) == self.total_size 90 | 91 | # subsample 92 | indices = indices[self.rank:self.total_size:self.num_replicas] 93 | 94 | logger.info('num samples {} {}'.format(len(indices), self.num_samples)) 95 | assert len(indices) == self.num_samples, \ 96 | "{} vs. {}".format(len(indices), self.num_samples) 97 | 98 | return iter(indices) 99 | 100 | def detection_collate(batch, cfg): 101 | """ 102 | Collate function for detection task. Concatanate bboxes, labels and 103 | metadata from different samples in the first dimension instead of 104 | stacking them to have a batch-size dimension. 105 | Args: 106 | batch (tuple or list): data batch to collate. 107 | Returns: 108 | (tuple): collated detection data batch. 109 | """ 110 | inputs, labels, video_idx, extra_data, feat_banks, bank_times = zip(*batch) 111 | inputs, video_idx = default_collate(inputs), default_collate(video_idx) 112 | 113 | collated_extra_data = { 114 | "raw_labels": torch.tensor(np.concatenate(labels, axis=0)).float() 115 | } 116 | 117 | if not cfg.AVA.GATHER_BANK: 118 | new_labels = [] 119 | for label in labels: 120 | if label.shape[0] > 0: 121 | new_labels.append(np.max(label, axis=0, keepdims=True)) 122 | else: 123 | new_labels.append(label) 124 | 125 | labels = new_labels 126 | 127 | labels = torch.tensor(np.concatenate(labels, axis=0)).float() 128 | 129 | feat_banks = list(feat_banks) 130 | bank_times = list(bank_times) 131 | 132 | for key in extra_data[0].keys(): 133 | data = [d[key] for d in extra_data] 134 | 135 | if key == "boxes": 136 | has_box = [d.shape[0] > 0 for d in data] 137 | collated_extra_data["has_box"] = torch.Tensor(has_box).bool() 138 | 139 | if key == "boxes" or key == "ori_boxes": 140 | # Append idx info to the bboxes before concatenating them. 141 | bboxes = [ 142 | np.concatenate( 143 | [np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1 144 | ) 145 | for i in range(len(data)) 146 | ] 147 | bboxes = np.concatenate(bboxes, axis=0) 148 | collated_extra_data[key] = torch.tensor(bboxes).float() 149 | elif key == "metadata": 150 | collated_extra_data[key] = torch.tensor( 151 | list(itertools.chain(*data)) 152 | ).view(-1, 2) 153 | else: 154 | collated_extra_data[key] = default_collate(data) 155 | 156 | return inputs, labels, video_idx, collated_extra_data, feat_banks, bank_times 157 | 158 | def construct_loader(cfg, split, is_precise_bn=False): 159 | """ 160 | Constructs the data loader for the given dataset. 161 | Args: 162 | cfg (CfgNode): configs. Details can be found in 163 | slowfast/config/defaults.py 164 | split (str): the split of the data loader. Options include `train`, 165 | `val`, and `test`. 166 | """ 167 | assert split in ["train", "val", "test"] 168 | if split in ["train"]: 169 | dataset_name = cfg.TRAIN.DATASET 170 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 171 | shuffle = True 172 | drop_last = True 173 | elif split in ["val"]: 174 | dataset_name = cfg.TRAIN.DATASET 175 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 176 | shuffle = False 177 | drop_last = False 178 | elif split in ["test"]: 179 | dataset_name = cfg.TEST.DATASET 180 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 181 | shuffle = False 182 | drop_last = False 183 | 184 | # Construct the dataset 185 | dataset = build_dataset(dataset_name, cfg, split) 186 | 187 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn: 188 | # Create a sampler for multi-process training 189 | sampler = ( 190 | DistributedSampler(dataset) 191 | if cfg.NUM_GPUS > 1 192 | else RandomSampler(dataset) 193 | ) 194 | batch_sampler = ShortCycleBatchSampler( 195 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg 196 | ) 197 | # Create a loader 198 | loader = torch.utils.data.DataLoader( 199 | dataset, 200 | batch_sampler=batch_sampler, 201 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 202 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 203 | ) 204 | else: 205 | # Create a sampler for multi-process training 206 | if cfg.NUM_GPUS > 1: 207 | sampler = DistributedSampler(dataset, shuffle=shuffle) 208 | else: 209 | sampler = None 210 | # Create a loader 211 | collate_fn = functools.partial(detection_collate, cfg=cfg) 212 | loader = torch.utils.data.DataLoader( 213 | dataset, 214 | batch_size=batch_size, 215 | shuffle=(False if sampler else shuffle), 216 | sampler=sampler, 217 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 218 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 219 | drop_last=drop_last, 220 | collate_fn=collate_fn if cfg.DETECTION.ENABLE else None, 221 | ) 222 | return loader 223 | 224 | 225 | def shuffle_dataset(loader, cur_epoch): 226 | """" 227 | Shuffles the data. 228 | Args: 229 | loader (loader): data loader to perform shuffle. 230 | cur_epoch (int): number of the current epoch. 231 | """ 232 | sampler = ( 233 | loader.batch_sampler.sampler 234 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler) 235 | else loader.sampler 236 | ) 237 | assert isinstance( 238 | sampler, (RandomSampler, DistributedSampler, SequentialDistributedSampler) 239 | ), "Sampler type '{}' not supported".format(type(sampler)) 240 | # RandomSampler handles shuffling automatically 241 | if isinstance(sampler, DistributedSampler) or issubclass(sampler, DistributedSampler): 242 | # DistributedSampler shuffles data based on epoch 243 | sampler.set_epoch(cur_epoch) 244 | -------------------------------------------------------------------------------- /slowfast/utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Distributed helpers.""" 5 | 6 | import functools 7 | import logging 8 | import pickle 9 | import torch 10 | import torch.distributed as dist 11 | 12 | _LOCAL_PROCESS_GROUP = None 13 | 14 | 15 | def all_gather(tensors): 16 | """ 17 | All gathers the provided tensors from all processes across machines. 18 | Args: 19 | tensors (list): tensors to perform all gather across all processes in 20 | all machines. 21 | """ 22 | 23 | gather_list = [] 24 | output_tensor = [] 25 | world_size = dist.get_world_size() 26 | for tensor in tensors: 27 | tensor_placeholder = [ 28 | torch.ones_like(tensor) for _ in range(world_size) 29 | ] 30 | dist.all_gather(tensor_placeholder, tensor, async_op=False) 31 | gather_list.append(tensor_placeholder) 32 | for gathered_tensor in gather_list: 33 | output_tensor.append(torch.cat(gathered_tensor, dim=0)) 34 | return output_tensor 35 | 36 | 37 | def all_reduce(tensors, average=True): 38 | """ 39 | All reduce the provided tensors from all processes across machines. 40 | Args: 41 | tensors (list): tensors to perform all reduce across all processes in 42 | all machines. 43 | average (bool): scales the reduced tensor by the number of overall 44 | processes across all machines. 45 | """ 46 | 47 | for tensor in tensors: 48 | dist.all_reduce(tensor, async_op=False) 49 | if average: 50 | world_size = dist.get_world_size() 51 | for tensor in tensors: 52 | tensor.mul_(1.0 / world_size) 53 | return tensors 54 | 55 | 56 | def init_process_group( 57 | local_rank, 58 | local_world_size, 59 | shard_id, 60 | num_shards, 61 | init_method, 62 | dist_backend="nccl", 63 | ): 64 | """ 65 | Initializes the default process group. 66 | Args: 67 | local_rank (int): the rank on the current local machine. 68 | local_world_size (int): the world size (number of processes running) on 69 | the current local machine. 70 | shard_id (int): the shard index (machine rank) of the current machine. 71 | num_shards (int): number of shards for distributed training. 72 | init_method (string): supporting three different methods for 73 | initializing process groups: 74 | "file": use shared file system to initialize the groups across 75 | different processes. 76 | "tcp": use tcp address to initialize the groups across different 77 | dist_backend (string): backend to use for distributed training. Options 78 | includes gloo, mpi and nccl, the details can be found here: 79 | https://pytorch.org/docs/stable/distributed.html 80 | """ 81 | # Sets the GPU to use. 82 | torch.cuda.set_device(local_rank) 83 | # Initialize the process group. 84 | proc_rank = local_rank + shard_id * local_world_size 85 | world_size = local_world_size * num_shards 86 | dist.init_process_group( 87 | backend=dist_backend, 88 | init_method=init_method, 89 | world_size=world_size, 90 | rank=proc_rank, 91 | ) 92 | 93 | 94 | def is_master_proc(num_gpus=8): 95 | """ 96 | Determines if the current process is the master process. 97 | """ 98 | if torch.distributed.is_initialized(): 99 | return dist.get_rank() % num_gpus == 0 100 | else: 101 | return True 102 | 103 | 104 | def get_world_size(): 105 | """ 106 | Get the size of the world. 107 | """ 108 | if not dist.is_available(): 109 | return 1 110 | if not dist.is_initialized(): 111 | return 1 112 | return dist.get_world_size() 113 | 114 | 115 | def get_rank(): 116 | """ 117 | Get the rank of the current process. 118 | """ 119 | if not dist.is_available(): 120 | return 0 121 | if not dist.is_initialized(): 122 | return 0 123 | return dist.get_rank() 124 | 125 | 126 | def synchronize(): 127 | """ 128 | Helper function to synchronize (barrier) among all processes when 129 | using distributed training 130 | """ 131 | if not dist.is_available(): 132 | return 133 | if not dist.is_initialized(): 134 | return 135 | world_size = dist.get_world_size() 136 | if world_size == 1: 137 | return 138 | dist.barrier() 139 | 140 | 141 | @functools.lru_cache() 142 | def _get_global_gloo_group(): 143 | """ 144 | Return a process group based on gloo backend, containing all the ranks 145 | The result is cached. 146 | Returns: 147 | (group): pytorch dist group. 148 | """ 149 | if dist.get_backend() == "nccl": 150 | return dist.new_group(backend="gloo") 151 | else: 152 | return dist.group.WORLD 153 | 154 | 155 | def _serialize_to_tensor(data, group): 156 | """ 157 | Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl` 158 | backend is supported. 159 | Args: 160 | data (data): data to be serialized. 161 | group (group): pytorch dist group. 162 | Returns: 163 | tensor (ByteTensor): tensor that serialized. 164 | """ 165 | 166 | backend = dist.get_backend(group) 167 | assert backend in ["gloo", "nccl"] 168 | device = torch.device("cpu" if backend == "gloo" else "cuda") 169 | 170 | buffer = pickle.dumps(data) 171 | if len(buffer) > 1024 ** 3: 172 | logger = logging.getLogger(__name__) 173 | logger.warning( 174 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 175 | get_rank(), len(buffer) / (1024 ** 3), device 176 | ) 177 | ) 178 | storage = torch.ByteStorage.from_buffer(buffer) 179 | tensor = torch.ByteTensor(storage).to(device=device) 180 | return tensor 181 | 182 | 183 | def _pad_to_largest_tensor(tensor, group): 184 | """ 185 | Padding all the tensors from different GPUs to the largest ones. 186 | Args: 187 | tensor (tensor): tensor to pad. 188 | group (group): pytorch dist group. 189 | Returns: 190 | list[int]: size of the tensor, on each rank 191 | Tensor: padded tensor that has the max size 192 | """ 193 | world_size = dist.get_world_size(group=group) 194 | assert ( 195 | world_size >= 1 196 | ), "comm.gather/all_gather must be called from ranks within the given group!" 197 | local_size = torch.tensor( 198 | [tensor.numel()], dtype=torch.int64, device=tensor.device 199 | ) 200 | size_list = [ 201 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 202 | for _ in range(world_size) 203 | ] 204 | dist.all_gather(size_list, local_size, group=group) 205 | size_list = [int(size.item()) for size in size_list] 206 | 207 | max_size = max(size_list) 208 | 209 | # we pad the tensor because torch all_gather does not support 210 | # gathering tensors of different shapes 211 | if local_size != max_size: 212 | padding = torch.zeros( 213 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 214 | ) 215 | tensor = torch.cat((tensor, padding), dim=0) 216 | return size_list, tensor 217 | 218 | 219 | def all_gather_unaligned(data, group=None): 220 | """ 221 | Run all_gather on arbitrary picklable data (not necessarily tensors). 222 | 223 | Args: 224 | data: any picklable object 225 | group: a torch process group. By default, will use a group which 226 | contains all ranks on gloo backend. 227 | 228 | Returns: 229 | list[data]: list of data gathered from each rank 230 | """ 231 | if get_world_size() == 1: 232 | return [data] 233 | if group is None: 234 | group = _get_global_gloo_group() 235 | if dist.get_world_size(group) == 1: 236 | return [data] 237 | 238 | tensor = _serialize_to_tensor(data, group) 239 | 240 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 241 | max_size = max(size_list) 242 | 243 | # receiving Tensor from all ranks 244 | tensor_list = [ 245 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 246 | for _ in size_list 247 | ] 248 | dist.all_gather(tensor_list, tensor, group=group) 249 | 250 | data_list = [] 251 | for size, tensor in zip(size_list, tensor_list): 252 | buffer = tensor.cpu().numpy().tobytes()[:size] 253 | data_list.append(pickle.loads(buffer)) 254 | 255 | return data_list 256 | 257 | 258 | def init_distributed_training(cfg): 259 | """ 260 | Initialize variables needed for distributed training. 261 | """ 262 | if cfg.NUM_GPUS <= 1: 263 | return 264 | num_gpus_per_machine = cfg.NUM_GPUS 265 | num_machines = dist.get_world_size() // num_gpus_per_machine 266 | for i in range(num_machines): 267 | ranks_on_i = list( 268 | range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) 269 | ) 270 | pg = dist.new_group(ranks_on_i) 271 | if i == cfg.SHARD_ID: 272 | global _LOCAL_PROCESS_GROUP 273 | _LOCAL_PROCESS_GROUP = pg 274 | 275 | def get_main_proc_rank() -> int: 276 | return 0 277 | 278 | def get_local_size() -> int: 279 | """ 280 | Returns: 281 | The size of the per-machine process group, 282 | i.e. the number of processes per machine. 283 | """ 284 | if not dist.is_available(): 285 | return 1 286 | if not dist.is_initialized(): 287 | return 1 288 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 289 | 290 | 291 | def get_local_rank() -> int: 292 | """ 293 | Returns: 294 | The rank of the current process within the local (per-machine) process group. 295 | """ 296 | if not dist.is_available(): 297 | return 0 298 | if not dist.is_initialized(): 299 | return 0 300 | assert _LOCAL_PROCESS_GROUP is not None 301 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 302 | -------------------------------------------------------------------------------- /slowfast/utils/ava_evaluation/standard_fields.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains classes specifying naming conventions used for object detection. 17 | 18 | 19 | Specifies: 20 | InputDataFields: standard fields used by reader/preprocessor/batcher. 21 | DetectionResultFields: standard fields returned by object detector. 22 | BoxListFields: standard field used by BoxList 23 | TfExampleFields: standard fields for tf-example data format (go/tf-example). 24 | """ 25 | 26 | 27 | from __future__ import ( 28 | absolute_import, 29 | division, 30 | print_function, 31 | unicode_literals, 32 | ) 33 | 34 | 35 | class InputDataFields(object): 36 | """Names for the input tensors. 37 | 38 | Holds the standard data field names to use for identifying input tensors. This 39 | should be used by the decoder to identify keys for the returned tensor_dict 40 | containing input tensors. And it should be used by the model to identify the 41 | tensors it needs. 42 | 43 | Attributes: 44 | image: image. 45 | original_image: image in the original input size. 46 | key: unique key corresponding to image. 47 | source_id: source of the original image. 48 | filename: original filename of the dataset (without common path). 49 | groundtruth_image_classes: image-level class labels. 50 | groundtruth_boxes: coordinates of the ground truth boxes in the image. 51 | groundtruth_classes: box-level class labels. 52 | groundtruth_label_types: box-level label types (e.g. explicit negative). 53 | groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead] 54 | is the groundtruth a single object or a crowd. 55 | groundtruth_area: area of a groundtruth segment. 56 | groundtruth_difficult: is a `difficult` object 57 | groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the 58 | same class, forming a connected group, where instances are heavily 59 | occluding each other. 60 | proposal_boxes: coordinates of object proposal boxes. 61 | proposal_objectness: objectness score of each proposal. 62 | groundtruth_instance_masks: ground truth instance masks. 63 | groundtruth_instance_boundaries: ground truth instance boundaries. 64 | groundtruth_instance_classes: instance mask-level class labels. 65 | groundtruth_keypoints: ground truth keypoints. 66 | groundtruth_keypoint_visibilities: ground truth keypoint visibilities. 67 | groundtruth_label_scores: groundtruth label scores. 68 | groundtruth_weights: groundtruth weight factor for bounding boxes. 69 | num_groundtruth_boxes: number of groundtruth boxes. 70 | true_image_shapes: true shapes of images in the resized images, as resized 71 | images can be padded with zeros. 72 | """ 73 | 74 | image = "image" 75 | original_image = "original_image" 76 | key = "key" 77 | source_id = "source_id" 78 | filename = "filename" 79 | groundtruth_image_classes = "groundtruth_image_classes" 80 | groundtruth_boxes = "groundtruth_boxes" 81 | groundtruth_classes = "groundtruth_classes" 82 | groundtruth_label_types = "groundtruth_label_types" 83 | groundtruth_is_crowd = "groundtruth_is_crowd" 84 | groundtruth_area = "groundtruth_area" 85 | groundtruth_difficult = "groundtruth_difficult" 86 | groundtruth_group_of = "groundtruth_group_of" 87 | proposal_boxes = "proposal_boxes" 88 | proposal_objectness = "proposal_objectness" 89 | groundtruth_instance_masks = "groundtruth_instance_masks" 90 | groundtruth_instance_boundaries = "groundtruth_instance_boundaries" 91 | groundtruth_instance_classes = "groundtruth_instance_classes" 92 | groundtruth_keypoints = "groundtruth_keypoints" 93 | groundtruth_keypoint_visibilities = "groundtruth_keypoint_visibilities" 94 | groundtruth_label_scores = "groundtruth_label_scores" 95 | groundtruth_weights = "groundtruth_weights" 96 | num_groundtruth_boxes = "num_groundtruth_boxes" 97 | true_image_shape = "true_image_shape" 98 | 99 | 100 | class DetectionResultFields(object): 101 | """Naming conventions for storing the output of the detector. 102 | 103 | Attributes: 104 | source_id: source of the original image. 105 | key: unique key corresponding to image. 106 | detection_boxes: coordinates of the detection boxes in the image. 107 | detection_scores: detection scores for the detection boxes in the image. 108 | detection_classes: detection-level class labels. 109 | detection_masks: contains a segmentation mask for each detection box. 110 | detection_boundaries: contains an object boundary for each detection box. 111 | detection_keypoints: contains detection keypoints for each detection box. 112 | num_detections: number of detections in the batch. 113 | """ 114 | 115 | source_id = "source_id" 116 | key = "key" 117 | detection_boxes = "detection_boxes" 118 | detection_scores = "detection_scores" 119 | detection_classes = "detection_classes" 120 | detection_masks = "detection_masks" 121 | detection_boundaries = "detection_boundaries" 122 | detection_keypoints = "detection_keypoints" 123 | num_detections = "num_detections" 124 | 125 | 126 | class BoxListFields(object): 127 | """Naming conventions for BoxLists. 128 | 129 | Attributes: 130 | boxes: bounding box coordinates. 131 | classes: classes per bounding box. 132 | scores: scores per bounding box. 133 | weights: sample weights per bounding box. 134 | objectness: objectness score per bounding box. 135 | masks: masks per bounding box. 136 | boundaries: boundaries per bounding box. 137 | keypoints: keypoints per bounding box. 138 | keypoint_heatmaps: keypoint heatmaps per bounding box. 139 | """ 140 | 141 | boxes = "boxes" 142 | classes = "classes" 143 | scores = "scores" 144 | weights = "weights" 145 | objectness = "objectness" 146 | masks = "masks" 147 | boundaries = "boundaries" 148 | keypoints = "keypoints" 149 | keypoint_heatmaps = "keypoint_heatmaps" 150 | 151 | 152 | class TfExampleFields(object): 153 | """TF-example proto feature names for object detection. 154 | 155 | Holds the standard feature names to load from an Example proto for object 156 | detection. 157 | 158 | Attributes: 159 | image_encoded: JPEG encoded string 160 | image_format: image format, e.g. "JPEG" 161 | filename: filename 162 | channels: number of channels of image 163 | colorspace: colorspace, e.g. "RGB" 164 | height: height of image in pixels, e.g. 462 165 | width: width of image in pixels, e.g. 581 166 | source_id: original source of the image 167 | object_class_text: labels in text format, e.g. ["person", "cat"] 168 | object_class_label: labels in numbers, e.g. [16, 8] 169 | object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30 170 | object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40 171 | object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50 172 | object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70 173 | object_view: viewpoint of object, e.g. ["frontal", "left"] 174 | object_truncated: is object truncated, e.g. [true, false] 175 | object_occluded: is object occluded, e.g. [true, false] 176 | object_difficult: is object difficult, e.g. [true, false] 177 | object_group_of: is object a single object or a group of objects 178 | object_depiction: is object a depiction 179 | object_is_crowd: [DEPRECATED, use object_group_of instead] 180 | is the object a single object or a crowd 181 | object_segment_area: the area of the segment. 182 | object_weight: a weight factor for the object's bounding box. 183 | instance_masks: instance segmentation masks. 184 | instance_boundaries: instance boundaries. 185 | instance_classes: Classes for each instance segmentation mask. 186 | detection_class_label: class label in numbers. 187 | detection_bbox_ymin: ymin coordinates of a detection box. 188 | detection_bbox_xmin: xmin coordinates of a detection box. 189 | detection_bbox_ymax: ymax coordinates of a detection box. 190 | detection_bbox_xmax: xmax coordinates of a detection box. 191 | detection_score: detection score for the class label and box. 192 | """ 193 | 194 | image_encoded = "image/encoded" 195 | image_format = "image/format" # format is reserved keyword 196 | filename = "image/filename" 197 | channels = "image/channels" 198 | colorspace = "image/colorspace" 199 | height = "image/height" 200 | width = "image/width" 201 | source_id = "image/source_id" 202 | object_class_text = "image/object/class/text" 203 | object_class_label = "image/object/class/label" 204 | object_bbox_ymin = "image/object/bbox/ymin" 205 | object_bbox_xmin = "image/object/bbox/xmin" 206 | object_bbox_ymax = "image/object/bbox/ymax" 207 | object_bbox_xmax = "image/object/bbox/xmax" 208 | object_view = "image/object/view" 209 | object_truncated = "image/object/truncated" 210 | object_occluded = "image/object/occluded" 211 | object_difficult = "image/object/difficult" 212 | object_group_of = "image/object/group_of" 213 | object_depiction = "image/object/depiction" 214 | object_is_crowd = "image/object/is_crowd" 215 | object_segment_area = "image/object/segment/area" 216 | object_weight = "image/object/weight" 217 | instance_masks = "image/segmentation/object" 218 | instance_boundaries = "image/boundaries/object" 219 | instance_classes = "image/segmentation/object/class" 220 | detection_class_label = "image/detection/label" 221 | detection_bbox_ymin = "image/detection/bbox/ymin" 222 | detection_bbox_xmin = "image/detection/bbox/xmin" 223 | detection_bbox_ymax = "image/detection/bbox/ymax" 224 | detection_bbox_xmax = "image/detection/bbox/xmax" 225 | detection_score = "image/detection/score" 226 | -------------------------------------------------------------------------------- /slowfast/datasets/ava_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import logging 5 | import os 6 | import torch 7 | import pickle 8 | 9 | from collections import defaultdict 10 | from fvcore.common.file_io import PathManager 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | FPS = 30 15 | AVA_VALID_FRAMES = range(902, 1799) 16 | 17 | 18 | def load_image_lists(cfg, is_train): 19 | """ 20 | Loading image paths from corresponding files. 21 | 22 | Args: 23 | cfg (CfgNode): config. 24 | is_train (bool): if it is training dataset or not. 25 | 26 | Returns: 27 | image_paths (list[list]): a list of items. Each item (also a list) 28 | corresponds to one video and contains the paths of images for 29 | this video. 30 | video_idx_to_name (list): a list which stores video names. 31 | """ 32 | list_filenames = [ 33 | os.path.join(cfg.AVA.FRAME_LIST_DIR, filename) 34 | for filename in ( 35 | cfg.AVA.TRAIN_LISTS if is_train else cfg.AVA.TEST_LISTS 36 | ) 37 | ] 38 | image_paths = defaultdict(list) 39 | video_name_to_idx = {} 40 | video_idx_to_name = [] 41 | for list_filename in list_filenames: 42 | with PathManager.open(list_filename, "r") as f: 43 | f.readline() 44 | for line in f: 45 | row = line.split() 46 | # The format of each row should follow: 47 | # original_vido_id video_id frame_id path labels. 48 | assert len(row) >= 4 49 | video_name = row[0] 50 | 51 | if video_name not in video_name_to_idx: 52 | idx = len(video_name_to_idx) 53 | video_name_to_idx[video_name] = idx 54 | video_idx_to_name.append(video_name) 55 | 56 | data_key = video_name_to_idx[video_name] 57 | 58 | image_paths[data_key].append( 59 | os.path.join(cfg.AVA.FRAME_DIR, row[3]) 60 | ) 61 | 62 | image_paths = [image_paths[i] for i in range(len(image_paths))] 63 | 64 | logger.info( 65 | "Finished loading image paths from: %s" % ", ".join(list_filenames) 66 | ) 67 | 68 | return image_paths, video_idx_to_name 69 | 70 | 71 | def load_feature_bank(cfg, mode): 72 | 73 | bank_path = cfg.AVA.FEATURE_BANK_PATH 74 | if not os.path.exists(bank_path): 75 | return None, 0 76 | 77 | bank = torch.load(bank_path)[mode] 78 | num = 0 79 | for _, saved in bank.items(): 80 | for sec in saved: 81 | num += len(saved[sec]['feature']) 82 | 83 | return bank, num 84 | 85 | 86 | def draw_feature_from_sliding_window(cfg, vid, sec, split, db=None): 87 | 88 | """ 89 | 90 | Args: 91 | cfg: (CfgNode) global configuration 92 | vid: (int) video instance number 93 | sec: (int) second number 94 | split: (str) dataset split, 'train', 'val' or 'test' 95 | 96 | Returns: 97 | output_feat: (torch.Tensor or None) 98 | output_context: (torch.Tensor or None) 99 | """ 100 | 101 | window_size = cfg.AVA.SLIDING_WINDOW_SIZE 102 | valid_range_list = list(AVA_VALID_FRAMES) 103 | min_f, max_f = valid_range_list[0], valid_range_list[-1] 104 | start = max(sec - window_size, min_f) 105 | end = min(sec + window_size, max_f) 106 | 107 | bank_mode = "train" if split == "train" else "test" 108 | bank_path = "{}/{}".format(bank_mode, vid) 109 | 110 | output_feat = [] 111 | time_stamp = [] 112 | 113 | assert db is not None, 'empty feature bank' 114 | with db.begin(write=False) as txn: 115 | 116 | for s in range(start, end+1): 117 | if s == sec: 118 | continue 119 | 120 | key_path = '{}/{}'.format(bank_path, s) 121 | # logger.info(key_path) 122 | feat_db_val = txn.get(os.path.join(key_path, 'feature').encode()) 123 | if feat_db_val is not None: 124 | fb = pickle.loads(feat_db_val) 125 | output_feat.append(fb) 126 | time_stamp += [s] * fb.shape[0] 127 | 128 | if len(output_feat) == 0: 129 | output_feat = None 130 | elif cfg.AVA.GATHER_BANK: 131 | output_feat = torch.cat(output_feat, dim=0) 132 | 133 | # if output_feat is not None: 134 | # print(start, end, output_feat.shape[0], output_context.shape[0]) 135 | return output_feat, time_stamp 136 | 137 | 138 | def load_boxes_and_labels(cfg, mode): 139 | """ 140 | Loading boxes and labels from csv files. 141 | 142 | Args: 143 | cfg (CfgNode): config. 144 | mode (str): 'train', 'val', or 'test' mode. 145 | Returns: 146 | all_boxes (dict): a dict which maps from `video_name` and 147 | `frame_sec` to a list of `box`. Each `box` is a 148 | [`box_coord`, `box_labels`] where `box_coord` is the 149 | coordinates of box and 'box_labels` are the corresponding 150 | labels for the box. 151 | """ 152 | gt_lists = cfg.AVA.TRAIN_GT_BOX_LISTS if mode == "train" else cfg.AVA.TEST_GT_BOX_LISTS 153 | pred_lists = ( 154 | cfg.AVA.TRAIN_PREDICT_BOX_LISTS 155 | if mode == "train" 156 | else cfg.AVA.TEST_PREDICT_BOX_LISTS 157 | ) 158 | ann_filenames = [ 159 | os.path.join(cfg.AVA.ANNOTATION_DIR, filename) 160 | for filename in gt_lists + pred_lists 161 | ] 162 | ann_is_gt_box = [True] * len(gt_lists) + [False] * len(pred_lists) 163 | 164 | detect_thresh = cfg.AVA.DETECTION_SCORE_THRESH 165 | # Only select frame_sec % 4 = 0 samples for validation if not 166 | # set FULL_TEST_ON_VAL. 167 | boxes_sample_rate = ( 168 | 4 if mode == "val" and not cfg.AVA.FULL_TEST_ON_VAL else 1 169 | ) 170 | all_boxes, count, unique_box_count = parse_bboxes_file( 171 | ann_filenames=ann_filenames, 172 | ann_is_gt_box=ann_is_gt_box, 173 | detect_thresh=detect_thresh, 174 | boxes_sample_rate=boxes_sample_rate, 175 | ) 176 | logger.info( 177 | "Finished loading annotations from: %s" % ", ".join(ann_filenames) 178 | ) 179 | logger.info("Detection threshold: {}".format(detect_thresh)) 180 | logger.info("Number of unique boxes: %d" % unique_box_count) 181 | logger.info("Number of annotations: %d" % count) 182 | 183 | return all_boxes 184 | 185 | 186 | def get_keyframe_data(boxes_and_labels): 187 | """ 188 | Getting keyframe indices, boxes and labels in the dataset. 189 | 190 | Args: 191 | boxes_and_labels (list[dict]): a list which maps from video_idx to a dict. 192 | Each dict `frame_sec` to a list of boxes and corresponding labels. 193 | 194 | Returns: 195 | keyframe_indices (list): a list of indices of the keyframes. 196 | keyframe_boxes_and_labels (list[list[list]]): a list of list which maps from 197 | video_idx and sec_idx to a list of boxes and corresponding labels. 198 | """ 199 | 200 | def sec_to_frame(sec): 201 | """ 202 | Convert time index (in second) to frame index. 203 | 0: 900 204 | 30: 901 205 | """ 206 | return (sec - 900) * FPS 207 | 208 | keyframe_indices = [] 209 | keyframe_boxes_and_labels = [] 210 | count = 0 211 | for video_idx in range(len(boxes_and_labels)): 212 | sec_idx = 0 213 | keyframe_boxes_and_labels.append([]) 214 | for sec in boxes_and_labels[video_idx].keys(): 215 | if sec not in AVA_VALID_FRAMES: 216 | continue 217 | 218 | if len(boxes_and_labels[video_idx][sec]) > 0: 219 | keyframe_indices.append( 220 | (video_idx, sec_idx, sec, sec_to_frame(sec)) 221 | ) 222 | keyframe_boxes_and_labels[video_idx].append( 223 | boxes_and_labels[video_idx][sec] 224 | ) 225 | sec_idx += 1 226 | count += 1 227 | logger.info("%d keyframes used." % count) 228 | 229 | return keyframe_indices, keyframe_boxes_and_labels 230 | 231 | 232 | def get_num_boxes_used(keyframe_indices, keyframe_boxes_and_labels): 233 | """ 234 | Get total number of used boxes. 235 | 236 | Args: 237 | keyframe_indices (list): a list of indices of the keyframes. 238 | keyframe_boxes_and_labels (list[list[list]]): a list of list which maps from 239 | video_idx and sec_idx to a list of boxes and corresponding labels. 240 | 241 | Returns: 242 | count (int): total number of used boxes. 243 | """ 244 | 245 | count = 0 246 | for video_idx, sec_idx, _, _ in keyframe_indices: 247 | count += len(keyframe_boxes_and_labels[video_idx][sec_idx]) 248 | return count 249 | 250 | 251 | def parse_bboxes_file( 252 | ann_filenames, ann_is_gt_box, detect_thresh, boxes_sample_rate=1 253 | ): 254 | """ 255 | Parse AVA bounding boxes files. 256 | Args: 257 | ann_filenames (list of str(s)): a list of AVA bounding boxes annotation files. 258 | ann_is_gt_box (list of bools): a list of boolean to indicate whether the corresponding 259 | ann_file is ground-truth. `ann_is_gt_box[i]` correspond to `ann_filenames[i]`. 260 | detect_thresh (float): threshold for accepting predicted boxes, range [0, 1]. 261 | boxes_sample_rate (int): sample rate for test bounding boxes. Get 1 every `boxes_sample_rate`. 262 | """ 263 | all_boxes = {} 264 | count = 0 265 | unique_box_count = 0 266 | for filename, is_gt_box in zip(ann_filenames, ann_is_gt_box): 267 | with PathManager.open(filename, "r") as f: 268 | for line in f: 269 | row = line.strip().split(",") 270 | # When we use predicted boxes to train/eval, we need to 271 | # ignore the boxes whose scores are below the threshold. 272 | if not is_gt_box: 273 | score = float(row[7]) 274 | if score < detect_thresh: 275 | continue 276 | 277 | video_name, frame_sec = row[0], int(row[1]) 278 | if frame_sec % boxes_sample_rate != 0: 279 | continue 280 | 281 | # Box with format [x1, y1, x2, y2] with a range of [0, 1] as float. 282 | box_key = ",".join(row[2:6]) 283 | box = list(map(float, row[2:6])) 284 | label = -1 if row[6] == "" else int(row[6]) 285 | 286 | if video_name not in all_boxes: 287 | all_boxes[video_name] = {} 288 | for sec in AVA_VALID_FRAMES: 289 | all_boxes[video_name][sec] = {} 290 | 291 | if box_key not in all_boxes[video_name][frame_sec]: 292 | all_boxes[video_name][frame_sec][box_key] = [box, []] 293 | unique_box_count += 1 294 | 295 | all_boxes[video_name][frame_sec][box_key][1].append(label) 296 | if label != -1: 297 | count += 1 298 | 299 | for video_name in all_boxes.keys(): 300 | for frame_sec in all_boxes[video_name].keys(): 301 | # Save in format of a list of [box_i, box_i_labels]. 302 | all_boxes[video_name][frame_sec] = list( 303 | all_boxes[video_name][frame_sec].values() 304 | ) 305 | 306 | return all_boxes, count, unique_box_count 307 | -------------------------------------------------------------------------------- /slowfast/datasets/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import logging 4 | import numpy as np 5 | import os 6 | import random 7 | import time 8 | from collections import defaultdict 9 | import cv2 10 | import torch 11 | from fvcore.common.file_io import PathManager 12 | 13 | from . import transform as transform 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def retry_load_images(image_paths, retry=10, backend="pytorch"): 19 | """ 20 | This function is to load images with support of retrying for failed load. 21 | 22 | Args: 23 | image_paths (list): paths of images needed to be loaded. 24 | retry (int, optional): maximum time of loading retrying. Defaults to 10. 25 | backend (str): `pytorch` or `cv2`. 26 | 27 | Returns: 28 | imgs (list): list of loaded images. 29 | """ 30 | for i in range(retry): 31 | imgs = [] 32 | for image_path in image_paths: 33 | with PathManager.open(image_path, "rb") as f: 34 | img_str = np.frombuffer(f.read(), np.uint8) 35 | img = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR) 36 | imgs.append(img) 37 | 38 | if all(img is not None for img in imgs): 39 | if backend == "pytorch": 40 | imgs = torch.as_tensor(np.stack(imgs)) 41 | return imgs 42 | else: 43 | logger.warn("Reading failed. Will retry.") 44 | time.sleep(1.0) 45 | if i == retry - 1: 46 | raise Exception("Failed to load images {}".format(image_paths)) 47 | 48 | 49 | def get_sequence(center_idx, half_len, sample_rate, num_frames): 50 | """ 51 | Sample frames among the corresponding clip. 52 | 53 | Args: 54 | center_idx (int): center frame idx for current clip 55 | half_len (int): half of the clip length 56 | sample_rate (int): sampling rate for sampling frames inside of the clip 57 | num_frames (int): number of expected sampled frames 58 | 59 | Returns: 60 | seq (list): list of indexes of sampled frames in this clip. 61 | """ 62 | seq = list(range(center_idx - half_len, center_idx + half_len, sample_rate)) 63 | 64 | for seq_idx in range(len(seq)): 65 | if seq[seq_idx] < 0: 66 | seq[seq_idx] = 0 67 | elif seq[seq_idx] >= num_frames: 68 | seq[seq_idx] = num_frames - 1 69 | return seq 70 | 71 | 72 | def pack_pathway_output(cfg, frames): 73 | """ 74 | Prepare output as a list of tensors. Each tensor corresponding to a 75 | unique pathway. 76 | Args: 77 | frames (tensor): frames of images sampled from the video. The 78 | dimension is `channel` x `num frames` x `height` x `width`. 79 | Returns: 80 | frame_list (list): list of tensors with the dimension of 81 | `channel` x `num frames` x `height` x `width`. 82 | """ 83 | if cfg.DATA.REVERSE_INPUT_CHANNEL: 84 | frames = frames[[2, 1, 0], :, :, :] 85 | if cfg.MODEL.ARCH in cfg.MODEL.SINGLE_PATHWAY_ARCH: 86 | frame_list = [frames] 87 | elif cfg.MODEL.ARCH in cfg.MODEL.MULTI_PATHWAY_ARCH: 88 | fast_pathway = frames 89 | # Perform temporal sampling from the fast pathway. 90 | slow_pathway = torch.index_select( 91 | frames, 92 | 1, 93 | torch.linspace( 94 | 0, frames.shape[1] - 1, frames.shape[1] // cfg.SLOWFAST.ALPHA 95 | ).long(), 96 | ) 97 | frame_list = [slow_pathway, fast_pathway] 98 | else: 99 | raise NotImplementedError( 100 | "Model arch {} is not in {}".format( 101 | cfg.MODEL.ARCH, 102 | cfg.MODEL.SINGLE_PATHWAY_ARCH + cfg.MODEL.MULTI_PATHWAY_ARCH, 103 | ) 104 | ) 105 | return frame_list 106 | 107 | 108 | def spatial_sampling( 109 | frames, 110 | spatial_idx=-1, 111 | min_scale=256, 112 | max_scale=320, 113 | crop_size=224, 114 | random_horizontal_flip=True, 115 | inverse_uniform_sampling=False, 116 | ): 117 | """ 118 | Perform spatial sampling on the given video frames. If spatial_idx is 119 | -1, perform random scale, random crop, and random flip on the given 120 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 121 | with the given spatial_idx. 122 | Args: 123 | frames (tensor): frames of images sampled from the video. The 124 | dimension is `num frames` x `height` x `width` x `channel`. 125 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 126 | or 2, perform left, center, right crop if width is larger than 127 | height, and perform top, center, buttom crop if height is larger 128 | than width. 129 | min_scale (int): the minimal size of scaling. 130 | max_scale (int): the maximal size of scaling. 131 | crop_size (int): the size of height and width used to crop the 132 | frames. 133 | inverse_uniform_sampling (bool): if True, sample uniformly in 134 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 135 | scale. If False, take a uniform sample from [min_scale, 136 | max_scale]. 137 | Returns: 138 | frames (tensor): spatially sampled frames. 139 | """ 140 | assert spatial_idx in [-1, 0, 1, 2] 141 | if spatial_idx == -1: 142 | frames, _ = transform.random_short_side_scale_jitter( 143 | images=frames, 144 | min_size=min_scale, 145 | max_size=max_scale, 146 | inverse_uniform_sampling=inverse_uniform_sampling, 147 | ) 148 | frames, _ = transform.random_crop(frames, crop_size) 149 | if random_horizontal_flip: 150 | frames, _ = transform.horizontal_flip(0.5, frames) 151 | else: 152 | # The testing is deterministic and no jitter should be performed. 153 | # min_scale, max_scale, and crop_size are expect to be the same. 154 | assert len({min_scale, max_scale, crop_size}) == 1 155 | frames, _ = transform.random_short_side_scale_jitter( 156 | frames, min_scale, max_scale 157 | ) 158 | frames, _ = transform.uniform_crop(frames, crop_size, spatial_idx) 159 | return frames 160 | 161 | 162 | def as_binary_vector(labels, num_classes): 163 | """ 164 | Construct binary label vector given a list of label indices. 165 | Args: 166 | labels (list): The input label list. 167 | num_classes (int): Number of classes of the label vector. 168 | Returns: 169 | labels (numpy array): the resulting binary vector. 170 | """ 171 | label_arr = np.zeros((num_classes,)) 172 | 173 | for lbl in set(labels): 174 | label_arr[lbl] = 1.0 175 | return label_arr 176 | 177 | 178 | def aggregate_labels(label_list): 179 | """ 180 | Join a list of label list. 181 | Args: 182 | labels (list): The input label list. 183 | Returns: 184 | labels (list): The joint list of all lists in input. 185 | """ 186 | all_labels = [] 187 | for labels in label_list: 188 | for l in labels: 189 | all_labels.append(l) 190 | return list(set(all_labels)) 191 | 192 | 193 | def convert_to_video_level_labels(labels): 194 | """ 195 | Aggregate annotations from all frames of a video to form video-level labels. 196 | Args: 197 | labels (list): The input label list. 198 | Returns: 199 | labels (list): Same as input, but with each label replaced by 200 | a video-level one. 201 | """ 202 | for video_id in range(len(labels)): 203 | video_level_labels = aggregate_labels(labels[video_id]) 204 | for i in range(len(labels[video_id])): 205 | labels[video_id][i] = video_level_labels 206 | return labels 207 | 208 | 209 | def load_image_lists(frame_list_file, prefix="", return_list=False): 210 | """ 211 | Load image paths and labels from a "frame list". 212 | Each line of the frame list contains: 213 | `original_vido_id video_id frame_id path labels` 214 | Args: 215 | frame_list_file (string): path to the frame list. 216 | prefix (str): the prefix for the path. 217 | return_list (bool): if True, return a list. If False, return a dict. 218 | Returns: 219 | image_paths (list or dict): list of list containing path to each frame. 220 | If return_list is False, then return in a dict form. 221 | labels (list or dict): list of list containing label of each frame. 222 | If return_list is False, then return in a dict form. 223 | """ 224 | image_paths = defaultdict(list) 225 | labels = defaultdict(list) 226 | with PathManager.open(frame_list_file, "r") as f: 227 | assert f.readline().startswith("original_vido_id") 228 | for line in f: 229 | row = line.split() 230 | # original_vido_id video_id frame_id path labels 231 | assert len(row) == 5 232 | video_name = row[0] 233 | if prefix == "": 234 | path = row[3] 235 | else: 236 | path = os.path.join(prefix, row[3]) 237 | image_paths[video_name].append(path) 238 | frame_labels = row[-1].replace('"', "") 239 | if frame_labels != "": 240 | labels[video_name].append( 241 | [int(x) for x in frame_labels.split(",")] 242 | ) 243 | else: 244 | labels[video_name].append([]) 245 | 246 | if return_list: 247 | keys = image_paths.keys() 248 | image_paths = [image_paths[key] for key in keys] 249 | labels = [labels[key] for key in keys] 250 | return image_paths, labels 251 | return dict(image_paths), dict(labels) 252 | 253 | 254 | def tensor_normalize(tensor, mean, std): 255 | """ 256 | Normalize a given tensor by subtracting the mean and dividing the std. 257 | Args: 258 | tensor (tensor): tensor to normalize. 259 | mean (tensor or list): mean value to subtract. 260 | std (tensor or list): std to divide. 261 | """ 262 | if tensor.dtype == torch.uint8: 263 | tensor = tensor.float() 264 | tensor = tensor / 255.0 265 | if type(mean) == list: 266 | mean = torch.tensor(mean) 267 | if type(std) == list: 268 | std = torch.tensor(std) 269 | tensor = tensor - mean 270 | tensor = tensor / std 271 | return tensor 272 | 273 | 274 | def get_random_sampling_rate(long_cycle_sampling_rate, sampling_rate): 275 | """ 276 | When multigrid training uses a fewer number of frames, we randomly 277 | increase the sampling rate so that some clips cover the original span. 278 | """ 279 | if long_cycle_sampling_rate > 0: 280 | assert long_cycle_sampling_rate >= sampling_rate 281 | return random.randint(sampling_rate, long_cycle_sampling_rate) 282 | else: 283 | return sampling_rate 284 | 285 | 286 | def revert_tensor_normalize(tensor, mean, std): 287 | """ 288 | Revert normalization for a given tensor by multiplying by the std and adding the mean. 289 | Args: 290 | tensor (tensor): tensor to revert normalization. 291 | mean (tensor or list): mean value to add. 292 | std (tensor or list): std to multiply. 293 | """ 294 | if type(mean) == list: 295 | mean = torch.tensor(mean) 296 | if type(std) == list: 297 | std = torch.tensor(std) 298 | tensor = tensor * std 299 | tensor = tensor + mean 300 | return tensor 301 | --------------------------------------------------------------------------------