├── .gitignore ├── Dockerfile ├── LICENCE ├── README.md ├── cuda_ops ├── python │ └── __init__.py ├── setup.py └── src │ ├── common.cpp │ ├── common.hpp │ ├── optimal_state_change.cu │ └── optimal_state_change_indices.cu ├── download_requirements.sh ├── method ├── dataset.py ├── model.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .ipynb_checkpoints 3 | CLIP 4 | weights 5 | videos 6 | ChangeIt 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel 2 | # using pytorch:1.12.0-cuda11.3-cudnn8-devel results in training being 2x slower for some weird reason 3 | 4 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 5 | 6 | RUN apt-get update \ 7 | && apt-get install ffmpeg wget git -y 8 | 9 | RUN pip install \ 10 | opencv-python \ 11 | pillow \ 12 | matplotlib \ 13 | scikit-learn \ 14 | scipy \ 15 | tqdm \ 16 | pandas \ 17 | ffmpeg-python \ 18 | ftfy \ 19 | regex \ 20 | imgaug 21 | 22 | ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 23 | 24 | COPY cuda_ops /tmp 25 | 26 | RUN cd /tmp \ 27 | && TORCH_CUDA_ARCH_LIST="6.1;7.0;7.5;8.0;8.6" python setup.py install \ 28 | && rm -rf * 29 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tomáš Souček 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Task Learning of Object States and State-Modifying Actions from Web Videos 2 | 3 | ### [[Project Website :dart:]](https://soczech.github.io/multi-task-object-states/)   [[Paper (Arxiv) :page_with_curl:]](https://arxiv.org/abs/2211.13500)   [[Paper (TPAMI) :page_with_curl:]](https://ieeexplore.ieee.org/abstract/document/10420504)   [Code :octocat:] 4 | 5 | This repository contrains code for the TPAMI paper [Multi-Task Learning of Object States and State-Modifying Actions from Web Videos](https://ieeexplore.ieee.org/abstract/document/10420504). 6 | 7 | 8 | 9 | 10 | ## Train the model on ChangeIt dataset 11 | 1. **Setup the environment** 12 | - Our code can be run in a docker container. Build it by running the following command. 13 | Note that by default, we compile custom CUDA code for architectures 6.1, 7.0, 7.5, 8.0, and 8.6. 14 | You may need to update the Dockerfile with your GPU architecture. 15 | ``` 16 | docker build -t multi-task-object-states . 17 | ``` 18 | - Go into the docker image. 19 | ``` 20 | docker run -it --rm --gpus all -v $(pwd):$(pwd) -w $(pwd) --user=$(id -u $USER):$(id -g $USER) multi-task-object-states bash 21 | ``` 22 | 23 | 2. **Download requirements** 24 | - Our code requires CLIP repository, CLIP model weights, and the ChangeIt dataset annotations. 25 | Run `./download_requirements.sh` to obtain those dependencies or download them yourselves. 26 | 27 | 3. **Download dataset** 28 | - To replicate our experiments on the ChangeIt dataset, the dataset videos are required. 29 | Please download them and put them inside `videos/*category*` folder. 30 | See [ChangeIt GitHub page](https://github.com/soCzech/ChangeIt) on how to download them. 31 | 32 | 4. **Train a model** 33 | - Run the training. 34 | ``` 35 | python train.py --video_roots ./videos 36 | --dataset_root ./ChangeIt 37 | --train_backbone 38 | --augment 39 | --local_batch_size 2 40 | ``` 41 | - We trained the model on 32 GPUs, i.e. batch size 64. 42 | - To run the code on multiple GPUs, simply run the code on a machine with multiple GPUs. 43 | - To run the code on multiple nodes, run the code once on each node. 44 | If you are not running on slurm, you also need to set environment variable `SLURM_NPROCS` 45 | to the total number of nodes and the variable `SLURM_PROCID` to the node id starting from zero. 46 | Make sure you also set `SLURM_JOBID` to some unique value. 47 | 48 | 49 | ## Train the model on your dataset 50 | - To train the model on your dataset, complete steps **1.** and **2.** from above. 51 | - Put your videos into `*dir*/*category*` for every video category `*category*`. 52 | - Put your annotations for selected videos into `*dataset*/annotations/*category*`. 53 | Use the same [format](https://github.com/soCzech/ChangeIt/tree/main/annotations) as in the case of ChangeIt dataset. 54 | - Run the training. 55 | ``` 56 | python train.py --video_roots *dir* 57 | --dataset_root *dataset* 58 | --train_backbone 59 | --augment 60 | --local_batch_size 2 61 | --ignore_video_weight 62 | ``` 63 | - `--ignore_video_weight` option ignores noise adaptive weighting done for noisy ChangeIt dataset. 64 | To use the noise adaptive weighting, you need to provide `*dataset*/categories.csv` and `*dataset*/videos/*category*.csv` files as well. 65 | 66 | 67 | ## Use a trained model 68 | Here is an example code for the inference of a trained model. 69 | ```python 70 | checkpoint = torch.load("path/to/saved/model.pth", map_location="cpu") 71 | model = ClipClassifier(params=checkpoint["args"], 72 | n_classes=checkpoint["n_classes"], 73 | hidden_mlp_layers=checkpoint["hidden_mlp_layers"]).cuda() 74 | model.load_state_dict({k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()}) 75 | 76 | video_frames = torch.from_numpy( 77 | extract_frames(video_fn, fps=1, size=(398, 224), crop=(398 - 224, 0))) 78 | 79 | with torch.no_grad(): 80 | predictions = model(video_frames.cuda()) 81 | state_pred, action_pred = torch.softmax(predictions["state"], -1), torch.softmax(predictions["action"], -1) 82 | ``` 83 | 84 | 85 | ## Citation 86 | ```bibtex 87 | @article{soucek2024multitask, 88 | title={Multi-Task Learning of Object States and State-Modifying Actions from Web Videos}, 89 | author={Sou\v{c}ek, Tom\'{a}\v{s} and Alayrac, Jean-Baptiste and Miech, Antoine and Laptev, Ivan and Sivic, Josef}, 90 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 91 | year={2024}, 92 | doi={10.1109/TPAMI.2024.3362288} 93 | } 94 | ``` 95 | 96 | 97 | ## Acknowledgements 98 | This work was partly supported by the European Regional Development Fund under the project IMPACT (reg. no. CZ.02.1.01/0.0/0.0/15_003/0000468), the Ministry of Education, Youth and Sports of the Czech Republic through the e-INFRA CZ (ID:90140), the French government under management of Agence Nationale de la Recherche as part of the “Investissements d’avenir” program, reference ANR19-P3IA-0001 (PRAIRIE 3IA Institute), and Louis Vuitton ENS Chair on Artificial Intelligence. 99 | 100 | The ordering constraint code has been adapted from the CVPR 2022 paper 101 | [Look for the Change: Learning Object States and State-Modifying Actions from Untrimmed Web Videos](https://soczech.github.io/look-for-the-change/) 102 | available on [GitHub](https://github.com/soCzech/LookForTheChange). 103 | -------------------------------------------------------------------------------- /cuda_ops/python/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import _lookforthechange_ops 3 | 4 | 5 | def optimal_state_change(state_tensor, action_tensor, lens, delta, kappa, max_action_state_distance=500): 6 | return _lookforthechange_ops.optimal_state_change( 7 | state_tensor.contiguous(), action_tensor.contiguous(), lens, delta, kappa, max_action_state_distance) 8 | 9 | 10 | def optimal_state_change_indices(state_tensor, action_tensor, lens, max_action_state_distance=500): 11 | return _lookforthechange_ops.optimal_state_change_indices( 12 | state_tensor.contiguous(), action_tensor.contiguous(), lens, max_action_state_distance) 13 | -------------------------------------------------------------------------------- /cuda_ops/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils import cpp_extension 3 | 4 | library_dirs = cpp_extension.library_paths(cuda=True) 5 | include_dirs = cpp_extension.include_paths(cuda=True) 6 | 7 | print("library_dirs:", library_dirs) 8 | print("include_dirs:", include_dirs) 9 | 10 | setup( 11 | name="lookforthechange", 12 | version="2.0", 13 | install_requires=[ 14 | "numpy", 15 | "torch" 16 | ], 17 | ext_modules=[ 18 | cpp_extension.CUDAExtension( 19 | name='_lookforthechange_ops', 20 | sources=[ 21 | 'src/common.cpp', 22 | 'src/optimal_state_change.cu', 23 | 'src/optimal_state_change_indices.cu', 24 | ], 25 | library_dirs=library_dirs, 26 | include_dirs=include_dirs 27 | ) 28 | ], 29 | packages=['lookforthechange'], 30 | package_dir={'lookforthechange': './python'}, 31 | cmdclass={'build_ext': cpp_extension.BuildExtension} 32 | ) 33 | -------------------------------------------------------------------------------- /cuda_ops/src/common.cpp: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("optimal_state_change", &optimal_state_change, "Optimal State1-Action-State2 Sequence (GPU)"); 5 | m.def("optimal_state_change_indices", &optimal_state_change_indices, "Optimal State1-Action-State2 Sequence (GPU)"); 6 | } 7 | -------------------------------------------------------------------------------- /cuda_ops/src/common.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // C++ interface 4 | #define CHECK_CPU(x) TORCH_CHECK(!x.type().is_cuda(), #x " must be a CPU tensor") 5 | #define CHECK_GPU(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a GPU tensor") 6 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) 8 | #define CHECK_CUDA_INPUT(x) CHECK_GPU(x); CHECK_CONTIGUOUS(x) 9 | 10 | std::vector optimal_state_change( 11 | torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int delta, int kappa, int max_action_state_distance); 12 | 13 | torch::Tensor optimal_state_change_indices( 14 | torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int max_action_state_distance); 15 | -------------------------------------------------------------------------------- /cuda_ops/src/optimal_state_change.cu: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | 3 | __global__ void SingleStateChangeKernel( 4 | const float* state_tensor, 5 | const float* action_tensor, 6 | const int* lens, 7 | int* state_targets, 8 | int* action_targets, 9 | const int delta, 10 | const int kappa, 11 | const int max_action_state_distance 12 | ) { 13 | const int batch_idx = blockIdx.x; 14 | const int video_len = blockDim.x; 15 | const int state1_pos = threadIdx.x; 16 | const int actual_len = lens[batch_idx]; 17 | 18 | // get pointer to shared memory 19 | extern __shared__ char shared_mem[]; 20 | int* state1_to_action_pos = reinterpret_cast(shared_mem); 21 | int* state1_to_state2_pos = state1_to_action_pos + video_len; 22 | float* state1_to_score = reinterpret_cast(state1_to_state2_pos + video_len); 23 | float* action_tensor_shared = state1_to_score + video_len; 24 | float* state_tensor_shared = action_tensor_shared + video_len; 25 | 26 | // load action and state tensors into shared memory 27 | action_tensor_shared[state1_pos] = action_tensor[batch_idx * video_len + state1_pos]; 28 | state_tensor_shared[2 * state1_pos + 0] = state_tensor[batch_idx * video_len * 2 + state1_pos * 2 + 0]; 29 | state_tensor_shared[2 * state1_pos + 1] = state_tensor[batch_idx * video_len * 2 + state1_pos * 2 + 1]; 30 | 31 | __syncthreads(); 32 | 33 | float best_score = -std::numeric_limits::infinity(); 34 | int best_action_pos = 0, best_state2_pos = 0; // position of states/action for videos shorter than 3 35 | 36 | for (int action_pos = state1_pos + 1; action_pos <= state1_pos + max_action_state_distance && action_pos < actual_len - 1; ++action_pos) { // -1: need at least one position for state2 37 | float action_score = action_tensor_shared[action_pos]; 38 | 39 | for (int state2_pos = action_pos + 1; state2_pos <= action_pos + max_action_state_distance && state2_pos < actual_len; ++state2_pos) { 40 | float state2_score = state_tensor_shared[2 * state2_pos + 1]; // 2 states, +1 for second state 41 | 42 | float score = action_score * state2_score; 43 | if (score > best_score) { 44 | best_score = score; 45 | best_action_pos = action_pos; 46 | best_state2_pos = state2_pos; 47 | } 48 | } 49 | } 50 | 51 | state1_to_action_pos[state1_pos] = best_action_pos; 52 | state1_to_state2_pos[state1_pos] = best_state2_pos; 53 | state1_to_score[state1_pos] = best_score * state_tensor_shared[2 * state1_pos + 0]; 54 | 55 | __syncthreads(); 56 | 57 | if (state1_pos == 0) { // compute reduction only on the first thread 58 | best_score = state1_to_score[0]; 59 | int best_state1_pos = 0; 60 | for (int i = 1; i < actual_len - 2; ++i) { // -2: need at least one position for action and one for state2 61 | if (best_score < state1_to_score[i]) { 62 | best_state1_pos = i; 63 | best_score = state1_to_score[i]; 64 | } 65 | } 66 | best_action_pos = state1_to_action_pos[best_state1_pos]; 67 | best_state2_pos = state1_to_state2_pos[best_state1_pos]; 68 | 69 | // FILL state_targets TENSOR 70 | // 0 .. default - no label 71 | // 1 .. initial state label 72 | // 2 .. end state label 73 | for (int i = best_state1_pos - delta; i <= best_state1_pos + delta; ++i) { 74 | if (i < 0 || i >= actual_len) continue; 75 | state_targets[batch_idx * video_len + i] = 1; 76 | } 77 | for (int i = best_state2_pos - delta; i <= best_state2_pos + delta; ++i) { 78 | if (i < 0 || i >= actual_len) continue; 79 | state_targets[batch_idx * video_len + i] = 2; 80 | } 81 | 82 | // FILL action_targets TENSOR 83 | // 0 .. default - no label 84 | // 1 .. no-action label 85 | // 2 .. action label 86 | for (int i = 0; i <= delta; ++i) { 87 | int j = best_action_pos - i - kappa; 88 | if (j < 0) { 89 | action_targets[batch_idx * video_len + 0] = 1; 90 | } else { 91 | action_targets[batch_idx * video_len + j] = 1; 92 | } 93 | 94 | int k = best_action_pos + i + kappa; 95 | if (k >= actual_len) { 96 | action_targets[batch_idx * video_len + actual_len - 1] = 1; 97 | } else { 98 | action_targets[batch_idx * video_len + k] = 1; 99 | } 100 | } 101 | for (int i = best_action_pos - delta; i <= best_action_pos + delta; ++i) { 102 | if (i < 0 || i >= actual_len) continue; 103 | action_targets[batch_idx * video_len + i] = 2; 104 | } 105 | } 106 | } 107 | 108 | std::vector optimal_state_change( 109 | torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int delta, int kappa, int max_action_state_distance) { 110 | 111 | CHECK_CUDA_INPUT(state_tensor); 112 | CHECK_CUDA_INPUT(action_tensor); 113 | CHECK_CUDA_INPUT(lens); 114 | 115 | int batch_size = state_tensor.size(0); 116 | int video_len = state_tensor.size(1); 117 | 118 | TORCH_CHECK(state_tensor.size(2) == 2, "state_tensor must be of shape [batch, video_len, 2]") 119 | TORCH_CHECK(action_tensor.size(2) == 1, "action_tensor must be of shape [batch, video_len, 1]") 120 | 121 | auto options = torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA); 122 | auto state_targets = torch::zeros({batch_size, video_len}, options); 123 | auto action_targets = torch::zeros({batch_size, video_len}, options); 124 | 125 | const int threads = video_len; 126 | const int blocks = batch_size; 127 | // store in shared memory: 128 | // best action position for each state1 position (1x int) 129 | // best state2 position for each state1 position (1x int) 130 | // best score for each state1 position (1x float) 131 | // action tensor (1x float) 132 | // state tensor (2x float) 133 | const int shared_mem = video_len * (2 * sizeof(int) + 4 * sizeof(float)); 134 | SingleStateChangeKernel<<>>( 135 | state_tensor.data_ptr(), 136 | action_tensor.data_ptr(), 137 | lens.data_ptr(), 138 | state_targets.data_ptr(), 139 | action_targets.data_ptr(), 140 | delta, 141 | kappa, 142 | max_action_state_distance); 143 | 144 | return std::vector{state_targets, action_targets}; 145 | } 146 | -------------------------------------------------------------------------------- /cuda_ops/src/optimal_state_change_indices.cu: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | 3 | __global__ void SingleStateChangeIndicesKernel( 4 | const float* state_tensor, 5 | const float* action_tensor, 6 | const int* lens, 7 | int* out_indices, 8 | const int max_action_state_distance 9 | ) { 10 | const int batch_idx = blockIdx.x; 11 | const int video_len = blockDim.x; 12 | const int state1_pos = threadIdx.x; 13 | const int actual_len = lens[batch_idx]; 14 | 15 | // get pointer to shared memory 16 | extern __shared__ char shared_mem[]; 17 | int* state1_to_action_pos = reinterpret_cast(shared_mem); 18 | int* state1_to_state2_pos = state1_to_action_pos + video_len; 19 | float* state1_to_score = reinterpret_cast(state1_to_state2_pos + video_len); 20 | float* action_tensor_shared = state1_to_score + video_len; 21 | float* state_tensor_shared = action_tensor_shared + video_len; 22 | 23 | // load action and state tensors into shared memory 24 | action_tensor_shared[state1_pos] = action_tensor[batch_idx * video_len + state1_pos]; 25 | state_tensor_shared[2 * state1_pos + 0] = state_tensor[batch_idx * video_len * 2 + state1_pos * 2 + 0]; 26 | state_tensor_shared[2 * state1_pos + 1] = state_tensor[batch_idx * video_len * 2 + state1_pos * 2 + 1]; 27 | 28 | __syncthreads(); 29 | 30 | float best_score = -std::numeric_limits::infinity(); 31 | int best_action_pos = 0, best_state2_pos = 0; // position of states/action for videos shorter than 3 32 | 33 | for (int action_pos = state1_pos + 1; action_pos <= state1_pos + max_action_state_distance && action_pos < actual_len - 1; ++action_pos) { // -1: need at least one position for state2 34 | float action_score = action_tensor_shared[action_pos]; 35 | 36 | for (int state2_pos = action_pos + 1; state2_pos <= action_pos + max_action_state_distance && state2_pos < actual_len; ++state2_pos) { 37 | float state2_score = state_tensor_shared[2 * state2_pos + 1]; // 2 states, +1 for second state 38 | 39 | float score = action_score * state2_score; 40 | if (score > best_score) { 41 | best_score = score; 42 | best_action_pos = action_pos; 43 | best_state2_pos = state2_pos; 44 | } 45 | } 46 | } 47 | 48 | state1_to_action_pos[state1_pos] = best_action_pos; 49 | state1_to_state2_pos[state1_pos] = best_state2_pos; 50 | state1_to_score[state1_pos] = best_score * state_tensor_shared[2 * state1_pos + 0]; 51 | 52 | __syncthreads(); 53 | 54 | if (state1_pos == 0) { // compute reduction only on the first thread 55 | best_score = state1_to_score[0]; 56 | int best_state1_pos = 0; 57 | for (int i = 1; i < actual_len - 2; ++i) { // -2: need at least one position for action and one for state2 58 | if (best_score < state1_to_score[i]) { 59 | best_state1_pos = i; 60 | best_score = state1_to_score[i]; 61 | } 62 | } 63 | best_action_pos = state1_to_action_pos[best_state1_pos]; 64 | best_state2_pos = state1_to_state2_pos[best_state1_pos]; 65 | 66 | out_indices[batch_idx * 3 + 0] = best_state1_pos; 67 | out_indices[batch_idx * 3 + 1] = best_state2_pos; 68 | out_indices[batch_idx * 3 + 2] = best_action_pos; 69 | } 70 | } 71 | 72 | torch::Tensor optimal_state_change_indices( 73 | torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int max_action_state_distance) { 74 | 75 | CHECK_CUDA_INPUT(state_tensor); 76 | CHECK_CUDA_INPUT(action_tensor); 77 | CHECK_CUDA_INPUT(lens); 78 | 79 | int batch_size = state_tensor.size(0); 80 | int video_len = state_tensor.size(1); 81 | 82 | TORCH_CHECK(state_tensor.size(2) == 2, "state_tensor must be of shape [batch, video_len, 2]") 83 | TORCH_CHECK(action_tensor.size(2) == 1, "action_tensor must be of shape [batch, video_len, 1]") 84 | 85 | auto options = torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA); 86 | auto out_indices = torch::zeros({batch_size, 3}, options); 87 | 88 | const int threads = video_len; 89 | const int blocks = batch_size; 90 | // store in shared memory: 91 | // best action position for each state1 position (1x int) 92 | // best state2 position for each state1 position (1x int) 93 | // best score for each state1 position (1x float) 94 | // action tensor (1x float) 95 | // state tensor (2x float) 96 | const int shared_mem = video_len * (2 * sizeof(int) + 4 * sizeof(float)); 97 | SingleStateChangeIndicesKernel<<>>( 98 | state_tensor.data_ptr(), 99 | action_tensor.data_ptr(), 100 | lens.data_ptr(), 101 | out_indices.data_ptr(), 102 | max_action_state_distance); 103 | 104 | return out_indices; 105 | } 106 | -------------------------------------------------------------------------------- /download_requirements.sh: -------------------------------------------------------------------------------- 1 | echo -n "Downloading and patching clip repository ... " 2 | git clone --quiet https://github.com/openai/CLIP.git || exit 1 3 | cd CLIP || exit 1 4 | git checkout --quiet d50d76daa670286dd6cacf3bcd80b5e4823fc8e1 || exit 1 5 | sed -i /self.proj/d clip/model.py || exit 1 6 | cd .. 7 | echo "OK ✓" 8 | 9 | echo -n "Downloading clip weights ... " 10 | mkdir -p weights 11 | wget https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt -q -O weights/ViT-L-14.pt 12 | SHA256SUM=$(sha256sum weights/ViT-L-14.pt | cut -d' ' -f1) 13 | 14 | if [[ ${SHA256SUM} == "b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836" ]]; then 15 | echo "OK ✓" 16 | else 17 | echo "ERROR ✗" 18 | exit 1 19 | fi 20 | 21 | echo -n "Downloading ChangeIt annotations ... " 22 | mkdir -p videos 23 | git clone --quiet https://github.com/soCzech/ChangeIt.git || exit 1 24 | echo "OK ✓" 25 | 26 | echo "To replicate our experiments, please download ChangeIt videos into \`videos/*category*\` folders." 27 | echo "More details on how to download the videos at https://github.com/soCzech/ChangeIt." 28 | echo "If you wish to train the model on your data, please see the README file." 29 | -------------------------------------------------------------------------------- /method/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | import torch 5 | import ffmpeg 6 | import random 7 | import itertools 8 | import numpy as np 9 | import torch.distributed as dist 10 | from torch.utils.data import Dataset 11 | from typing import TypeVar, Optional, Iterator 12 | 13 | import imgaug as ia 14 | import imgaug.augmenters as iaa 15 | 16 | 17 | def extract_frames(video_path, fps, size=None, crop=None, start=None, duration=None): 18 | if start is not None: 19 | cmd = ffmpeg.input(video_path, ss=start, t=duration) 20 | else: 21 | cmd = ffmpeg.input(video_path) 22 | 23 | if size is None: 24 | info = [s for s in ffmpeg.probe(video_path)["streams"] if s["codec_type"] == "video"][0] 25 | size = (info["width"], info["height"]) 26 | elif isinstance(size, int): 27 | size = (size, size) 28 | 29 | if fps is not None: 30 | cmd = cmd.filter('fps', fps=fps) 31 | cmd = cmd.filter('scale', size[0], size[1]) 32 | 33 | if crop is not None: 34 | cmd = cmd.filter('crop', f'in_w-{crop[0]}', f'in_h-{crop[1]}') 35 | size = (size[0] - crop[0], size[1] - crop[1]) 36 | 37 | for i in range(5): 38 | try: 39 | out, _ = ( 40 | cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24') 41 | .run(capture_stdout=True, quiet=True) 42 | ) 43 | break 44 | except Exception as e: 45 | time.sleep(random.random() * 5.) 46 | if i < 4: 47 | continue 48 | print(f"W: FFMPEG file {video_path} read failed!", flush=True) 49 | if isinstance(e, ffmpeg.Error): 50 | print("STDOUT:", e.stdout, flush=True) 51 | print("STDERR:", e.stderr, flush=True) 52 | raise 53 | 54 | video = np.frombuffer(out, np.uint8).reshape([-1, size[1], size[0], 3]) 55 | return video 56 | 57 | 58 | class ChangeItVideoDataset(Dataset): 59 | 60 | def __init__(self, 61 | video_roots, 62 | annotation_root=None, 63 | file_mode="unannotated", # "unannotated", "annotated", "all" 64 | noise_adapt_weight_root=None, 65 | noise_adapt_weight_threshold_file=None, 66 | augment=False): 67 | 68 | self.classes = {x: i for i, x in enumerate(sorted(set([os.path.basename(fn) for fn in itertools.chain(*[ 69 | glob.glob(os.path.join(root, "*")) for root in video_roots 70 | ]) if os.path.isdir(fn)])))} 71 | 72 | self.files = {key: sorted(itertools.chain(*[ 73 | glob.glob(os.path.join(root, key, "*.mp4")) + glob.glob(os.path.join(root, key, "*.webm")) for root in 74 | video_roots 75 | ])) for key in self.classes.keys()} 76 | 77 | self.annotations = {key: { 78 | os.path.basename(fn).split(".")[0]: np.uint8( 79 | [int(line.strip().split(",")[1]) for line in open(fn).readlines()]) 80 | for fn in glob.glob(os.path.join(annotation_root, key, "*.csv")) 81 | } for key in self.classes.keys()} if annotation_root is not None else None 82 | 83 | if file_mode == "unannotated": 84 | for key in self.classes.keys(): 85 | for fn in self.files[key].copy(): 86 | if os.path.basename(fn).split(".")[0] in self.annotations[key]: 87 | self.files[key].remove(fn) 88 | elif file_mode == "annotated": 89 | for key in self.classes.keys(): 90 | for fn in self.files[key].copy(): 91 | if os.path.basename(fn).split(".")[0] not in self.annotations[key]: 92 | self.files[key].remove(fn) 93 | elif file_mode == "all": 94 | pass 95 | else: 96 | raise NotImplementedError() 97 | 98 | self.flattened_files = [] 99 | for key in self.classes.keys(): 100 | self.flattened_files.extend([(key, fn) for fn in self.files[key]]) 101 | 102 | self.augment = augment 103 | 104 | # Noise adaptive weighting 105 | if noise_adapt_weight_root is None: 106 | return 107 | 108 | self.noise_adapt_weight = {} 109 | for key in self.classes.keys(): 110 | with open(os.path.join(noise_adapt_weight_root, f"{key}.csv"), "r") as f: 111 | for line in f.readlines(): 112 | vid_id, score = line.strip().split(",") 113 | self.noise_adapt_weight[f"{key}/{vid_id}"] = float(score) 114 | 115 | self.noise_adapt_weight_thr = {line.split(",")[0]: float(line.split(",")[2].strip()) 116 | for line in open(noise_adapt_weight_threshold_file, "r").readlines()[1:]} 117 | 118 | def __getitem__(self, idx): 119 | class_name, video_fn = self.flattened_files[idx] 120 | file_id = os.path.basename(video_fn).split(".")[0] 121 | 122 | video_frames = extract_frames(video_fn, fps=1, size=(398, 224)).copy() 123 | 124 | if self.augment: 125 | video_frames = ChangeItVideoDataset.augment_fc(video_frames) 126 | video_frames = video_frames[:, :, random.randint(0, 398 - 224 - 1):][:, :, :224] 127 | else: 128 | video_frames = video_frames[:, :, (398 - 224) // 2:][:, :, :224] 129 | video_frames = torch.from_numpy(video_frames.copy()) 130 | 131 | annotation = self.annotations[class_name][file_id] \ 132 | if self.annotations is not None and file_id in self.annotations[class_name] else None 133 | video_level_score = self.noise_adapt_weight[f"{class_name}/{file_id}"] - self.noise_adapt_weight_thr[class_name] \ 134 | if hasattr(self, "noise_adapt_weight") else None 135 | 136 | return class_name + "/" + file_id, self.classes[class_name], video_frames, annotation, video_level_score 137 | 138 | @property 139 | def n_classes(self): 140 | return len(self.classes) 141 | 142 | def __len__(self): 143 | return len(self.flattened_files) 144 | 145 | def __repr__(self): 146 | string = f"ChangeItVideoDataset(n_classes: {self.n_classes}, n_samples: {self.__len__()}, " \ 147 | f"augment: {self.augment})" 148 | for key in sorted(self.classes.keys()): 149 | string += f"\n> {key:20} {len(self.files[key]):4d}" 150 | if hasattr(self, "noise_adapt_weight_thr"): 151 | len_ = len([ 152 | fn for fn in self.files[key] 153 | if self.noise_adapt_weight[f"{key}/{os.path.basename(fn).split('.')[0]}"] > 154 | self.noise_adapt_weight_thr[key] 155 | ]) 156 | string += f" (above threshold {self.noise_adapt_weight_thr[key]:.3f}: {len_:4d})" 157 | return string 158 | 159 | @staticmethod 160 | def augment_fc(video_frames): 161 | seq = iaa.Sequential( 162 | [ 163 | # apply the following augmenters to most images 164 | iaa.Fliplr(0.5), # horizontally flip 50% of all images 165 | iaa.Flipud(0.1), # vertically flip 10% of all images 166 | # crop images by -5% to 10% of their height/width 167 | iaa.Sometimes(0.5, iaa.CropAndPad( 168 | percent=(-0.05, 0.1), 169 | pad_mode=ia.ALL, 170 | pad_cval=(0, 255) 171 | )), 172 | iaa.Sometimes(0.5, iaa.Affine( 173 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, 174 | # scale images to 80-120% of their size, individually per axis 175 | translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, 176 | # translate by -20 to +20 percent (per axis) 177 | rotate=(-15, 15), # rotate by -15 to +15 degrees 178 | shear=(-16, 16), # shear by -16 to +16 degrees 179 | order=[0, 1], # use nearest neighbour or bilinear interpolation (fast) 180 | cval=(0, 255), # if mode is constant, use a cval between 0 and 255 181 | mode=ia.ALL # use any of scikit-image's warping modes (see 2nd image from the top for examples) 182 | )), 183 | # execute 0 to 4 of the following (less important) augmenters per image 184 | # don't execute all of them, as that would often be way too strong 185 | iaa.SomeOf((0, 4), [ 186 | iaa.OneOf([ 187 | iaa.GaussianBlur((0, 3.0)), # blur images with a sigma between 0 and 3.0 188 | iaa.AverageBlur(k=(2, 3)), 189 | # blur image using local means with kernel sizes between 2 and 7 190 | iaa.MedianBlur(k=(3, 5)), 191 | # blur image using local medians with kernel sizes between 3 and 5 192 | ]), 193 | iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5)), # sharpen images 194 | iaa.Add((-10, 10), per_channel=0.5), 195 | # change brightness of images (by -10 to 10 of original value) 196 | iaa.AddToHueAndSaturation((-20, 20)), # change hue and saturation 197 | ], random_order=True) 198 | ], 199 | random_order=True 200 | ) 201 | seq_det = seq.to_deterministic() 202 | 203 | video_frames_augmented = np.empty_like(video_frames) 204 | for i in range(len(video_frames)): 205 | video_frames_augmented[i] = seq_det.augment_image(video_frames[i]) 206 | return video_frames_augmented 207 | 208 | 209 | def identity_collate(items): 210 | return items 211 | 212 | 213 | T_co = TypeVar('T_co', covariant=True) 214 | 215 | 216 | class DistributedDropFreeSampler: 217 | 218 | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, 219 | rank: Optional[int] = None, shuffle: bool = True, 220 | seed: int = 0) -> None: 221 | if num_replicas is None: 222 | if not dist.is_available(): 223 | raise RuntimeError("Requires distributed package to be available") 224 | num_replicas = dist.get_world_size() 225 | if rank is None: 226 | if not dist.is_available(): 227 | raise RuntimeError("Requires distributed package to be available") 228 | rank = dist.get_rank() 229 | if rank >= num_replicas or rank < 0: 230 | raise ValueError( 231 | "Invalid rank {}, rank should be in the interval" 232 | " [0, {}]".format(rank, num_replicas - 1)) 233 | self.dataset = dataset 234 | self.num_replicas = num_replicas 235 | self.rank = rank 236 | self.epoch = 0 237 | 238 | self.num_samples = len(self.dataset) // self.num_replicas # type: ignore[arg-type] 239 | if self.num_samples * self.num_replicas < len(self.dataset): # type: ignore[arg-type] 240 | if self.rank < len(self.dataset) % self.num_replicas: # type: ignore[arg-type] 241 | self.num_samples += 1 242 | self.total_size = len(self.dataset) # type: ignore[arg-type] 243 | self.shuffle = shuffle 244 | self.seed = seed 245 | 246 | def __iter__(self) -> Iterator[T_co]: 247 | if self.shuffle: 248 | # deterministically shuffle based on epoch and seed 249 | g = torch.Generator() 250 | g.manual_seed(self.seed + self.epoch) 251 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] 252 | else: 253 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 254 | assert len(indices) == self.total_size 255 | 256 | # subsample 257 | indices = indices[self.rank:self.total_size:self.num_replicas] 258 | assert len(indices) == self.num_samples 259 | 260 | return iter(indices) 261 | 262 | def __len__(self) -> int: 263 | return self.num_samples 264 | 265 | def set_epoch(self, epoch: int) -> None: 266 | self.epoch = epoch 267 | -------------------------------------------------------------------------------- /method/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | sys.path.append("./CLIP") 4 | from clip.model import VisionTransformer 5 | 6 | 7 | class ClassifierHeads(torch.nn.Module): 8 | 9 | def __init__(self, layers, n_classes=1): 10 | super(ClassifierHeads, self).__init__() 11 | 12 | self.state = torch.nn.ModuleDict({ 13 | f"l{i:d}": torch.nn.Linear(layers[i], layers[i + 1]) 14 | for i in range(len(layers) - 1) 15 | }) 16 | self.action = torch.nn.ModuleDict({ 17 | f"l{i:d}": torch.nn.Linear(layers[i], layers[i + 1]) 18 | for i in range(len(layers) - 1) 19 | }) 20 | 21 | self.state_layer = torch.nn.Linear(layers[-1], 2 * n_classes + 1, bias=True) 22 | self.action_layer = torch.nn.Linear(layers[-1], 1 * n_classes + 1, bias=True) 23 | 24 | def forward(self, inputs): 25 | x = inputs 26 | for i in range(len(self.state)): 27 | x = self.state[f"l{i:d}"](x) 28 | x = torch.relu(x) 29 | state = self.state_layer(x) 30 | 31 | x = inputs 32 | for i in range(len(self.action)): 33 | x = self.action[f"l{i:d}"](x) 34 | x = torch.relu(x) 35 | action = self.action_layer(x) 36 | 37 | return {"state": state, "action": action} 38 | 39 | 40 | class ClipClassifier(torch.nn.Module): 41 | 42 | def __init__(self, hidden_mlp_layers, params, n_classes=1, train_backbone=True): 43 | super(ClipClassifier, self).__init__() 44 | 45 | self.args = params 46 | self.n_classes = n_classes 47 | self.hidden_mlp_layers = hidden_mlp_layers 48 | 49 | if "visual.conv1.weight" in params: 50 | vision_width = params["visual.conv1.weight"].shape[0] 51 | vision_heads = vision_width // 64 52 | vision_layers = len([k for k in params.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 53 | vision_patch_size = params["visual.conv1.weight"].shape[-1] 54 | grid_size = round((params["visual.positional_embedding"].shape[0] - 1) ** 0.5) 55 | image_resolution = vision_patch_size * grid_size 56 | self.args = dict( 57 | input_resolution=image_resolution, 58 | patch_size=vision_patch_size, 59 | width=vision_width, 60 | layers=vision_layers, 61 | heads=vision_heads, 62 | output_dim=None 63 | ) 64 | 65 | self.backbone = VisionTransformer(**self.args) 66 | 67 | if "visual.conv1.weight" in params: 68 | self.backbone.load_state_dict({ 69 | k[len("visual."):]: v for k, v in params.items() 70 | if k.startswith("visual.") and k != "visual.proj" 71 | }) 72 | 73 | if not train_backbone: 74 | for param in self.backbone.parameters(): 75 | param.requires_grad_(False) 76 | 77 | self.heads = ClassifierHeads([self.args["width"]] + hidden_mlp_layers, n_classes) 78 | 79 | @staticmethod 80 | def preprocess(imgs): 81 | # CLIP image preprocessing 82 | imgs = imgs.permute((0, 3, 1, 2)).float().div_(255) 83 | mean = torch.as_tensor((0.48145466, 0.4578275, 0.40821073), 84 | dtype=torch.float32, device=imgs.device).view(1, -1, 1, 1) 85 | std = torch.as_tensor((0.26862954, 0.26130258, 0.27577711), 86 | dtype=torch.float32, device=imgs.device).view(1, -1, 1, 1) 87 | return imgs.sub_(mean).div_(std) 88 | 89 | def forward(self, inputs): 90 | x = self.preprocess(inputs) 91 | x = self.backbone(x) 92 | x = self.heads(x) 93 | return x 94 | -------------------------------------------------------------------------------- /method/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import math 4 | import torch 5 | import pickle 6 | import numpy as np 7 | 8 | 9 | def select_correct_classes(predictions, classes, n_classes=1): 10 | if n_classes == 1: 11 | return predictions 12 | # predictions: B, L, n_classes * third_dim 13 | B, L, dim = predictions.shape 14 | third_dim = dim // n_classes 15 | 16 | x = torch.arange(0, B).view(-1, 1, 1).repeat(1, L, third_dim) 17 | y = torch.arange(0, L).view(1, -1, 1).repeat(B, 1, third_dim) 18 | z = classes.view(-1, 1, 1).repeat(1, L, third_dim) * third_dim + \ 19 | torch.arange(0, third_dim, device=classes.device).view(1, 1, third_dim) 20 | 21 | return predictions[x, y, z] # B, L, third_dim 22 | 23 | 24 | def constrained_argmax(pred_action, pred_state, also_only_state=False): 25 | max_val_, best_idx_ = -1, (0, 0, 0) 26 | for i in range(len(pred_state)): 27 | for j in range(i + 2, len(pred_state)): 28 | val_ = pred_state[i, 0] * pred_state[j, 1] 29 | k = np.argmax(pred_action[i + 1:j]) 30 | val_ *= pred_action[i + 1 + k] 31 | if val_ > max_val_: 32 | best_idx_ = i, j, i + 1 + k 33 | max_val_ = val_ 34 | if not also_only_state: 35 | return best_idx_ 36 | 37 | max_val_state_, best_idx_state_ = -1, (0, 0) 38 | for i in range(len(pred_state)): 39 | for j in range(i + 1, len(pred_state)): 40 | val_ = pred_state[i, 0] * pred_state[j, 1] 41 | if val_ > max_val_state_: 42 | best_idx_state_ = i, j 43 | max_val_state_ = val_ 44 | return best_idx_, best_idx_state_ 45 | 46 | 47 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1): 48 | 49 | def lr_lambda(current_step): 50 | if current_step < num_warmup_steps: 51 | return float(current_step) / float(max(1, num_warmup_steps)) 52 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 53 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 54 | 55 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 56 | 57 | 58 | class AverageMeter: 59 | 60 | def __init__(self): 61 | self.reset() 62 | 63 | def reset(self): 64 | self._sum = 0 65 | self._count = 0 66 | 67 | def update(self, val, n=1): 68 | self._sum += val * n 69 | self._count += n 70 | 71 | @property 72 | def value(self): 73 | if self._count == 0: 74 | return 0 75 | return self._sum / self._count 76 | 77 | 78 | class JointMeter: 79 | def __init__(self, n_classes): 80 | self._dict = { 81 | "sp": [[] for _ in range(n_classes)], 82 | "ap": [[] for _ in range(n_classes)], 83 | "jsp": [[] for _ in range(n_classes)], 84 | "jap": [[] for _ in range(n_classes)], 85 | "acc": [[] for _ in range(n_classes)] 86 | } 87 | 88 | def log(self, pred_action, pred_state, annotations, category): 89 | assert len(annotations) == len(pred_action) == len(pred_state) 90 | 91 | # state accuracy 92 | pred_state_idx = np.argmax(pred_state, axis=-1) 93 | n_gt_states = np.logical_or(annotations == 1, annotations == 3).sum() 94 | state_acc = ((pred_state_idx[annotations == 1] == 0).sum() + 95 | (pred_state_idx[annotations == 3] == 1).sum()) / n_gt_states if n_gt_states > 0 else 0. 96 | if n_gt_states > 0: 97 | self._dict["acc"][category].append(state_acc) 98 | 99 | # action precision 100 | self._dict["ap"][category].append(1. if annotations[np.argmax(pred_action)] == 2 else 0.) 101 | 102 | # state and joint precision 103 | joint, state_only = constrained_argmax(pred_action, pred_state, also_only_state=True) 104 | self._dict["sp"][category].append((0.5 if annotations[state_only[0]] == 1 else 0.0) + \ 105 | (0.5 if annotations[state_only[1]] == 3 else 0.0)) 106 | self._dict["jap"][category].append(1. if annotations[joint[2]] == 2 else 0.) 107 | self._dict["jsp"][category].append((0.5 if annotations[joint[0]] == 1 else 0.0) + \ 108 | (0.5 if annotations[joint[1]] == 3 else 0.0)) 109 | 110 | def __getattr__(self, item): 111 | if item in self._dict: 112 | return np.mean([np.mean(x) for x in self._dict[item]]) * 100 113 | raise NotImplementedError() 114 | 115 | def __getitem__(self, item): 116 | return [np.mean(self._dict[k][item]) * 100 for k in ["acc", "sp", "jsp", "ap", "jap"]] 117 | 118 | def dump(self, path, global_id): 119 | with open(f"{path}.{global_id}.pickle", "wb") as f: 120 | pickle.dump(self._dict, f) 121 | 122 | def load(self, path): 123 | n_classes = len(self._dict["sp"]) 124 | self._dict = { 125 | "sp": [[] for _ in range(n_classes)], 126 | "ap": [[] for _ in range(n_classes)], 127 | "jsp": [[] for _ in range(n_classes)], 128 | "jap": [[] for _ in range(n_classes)], 129 | "acc": [[] for _ in range(n_classes)] 130 | } 131 | for fn in glob.glob(f"{path}.*.pickle"): 132 | dict_ = pickle.load(open(fn, "rb")) 133 | for k in self._dict.keys(): 134 | for i in range(n_classes): 135 | self._dict[k][i] += dict_[k][i] 136 | 137 | for fn in glob.glob(f"{path}.*.pickle"): 138 | os.remove(fn) 139 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import tqdm 5 | import torch 6 | import random 7 | import socket 8 | import argparse 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | 13 | import lookforthechange 14 | from method.dataset import ChangeItVideoDataset, identity_collate, DistributedDropFreeSampler 15 | from method.model import ClipClassifier 16 | from method.utils import get_cosine_schedule_with_warmup, AverageMeter, select_correct_classes, JointMeter 17 | 18 | cv2.setNumThreads(1) # do not spawn multiple threads for augmentation (ffmpeg then raises an error) 19 | 20 | 21 | def main(args): 22 | ngpus_per_node = torch.cuda.device_count() 23 | node_count = int(os.environ.get("SLURM_NPROCS", "1")) 24 | node_rank = int(os.environ.get("SLURM_PROCID", "0")) 25 | job_id = os.environ.get("SLURM_JOBID", "0") 26 | 27 | if node_count == 1: # for PBS/PMI clusters 28 | node_count = int(os.environ.get("PMI_SIZE", "1")) 29 | node_rank = int(os.environ.get("PMI_RANK", "0")) 30 | job_id = os.environ.get("PBS_JOBID", "".join([str(random.randint(0, 9)) for _ in range(5)])) 31 | 32 | dist_url = "file://{}.{}".format(os.path.realpath("distfile"), job_id) 33 | print(f"Hi from node {socket.gethostname()} ({node_rank}/{node_count} with {ngpus_per_node} GPUs)!", flush=True) 34 | 35 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=({ 36 | "ngpus_per_node": ngpus_per_node, 37 | "node_count": node_count, 38 | "node_rank": node_rank, 39 | "dist_url": dist_url, 40 | "job_id": job_id 41 | }, args)) 42 | 43 | 44 | def main_worker(local_rank, cluster_args, args): 45 | world_size = cluster_args["node_count"] * cluster_args["ngpus_per_node"] 46 | global_rank = cluster_args["node_rank"] * cluster_args["ngpus_per_node"] + local_rank 47 | dist.init_process_group( 48 | backend="nccl", 49 | init_method=cluster_args["dist_url"], 50 | world_size=world_size, 51 | rank=global_rank, 52 | ) 53 | 54 | if global_rank == 0: 55 | for k, v in sorted(vars(args).items(), key=lambda x: x[0]): 56 | print(f"# {k}: {v}") 57 | print(f"# effective_batch_size: {world_size * args.local_batch_size}", flush=True) 58 | 59 | ############### 60 | # DATASET 61 | ############### 62 | train_ds = ChangeItVideoDataset( 63 | video_roots=args.video_roots, annotation_root=os.path.join(args.dataset_root, "annotations"), 64 | file_mode="unannotated", noise_adapt_weight_root=None if args.ignore_video_weight else os.path.join(args.dataset_root, "videos"), 65 | noise_adapt_weight_threshold_file=None if args.ignore_video_weight else os.path.join(args.dataset_root, "categories.csv"), augment=args.augment 66 | ) 67 | test_ds = ChangeItVideoDataset( 68 | video_roots=args.video_roots, annotation_root=os.path.join(args.dataset_root, "annotations"), 69 | file_mode="annotated", noise_adapt_weight_root=None if args.ignore_video_weight else os.path.join(args.dataset_root, "videos"), 70 | noise_adapt_weight_threshold_file=None if args.ignore_video_weight else os.path.join(args.dataset_root, "categories.csv"), augment=False 71 | ) 72 | 73 | if global_rank == 0: 74 | print(train_ds, test_ds, sep="\n", flush=True) 75 | 76 | train_sampler = torch.utils.data.distributed.DistributedSampler( 77 | train_ds, shuffle=True, drop_last=True) if world_size > 1 else None 78 | train_ds_iter = torch.utils.data.DataLoader( 79 | train_ds, batch_size=args.local_batch_size, shuffle=world_size == 1, drop_last=True, num_workers=2, 80 | pin_memory=True, sampler=train_sampler, collate_fn=identity_collate) 81 | 82 | test_sampler = DistributedDropFreeSampler(test_ds, shuffle=False) if world_size > 1 else None 83 | test_ds_iter = torch.utils.data.DataLoader( 84 | test_ds, batch_size=1, shuffle=False, drop_last=False, num_workers=2, 85 | pin_memory=True, sampler=test_sampler, collate_fn=identity_collate) 86 | 87 | ############### 88 | # MODEL 89 | ############### 90 | weights = torch.jit.load(args.clip_weights, map_location="cpu").state_dict() 91 | model = ClipClassifier(hidden_mlp_layers=[4096], 92 | params=weights, 93 | n_classes=train_ds.n_classes, 94 | train_backbone=args.train_backbone) 95 | assert model.backbone.input_resolution == 224 96 | 97 | torch.cuda.set_device(local_rank) 98 | model.cuda(local_rank) 99 | model_parallel = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) 100 | 101 | ############### 102 | # OPTIMIZER 103 | ############### 104 | head_params = model_parallel.module.heads.parameters() 105 | optim_head = torch.optim.SGD(head_params, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 106 | scheduler_head = get_cosine_schedule_with_warmup(optim_head, 5 * len(train_ds_iter), len(train_ds_iter) * args.n_epochs) 107 | 108 | if args.train_backbone: 109 | backbone_params = model_parallel.module.backbone.parameters() 110 | optim_backbone = torch.optim.AdamW(backbone_params, lr=args.lr_backbone, weight_decay=args.weight_decay_backbone) 111 | scheduler_backbone = get_cosine_schedule_with_warmup(optim_backbone, 5 * len(train_ds_iter), len(train_ds_iter) * args.n_epochs) 112 | 113 | ############### 114 | # TRAINING 115 | ############### 116 | n_frames_per_gt = args.n_frames_per_gt 117 | kappa_dist = 60 118 | 119 | loss_metric = AverageMeter() 120 | loss_norm_metric = AverageMeter() 121 | unsup_state_loss_metric = AverageMeter() 122 | unsup_action_loss_metric = AverageMeter() 123 | 124 | for epoch in range(1, args.n_epochs + 1): 125 | if world_size > 1: train_sampler.set_epoch(epoch) 126 | loss_metric.reset() 127 | loss_norm_metric.reset() 128 | unsup_state_loss_metric.reset() 129 | unsup_action_loss_metric.reset() 130 | 131 | iterator = tqdm.tqdm(train_ds_iter) if global_rank == 0 else train_ds_iter 132 | for batch in iterator: # id, class, video, annotation/None, weight 133 | 134 | optim_head.zero_grad() 135 | if args.train_backbone: optim_backbone.zero_grad() 136 | 137 | # COMPUTE GT FOR ALL VIDEOS IN BATCH 138 | batch_for_training = [] 139 | for _, class_, inputs, _, weight in batch: 140 | classes = torch.LongTensor([class_]) 141 | 142 | # PREDICT 143 | with torch.no_grad(): 144 | predictions = [] 145 | for i in range(0, len(inputs), 256): 146 | predictions += [model(inputs[i:i + 256].cuda(local_rank))] 147 | predictions = { 148 | "state": torch.cat([p["state"] for p in predictions], dim=0), 149 | "action": torch.cat([p["action"] for p in predictions], dim=0) 150 | } 151 | 152 | st_probs = select_correct_classes( 153 | torch.softmax(predictions["state"].unsqueeze(0), -1), classes, n_classes=train_ds.n_classes) 154 | ac_probs = select_correct_classes( 155 | torch.softmax(predictions["action"].unsqueeze(0), -1), classes, n_classes=train_ds.n_classes + 1) 156 | 157 | # COMPUTE GROUND TRUTH 158 | indices = lookforthechange.optimal_state_change_indices( 159 | st_probs, ac_probs, lens=torch.tensor([st_probs.shape[1]], dtype=torch.int32, device=st_probs.device)) 160 | indices = indices.view(1, 3).cpu() # [S1idx, S2idx, ACidx] 161 | 162 | positives = indices.repeat(n_frames_per_gt, 1) + \ 163 | torch.arange(-(n_frames_per_gt // 2), (n_frames_per_gt // 2) + 1, 1, device=indices.device).unsqueeze_(1) 164 | indices_extended = torch.cat([ 165 | positives.transpose(1, 0).reshape(-1), positives[:, 2] - kappa_dist, positives[:, 2] + kappa_dist 166 | ], 0).clamp_(0, len(inputs) - 1) 167 | # [ S1idx - 1, S1idx, S1idx + 1, 168 | # S2idx - 1, S2idx, S2idx + 1, 169 | # ACidx - 1, ACidx, ACidx + 1, 170 | # ACidx - 61, ACidx - 60, ACidx - 59, 171 | # ACidx + 59, ACidx + 60, ACidx + 61] 172 | 173 | bg_class_index = train_ds.n_classes 174 | action_targets = torch.LongTensor([bg_class_index] * n_frames_per_gt * 2 + 175 | [class_] * n_frames_per_gt + 176 | [bg_class_index] * n_frames_per_gt * 2) 177 | # [ BG, BG, BG, 178 | # BG, BG, BG, 179 | # CLS, CLS, CLS, 180 | # BG, BG, BG, 181 | # BG, BG, BG] 182 | 183 | bg_class_index = train_ds.n_classes * 2 184 | state_targets = torch.LongTensor([class_ * 2 + 0] * n_frames_per_gt + 185 | [class_ * 2 + 1] * n_frames_per_gt + 186 | [bg_class_index] * n_frames_per_gt + 187 | [bg_class_index] * n_frames_per_gt * 2) 188 | state_target_mask = torch.FloatTensor([1.] * n_frames_per_gt * 3 + [0.] * n_frames_per_gt * 2) 189 | # [ CS1, CS1, CS1, 190 | # CS2, CS2, CS2, 191 | # BG, BG, BG, 192 | # *, *, *, 193 | # *, *, *] 194 | 195 | batch_for_training.append(( 196 | inputs[indices_extended], action_targets, state_targets, state_target_mask, weight)) 197 | 198 | # FORWARD + BACKWARD PASS 199 | predictions = model_parallel(torch.cat([x[0] for x in batch_for_training], 0).cuda(local_rank)) 200 | 201 | if batch_for_training[0][4] is None: 202 | video_loss_weight = torch.FloatTensor([1. for _ in batch_for_training]).view(-1, 1).cuda(local_rank) 203 | else: 204 | video_loss_weight = torch.FloatTensor([x[4] for x in batch_for_training]) * (-1 / 0.001) 205 | video_loss_weight = 1 / (1 + torch.exp(video_loss_weight)) 206 | video_loss_weight = video_loss_weight.view(-1, 1).cuda(local_rank) 207 | 208 | state_gt = torch.cat([x[2] for x in batch_for_training], 0).cuda(local_rank) 209 | state_gt_mask = torch.cat([x[3] for x in batch_for_training], 0).cuda(local_rank) 210 | action_gt = torch.cat([x[1] for x in batch_for_training], 0).cuda(local_rank) 211 | 212 | state_loss = F.cross_entropy(predictions["state"], state_gt, reduction="none") * state_gt_mask 213 | state_loss = state_loss.view(-1, n_frames_per_gt * 5) * video_loss_weight 214 | action_loss = F.cross_entropy(predictions["action"], action_gt, reduction="none") 215 | action_loss = action_loss.view(-1, n_frames_per_gt * 5) * video_loss_weight 216 | 217 | state_loss = torch.sum(state_loss) 218 | action_loss = 0.2 * torch.sum(action_loss) 219 | loss = state_loss + action_loss 220 | 221 | # DistributedDataParallel does gradient averaging, i.e. loss is x-times smaller than in Look for the Change. 222 | # When training with frozen backbone, make it somewhat equivalent to the Look for the Change setup. 223 | if not args.train_backbone: 224 | loss = loss * world_size 225 | loss.backward() 226 | 227 | optim_head.step() 228 | scheduler_head.step() 229 | if args.train_backbone: 230 | optim_backbone.step() 231 | scheduler_backbone.step() 232 | 233 | loss_metric.update(loss.item(), len(batch_for_training)) 234 | unsup_state_loss_metric.update(state_loss.item(), len(batch_for_training)) 235 | unsup_action_loss_metric.update(action_loss.item(), len(batch_for_training)) 236 | 237 | ############### 238 | # VALIDATION 239 | ############### 240 | joint_meter = JointMeter(train_ds.n_classes) 241 | for batch in test_ds_iter: 242 | _, class_, inputs, annot, _ = batch[0] 243 | classes = torch.LongTensor([class_]) 244 | 245 | with torch.no_grad(): 246 | predictions = model(inputs.cuda(local_rank)) 247 | st_probs = select_correct_classes( 248 | torch.softmax(predictions["state"].unsqueeze(0), -1), classes, n_classes=train_ds.n_classes) 249 | ac_probs = select_correct_classes( 250 | torch.softmax(predictions["action"].unsqueeze(0), -1), classes, n_classes=train_ds.n_classes + 1) 251 | 252 | joint_meter.log(ac_probs[0, :, 0].cpu().numpy(), st_probs[0].cpu().numpy(), annot, category=class_) 253 | 254 | vallog_fn = "{}.{}".format(os.path.realpath("vallog"), cluster_args["job_id"]) 255 | joint_meter.dump(vallog_fn, global_rank) 256 | dist.barrier() 257 | 258 | if global_rank == 0: 259 | time.sleep(10) 260 | joint_meter.load(vallog_fn) 261 | dir_ = f'logs/{cluster_args["job_id"]}' 262 | os.makedirs(dir_, exist_ok=True) 263 | torch.save({"state_dict": model_parallel.state_dict(), "args": model.args, "n_classes": model.n_classes, 264 | "hidden_mlp_layers": model.hidden_mlp_layers}, f"{dir_}/epoch{epoch:03d}.pth") 265 | 266 | print(f"Epoch {epoch} (" 267 | f"T loss: {loss_metric.value:.3f}, " 268 | f"T lr: {scheduler_head.get_last_lr()[0]:.6f}, " 269 | f"T grad norm: {loss_norm_metric.value:.1f}, " 270 | f"T unsup state loss: {unsup_state_loss_metric.value:.3f}, " 271 | f"T unsup action loss: {unsup_action_loss_metric.value:.3f}, " 272 | f"V state acc: {joint_meter.acc:.1f}%, " 273 | f"V state prec: {joint_meter.sp:.1f}%, " 274 | f"V state joint prec: {joint_meter.jsp:.1f}%, " 275 | f"V action prec: {joint_meter.ap:.1f}%, " 276 | f"V action joint prec: {joint_meter.jap:.1f}%)", flush=True) 277 | 278 | print("> {:20} {:>6} {:>6} {:>6} {:>6} {:>6}".format("CATEGORY", "SAcc", "SP", "JtSP", "AP", "JtAP")) 279 | print("\n".join([ 280 | "> {:20}{:6.1f}%{:6.1f}%{:6.1f}%{:6.1f}%{:6.1f}%".format(cls_name, *joint_meter[train_ds.classes[cls_name]]) 281 | for cls_name in sorted(train_ds.classes.keys()) 282 | ]), flush=True) 283 | 284 | 285 | if __name__ == '__main__': 286 | parser = argparse.ArgumentParser() 287 | parser.add_argument("--video_roots", type=str, nargs="+", default=["./videos"]) 288 | parser.add_argument("--dataset_root", type=str, default="./ChangeIt") 289 | parser.add_argument("--lr", default=0.0001, type=float) 290 | parser.add_argument("--lr_backbone", default=0.00001, type=float) 291 | parser.add_argument("--weight_decay", type=float, default=0.001) 292 | parser.add_argument("--weight_decay_backbone", type=float, default=0.) 293 | parser.add_argument("--train_backbone", action="store_true") 294 | parser.add_argument("--n_frames_per_gt", type=int, default=3) 295 | parser.add_argument("--local_batch_size", type=int, default=2) 296 | parser.add_argument("--clip_weights", type=str, default="./weights/ViT-L-14.pt") 297 | parser.add_argument("--n_epochs", type=int, default=20) 298 | parser.add_argument("--augment", action="store_true") 299 | parser.add_argument("--ignore_video_weight", action="store_true") 300 | main(parser.parse_args()) 301 | --------------------------------------------------------------------------------