├── .gitignore
├── Dockerfile
├── LICENCE
├── README.md
├── cuda_ops
├── python
│ └── __init__.py
├── setup.py
└── src
│ ├── common.cpp
│ ├── common.hpp
│ ├── optimal_state_change.cu
│ └── optimal_state_change_indices.cu
├── download_requirements.sh
├── method
├── dataset.py
├── model.py
└── utils.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .ipynb_checkpoints
3 | CLIP
4 | weights
5 | videos
6 | ChangeIt
7 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel
2 | # using pytorch:1.12.0-cuda11.3-cudnn8-devel results in training being 2x slower for some weird reason
3 |
4 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
5 |
6 | RUN apt-get update \
7 | && apt-get install ffmpeg wget git -y
8 |
9 | RUN pip install \
10 | opencv-python \
11 | pillow \
12 | matplotlib \
13 | scikit-learn \
14 | scipy \
15 | tqdm \
16 | pandas \
17 | ffmpeg-python \
18 | ftfy \
19 | regex \
20 | imgaug
21 |
22 | ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
23 |
24 | COPY cuda_ops /tmp
25 |
26 | RUN cd /tmp \
27 | && TORCH_CUDA_ARCH_LIST="6.1;7.0;7.5;8.0;8.6" python setup.py install \
28 | && rm -rf *
29 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Tomáš Souček
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-Task Learning of Object States and State-Modifying Actions from Web Videos
2 |
3 | ### [[Project Website :dart:]](https://soczech.github.io/multi-task-object-states/) [[Paper (Arxiv) :page_with_curl:]](https://arxiv.org/abs/2211.13500) [[Paper (TPAMI) :page_with_curl:]](https://ieeexplore.ieee.org/abstract/document/10420504) [Code :octocat:]
4 |
5 | This repository contrains code for the TPAMI paper [Multi-Task Learning of Object States and State-Modifying Actions from Web Videos](https://ieeexplore.ieee.org/abstract/document/10420504).
6 |
7 |
8 |
9 |
10 | ## Train the model on ChangeIt dataset
11 | 1. **Setup the environment**
12 | - Our code can be run in a docker container. Build it by running the following command.
13 | Note that by default, we compile custom CUDA code for architectures 6.1, 7.0, 7.5, 8.0, and 8.6.
14 | You may need to update the Dockerfile with your GPU architecture.
15 | ```
16 | docker build -t multi-task-object-states .
17 | ```
18 | - Go into the docker image.
19 | ```
20 | docker run -it --rm --gpus all -v $(pwd):$(pwd) -w $(pwd) --user=$(id -u $USER):$(id -g $USER) multi-task-object-states bash
21 | ```
22 |
23 | 2. **Download requirements**
24 | - Our code requires CLIP repository, CLIP model weights, and the ChangeIt dataset annotations.
25 | Run `./download_requirements.sh` to obtain those dependencies or download them yourselves.
26 |
27 | 3. **Download dataset**
28 | - To replicate our experiments on the ChangeIt dataset, the dataset videos are required.
29 | Please download them and put them inside `videos/*category*` folder.
30 | See [ChangeIt GitHub page](https://github.com/soCzech/ChangeIt) on how to download them.
31 |
32 | 4. **Train a model**
33 | - Run the training.
34 | ```
35 | python train.py --video_roots ./videos
36 | --dataset_root ./ChangeIt
37 | --train_backbone
38 | --augment
39 | --local_batch_size 2
40 | ```
41 | - We trained the model on 32 GPUs, i.e. batch size 64.
42 | - To run the code on multiple GPUs, simply run the code on a machine with multiple GPUs.
43 | - To run the code on multiple nodes, run the code once on each node.
44 | If you are not running on slurm, you also need to set environment variable `SLURM_NPROCS`
45 | to the total number of nodes and the variable `SLURM_PROCID` to the node id starting from zero.
46 | Make sure you also set `SLURM_JOBID` to some unique value.
47 |
48 |
49 | ## Train the model on your dataset
50 | - To train the model on your dataset, complete steps **1.** and **2.** from above.
51 | - Put your videos into `*dir*/*category*` for every video category `*category*`.
52 | - Put your annotations for selected videos into `*dataset*/annotations/*category*`.
53 | Use the same [format](https://github.com/soCzech/ChangeIt/tree/main/annotations) as in the case of ChangeIt dataset.
54 | - Run the training.
55 | ```
56 | python train.py --video_roots *dir*
57 | --dataset_root *dataset*
58 | --train_backbone
59 | --augment
60 | --local_batch_size 2
61 | --ignore_video_weight
62 | ```
63 | - `--ignore_video_weight` option ignores noise adaptive weighting done for noisy ChangeIt dataset.
64 | To use the noise adaptive weighting, you need to provide `*dataset*/categories.csv` and `*dataset*/videos/*category*.csv` files as well.
65 |
66 |
67 | ## Use a trained model
68 | Here is an example code for the inference of a trained model.
69 | ```python
70 | checkpoint = torch.load("path/to/saved/model.pth", map_location="cpu")
71 | model = ClipClassifier(params=checkpoint["args"],
72 | n_classes=checkpoint["n_classes"],
73 | hidden_mlp_layers=checkpoint["hidden_mlp_layers"]).cuda()
74 | model.load_state_dict({k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()})
75 |
76 | video_frames = torch.from_numpy(
77 | extract_frames(video_fn, fps=1, size=(398, 224), crop=(398 - 224, 0)))
78 |
79 | with torch.no_grad():
80 | predictions = model(video_frames.cuda())
81 | state_pred, action_pred = torch.softmax(predictions["state"], -1), torch.softmax(predictions["action"], -1)
82 | ```
83 |
84 |
85 | ## Citation
86 | ```bibtex
87 | @article{soucek2024multitask,
88 | title={Multi-Task Learning of Object States and State-Modifying Actions from Web Videos},
89 | author={Sou\v{c}ek, Tom\'{a}\v{s} and Alayrac, Jean-Baptiste and Miech, Antoine and Laptev, Ivan and Sivic, Josef},
90 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
91 | year={2024},
92 | doi={10.1109/TPAMI.2024.3362288}
93 | }
94 | ```
95 |
96 |
97 | ## Acknowledgements
98 | This work was partly supported by the European Regional Development Fund under the project IMPACT (reg. no. CZ.02.1.01/0.0/0.0/15_003/0000468), the Ministry of Education, Youth and Sports of the Czech Republic through the e-INFRA CZ (ID:90140), the French government under management of Agence Nationale de la Recherche as part of the “Investissements d’avenir” program, reference ANR19-P3IA-0001 (PRAIRIE 3IA Institute), and Louis Vuitton ENS Chair on Artificial Intelligence.
99 |
100 | The ordering constraint code has been adapted from the CVPR 2022 paper
101 | [Look for the Change: Learning Object States and State-Modifying Actions from Untrimmed Web Videos](https://soczech.github.io/look-for-the-change/)
102 | available on [GitHub](https://github.com/soCzech/LookForTheChange).
103 |
--------------------------------------------------------------------------------
/cuda_ops/python/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import _lookforthechange_ops
3 |
4 |
5 | def optimal_state_change(state_tensor, action_tensor, lens, delta, kappa, max_action_state_distance=500):
6 | return _lookforthechange_ops.optimal_state_change(
7 | state_tensor.contiguous(), action_tensor.contiguous(), lens, delta, kappa, max_action_state_distance)
8 |
9 |
10 | def optimal_state_change_indices(state_tensor, action_tensor, lens, max_action_state_distance=500):
11 | return _lookforthechange_ops.optimal_state_change_indices(
12 | state_tensor.contiguous(), action_tensor.contiguous(), lens, max_action_state_distance)
13 |
--------------------------------------------------------------------------------
/cuda_ops/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils import cpp_extension
3 |
4 | library_dirs = cpp_extension.library_paths(cuda=True)
5 | include_dirs = cpp_extension.include_paths(cuda=True)
6 |
7 | print("library_dirs:", library_dirs)
8 | print("include_dirs:", include_dirs)
9 |
10 | setup(
11 | name="lookforthechange",
12 | version="2.0",
13 | install_requires=[
14 | "numpy",
15 | "torch"
16 | ],
17 | ext_modules=[
18 | cpp_extension.CUDAExtension(
19 | name='_lookforthechange_ops',
20 | sources=[
21 | 'src/common.cpp',
22 | 'src/optimal_state_change.cu',
23 | 'src/optimal_state_change_indices.cu',
24 | ],
25 | library_dirs=library_dirs,
26 | include_dirs=include_dirs
27 | )
28 | ],
29 | packages=['lookforthechange'],
30 | package_dir={'lookforthechange': './python'},
31 | cmdclass={'build_ext': cpp_extension.BuildExtension}
32 | )
33 |
--------------------------------------------------------------------------------
/cuda_ops/src/common.cpp:
--------------------------------------------------------------------------------
1 | #include "common.hpp"
2 |
3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4 | m.def("optimal_state_change", &optimal_state_change, "Optimal State1-Action-State2 Sequence (GPU)");
5 | m.def("optimal_state_change_indices", &optimal_state_change_indices, "Optimal State1-Action-State2 Sequence (GPU)");
6 | }
7 |
--------------------------------------------------------------------------------
/cuda_ops/src/common.hpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | // C++ interface
4 | #define CHECK_CPU(x) TORCH_CHECK(!x.type().is_cuda(), #x " must be a CPU tensor")
5 | #define CHECK_GPU(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a GPU tensor")
6 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
7 | #define CHECK_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
8 | #define CHECK_CUDA_INPUT(x) CHECK_GPU(x); CHECK_CONTIGUOUS(x)
9 |
10 | std::vector optimal_state_change(
11 | torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int delta, int kappa, int max_action_state_distance);
12 |
13 | torch::Tensor optimal_state_change_indices(
14 | torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int max_action_state_distance);
15 |
--------------------------------------------------------------------------------
/cuda_ops/src/optimal_state_change.cu:
--------------------------------------------------------------------------------
1 | #include "common.hpp"
2 |
3 | __global__ void SingleStateChangeKernel(
4 | const float* state_tensor,
5 | const float* action_tensor,
6 | const int* lens,
7 | int* state_targets,
8 | int* action_targets,
9 | const int delta,
10 | const int kappa,
11 | const int max_action_state_distance
12 | ) {
13 | const int batch_idx = blockIdx.x;
14 | const int video_len = blockDim.x;
15 | const int state1_pos = threadIdx.x;
16 | const int actual_len = lens[batch_idx];
17 |
18 | // get pointer to shared memory
19 | extern __shared__ char shared_mem[];
20 | int* state1_to_action_pos = reinterpret_cast(shared_mem);
21 | int* state1_to_state2_pos = state1_to_action_pos + video_len;
22 | float* state1_to_score = reinterpret_cast(state1_to_state2_pos + video_len);
23 | float* action_tensor_shared = state1_to_score + video_len;
24 | float* state_tensor_shared = action_tensor_shared + video_len;
25 |
26 | // load action and state tensors into shared memory
27 | action_tensor_shared[state1_pos] = action_tensor[batch_idx * video_len + state1_pos];
28 | state_tensor_shared[2 * state1_pos + 0] = state_tensor[batch_idx * video_len * 2 + state1_pos * 2 + 0];
29 | state_tensor_shared[2 * state1_pos + 1] = state_tensor[batch_idx * video_len * 2 + state1_pos * 2 + 1];
30 |
31 | __syncthreads();
32 |
33 | float best_score = -std::numeric_limits::infinity();
34 | int best_action_pos = 0, best_state2_pos = 0; // position of states/action for videos shorter than 3
35 |
36 | for (int action_pos = state1_pos + 1; action_pos <= state1_pos + max_action_state_distance && action_pos < actual_len - 1; ++action_pos) { // -1: need at least one position for state2
37 | float action_score = action_tensor_shared[action_pos];
38 |
39 | for (int state2_pos = action_pos + 1; state2_pos <= action_pos + max_action_state_distance && state2_pos < actual_len; ++state2_pos) {
40 | float state2_score = state_tensor_shared[2 * state2_pos + 1]; // 2 states, +1 for second state
41 |
42 | float score = action_score * state2_score;
43 | if (score > best_score) {
44 | best_score = score;
45 | best_action_pos = action_pos;
46 | best_state2_pos = state2_pos;
47 | }
48 | }
49 | }
50 |
51 | state1_to_action_pos[state1_pos] = best_action_pos;
52 | state1_to_state2_pos[state1_pos] = best_state2_pos;
53 | state1_to_score[state1_pos] = best_score * state_tensor_shared[2 * state1_pos + 0];
54 |
55 | __syncthreads();
56 |
57 | if (state1_pos == 0) { // compute reduction only on the first thread
58 | best_score = state1_to_score[0];
59 | int best_state1_pos = 0;
60 | for (int i = 1; i < actual_len - 2; ++i) { // -2: need at least one position for action and one for state2
61 | if (best_score < state1_to_score[i]) {
62 | best_state1_pos = i;
63 | best_score = state1_to_score[i];
64 | }
65 | }
66 | best_action_pos = state1_to_action_pos[best_state1_pos];
67 | best_state2_pos = state1_to_state2_pos[best_state1_pos];
68 |
69 | // FILL state_targets TENSOR
70 | // 0 .. default - no label
71 | // 1 .. initial state label
72 | // 2 .. end state label
73 | for (int i = best_state1_pos - delta; i <= best_state1_pos + delta; ++i) {
74 | if (i < 0 || i >= actual_len) continue;
75 | state_targets[batch_idx * video_len + i] = 1;
76 | }
77 | for (int i = best_state2_pos - delta; i <= best_state2_pos + delta; ++i) {
78 | if (i < 0 || i >= actual_len) continue;
79 | state_targets[batch_idx * video_len + i] = 2;
80 | }
81 |
82 | // FILL action_targets TENSOR
83 | // 0 .. default - no label
84 | // 1 .. no-action label
85 | // 2 .. action label
86 | for (int i = 0; i <= delta; ++i) {
87 | int j = best_action_pos - i - kappa;
88 | if (j < 0) {
89 | action_targets[batch_idx * video_len + 0] = 1;
90 | } else {
91 | action_targets[batch_idx * video_len + j] = 1;
92 | }
93 |
94 | int k = best_action_pos + i + kappa;
95 | if (k >= actual_len) {
96 | action_targets[batch_idx * video_len + actual_len - 1] = 1;
97 | } else {
98 | action_targets[batch_idx * video_len + k] = 1;
99 | }
100 | }
101 | for (int i = best_action_pos - delta; i <= best_action_pos + delta; ++i) {
102 | if (i < 0 || i >= actual_len) continue;
103 | action_targets[batch_idx * video_len + i] = 2;
104 | }
105 | }
106 | }
107 |
108 | std::vector optimal_state_change(
109 | torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int delta, int kappa, int max_action_state_distance) {
110 |
111 | CHECK_CUDA_INPUT(state_tensor);
112 | CHECK_CUDA_INPUT(action_tensor);
113 | CHECK_CUDA_INPUT(lens);
114 |
115 | int batch_size = state_tensor.size(0);
116 | int video_len = state_tensor.size(1);
117 |
118 | TORCH_CHECK(state_tensor.size(2) == 2, "state_tensor must be of shape [batch, video_len, 2]")
119 | TORCH_CHECK(action_tensor.size(2) == 1, "action_tensor must be of shape [batch, video_len, 1]")
120 |
121 | auto options = torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA);
122 | auto state_targets = torch::zeros({batch_size, video_len}, options);
123 | auto action_targets = torch::zeros({batch_size, video_len}, options);
124 |
125 | const int threads = video_len;
126 | const int blocks = batch_size;
127 | // store in shared memory:
128 | // best action position for each state1 position (1x int)
129 | // best state2 position for each state1 position (1x int)
130 | // best score for each state1 position (1x float)
131 | // action tensor (1x float)
132 | // state tensor (2x float)
133 | const int shared_mem = video_len * (2 * sizeof(int) + 4 * sizeof(float));
134 | SingleStateChangeKernel<<>>(
135 | state_tensor.data_ptr(),
136 | action_tensor.data_ptr(),
137 | lens.data_ptr(),
138 | state_targets.data_ptr(),
139 | action_targets.data_ptr(),
140 | delta,
141 | kappa,
142 | max_action_state_distance);
143 |
144 | return std::vector{state_targets, action_targets};
145 | }
146 |
--------------------------------------------------------------------------------
/cuda_ops/src/optimal_state_change_indices.cu:
--------------------------------------------------------------------------------
1 | #include "common.hpp"
2 |
3 | __global__ void SingleStateChangeIndicesKernel(
4 | const float* state_tensor,
5 | const float* action_tensor,
6 | const int* lens,
7 | int* out_indices,
8 | const int max_action_state_distance
9 | ) {
10 | const int batch_idx = blockIdx.x;
11 | const int video_len = blockDim.x;
12 | const int state1_pos = threadIdx.x;
13 | const int actual_len = lens[batch_idx];
14 |
15 | // get pointer to shared memory
16 | extern __shared__ char shared_mem[];
17 | int* state1_to_action_pos = reinterpret_cast(shared_mem);
18 | int* state1_to_state2_pos = state1_to_action_pos + video_len;
19 | float* state1_to_score = reinterpret_cast(state1_to_state2_pos + video_len);
20 | float* action_tensor_shared = state1_to_score + video_len;
21 | float* state_tensor_shared = action_tensor_shared + video_len;
22 |
23 | // load action and state tensors into shared memory
24 | action_tensor_shared[state1_pos] = action_tensor[batch_idx * video_len + state1_pos];
25 | state_tensor_shared[2 * state1_pos + 0] = state_tensor[batch_idx * video_len * 2 + state1_pos * 2 + 0];
26 | state_tensor_shared[2 * state1_pos + 1] = state_tensor[batch_idx * video_len * 2 + state1_pos * 2 + 1];
27 |
28 | __syncthreads();
29 |
30 | float best_score = -std::numeric_limits::infinity();
31 | int best_action_pos = 0, best_state2_pos = 0; // position of states/action for videos shorter than 3
32 |
33 | for (int action_pos = state1_pos + 1; action_pos <= state1_pos + max_action_state_distance && action_pos < actual_len - 1; ++action_pos) { // -1: need at least one position for state2
34 | float action_score = action_tensor_shared[action_pos];
35 |
36 | for (int state2_pos = action_pos + 1; state2_pos <= action_pos + max_action_state_distance && state2_pos < actual_len; ++state2_pos) {
37 | float state2_score = state_tensor_shared[2 * state2_pos + 1]; // 2 states, +1 for second state
38 |
39 | float score = action_score * state2_score;
40 | if (score > best_score) {
41 | best_score = score;
42 | best_action_pos = action_pos;
43 | best_state2_pos = state2_pos;
44 | }
45 | }
46 | }
47 |
48 | state1_to_action_pos[state1_pos] = best_action_pos;
49 | state1_to_state2_pos[state1_pos] = best_state2_pos;
50 | state1_to_score[state1_pos] = best_score * state_tensor_shared[2 * state1_pos + 0];
51 |
52 | __syncthreads();
53 |
54 | if (state1_pos == 0) { // compute reduction only on the first thread
55 | best_score = state1_to_score[0];
56 | int best_state1_pos = 0;
57 | for (int i = 1; i < actual_len - 2; ++i) { // -2: need at least one position for action and one for state2
58 | if (best_score < state1_to_score[i]) {
59 | best_state1_pos = i;
60 | best_score = state1_to_score[i];
61 | }
62 | }
63 | best_action_pos = state1_to_action_pos[best_state1_pos];
64 | best_state2_pos = state1_to_state2_pos[best_state1_pos];
65 |
66 | out_indices[batch_idx * 3 + 0] = best_state1_pos;
67 | out_indices[batch_idx * 3 + 1] = best_state2_pos;
68 | out_indices[batch_idx * 3 + 2] = best_action_pos;
69 | }
70 | }
71 |
72 | torch::Tensor optimal_state_change_indices(
73 | torch::Tensor state_tensor, torch::Tensor action_tensor, torch::Tensor lens, int max_action_state_distance) {
74 |
75 | CHECK_CUDA_INPUT(state_tensor);
76 | CHECK_CUDA_INPUT(action_tensor);
77 | CHECK_CUDA_INPUT(lens);
78 |
79 | int batch_size = state_tensor.size(0);
80 | int video_len = state_tensor.size(1);
81 |
82 | TORCH_CHECK(state_tensor.size(2) == 2, "state_tensor must be of shape [batch, video_len, 2]")
83 | TORCH_CHECK(action_tensor.size(2) == 1, "action_tensor must be of shape [batch, video_len, 1]")
84 |
85 | auto options = torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA);
86 | auto out_indices = torch::zeros({batch_size, 3}, options);
87 |
88 | const int threads = video_len;
89 | const int blocks = batch_size;
90 | // store in shared memory:
91 | // best action position for each state1 position (1x int)
92 | // best state2 position for each state1 position (1x int)
93 | // best score for each state1 position (1x float)
94 | // action tensor (1x float)
95 | // state tensor (2x float)
96 | const int shared_mem = video_len * (2 * sizeof(int) + 4 * sizeof(float));
97 | SingleStateChangeIndicesKernel<<>>(
98 | state_tensor.data_ptr(),
99 | action_tensor.data_ptr(),
100 | lens.data_ptr(),
101 | out_indices.data_ptr(),
102 | max_action_state_distance);
103 |
104 | return out_indices;
105 | }
106 |
--------------------------------------------------------------------------------
/download_requirements.sh:
--------------------------------------------------------------------------------
1 | echo -n "Downloading and patching clip repository ... "
2 | git clone --quiet https://github.com/openai/CLIP.git || exit 1
3 | cd CLIP || exit 1
4 | git checkout --quiet d50d76daa670286dd6cacf3bcd80b5e4823fc8e1 || exit 1
5 | sed -i /self.proj/d clip/model.py || exit 1
6 | cd ..
7 | echo "OK ✓"
8 |
9 | echo -n "Downloading clip weights ... "
10 | mkdir -p weights
11 | wget https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt -q -O weights/ViT-L-14.pt
12 | SHA256SUM=$(sha256sum weights/ViT-L-14.pt | cut -d' ' -f1)
13 |
14 | if [[ ${SHA256SUM} == "b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836" ]]; then
15 | echo "OK ✓"
16 | else
17 | echo "ERROR ✗"
18 | exit 1
19 | fi
20 |
21 | echo -n "Downloading ChangeIt annotations ... "
22 | mkdir -p videos
23 | git clone --quiet https://github.com/soCzech/ChangeIt.git || exit 1
24 | echo "OK ✓"
25 |
26 | echo "To replicate our experiments, please download ChangeIt videos into \`videos/*category*\` folders."
27 | echo "More details on how to download the videos at https://github.com/soCzech/ChangeIt."
28 | echo "If you wish to train the model on your data, please see the README file."
29 |
--------------------------------------------------------------------------------
/method/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import time
4 | import torch
5 | import ffmpeg
6 | import random
7 | import itertools
8 | import numpy as np
9 | import torch.distributed as dist
10 | from torch.utils.data import Dataset
11 | from typing import TypeVar, Optional, Iterator
12 |
13 | import imgaug as ia
14 | import imgaug.augmenters as iaa
15 |
16 |
17 | def extract_frames(video_path, fps, size=None, crop=None, start=None, duration=None):
18 | if start is not None:
19 | cmd = ffmpeg.input(video_path, ss=start, t=duration)
20 | else:
21 | cmd = ffmpeg.input(video_path)
22 |
23 | if size is None:
24 | info = [s for s in ffmpeg.probe(video_path)["streams"] if s["codec_type"] == "video"][0]
25 | size = (info["width"], info["height"])
26 | elif isinstance(size, int):
27 | size = (size, size)
28 |
29 | if fps is not None:
30 | cmd = cmd.filter('fps', fps=fps)
31 | cmd = cmd.filter('scale', size[0], size[1])
32 |
33 | if crop is not None:
34 | cmd = cmd.filter('crop', f'in_w-{crop[0]}', f'in_h-{crop[1]}')
35 | size = (size[0] - crop[0], size[1] - crop[1])
36 |
37 | for i in range(5):
38 | try:
39 | out, _ = (
40 | cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
41 | .run(capture_stdout=True, quiet=True)
42 | )
43 | break
44 | except Exception as e:
45 | time.sleep(random.random() * 5.)
46 | if i < 4:
47 | continue
48 | print(f"W: FFMPEG file {video_path} read failed!", flush=True)
49 | if isinstance(e, ffmpeg.Error):
50 | print("STDOUT:", e.stdout, flush=True)
51 | print("STDERR:", e.stderr, flush=True)
52 | raise
53 |
54 | video = np.frombuffer(out, np.uint8).reshape([-1, size[1], size[0], 3])
55 | return video
56 |
57 |
58 | class ChangeItVideoDataset(Dataset):
59 |
60 | def __init__(self,
61 | video_roots,
62 | annotation_root=None,
63 | file_mode="unannotated", # "unannotated", "annotated", "all"
64 | noise_adapt_weight_root=None,
65 | noise_adapt_weight_threshold_file=None,
66 | augment=False):
67 |
68 | self.classes = {x: i for i, x in enumerate(sorted(set([os.path.basename(fn) for fn in itertools.chain(*[
69 | glob.glob(os.path.join(root, "*")) for root in video_roots
70 | ]) if os.path.isdir(fn)])))}
71 |
72 | self.files = {key: sorted(itertools.chain(*[
73 | glob.glob(os.path.join(root, key, "*.mp4")) + glob.glob(os.path.join(root, key, "*.webm")) for root in
74 | video_roots
75 | ])) for key in self.classes.keys()}
76 |
77 | self.annotations = {key: {
78 | os.path.basename(fn).split(".")[0]: np.uint8(
79 | [int(line.strip().split(",")[1]) for line in open(fn).readlines()])
80 | for fn in glob.glob(os.path.join(annotation_root, key, "*.csv"))
81 | } for key in self.classes.keys()} if annotation_root is not None else None
82 |
83 | if file_mode == "unannotated":
84 | for key in self.classes.keys():
85 | for fn in self.files[key].copy():
86 | if os.path.basename(fn).split(".")[0] in self.annotations[key]:
87 | self.files[key].remove(fn)
88 | elif file_mode == "annotated":
89 | for key in self.classes.keys():
90 | for fn in self.files[key].copy():
91 | if os.path.basename(fn).split(".")[0] not in self.annotations[key]:
92 | self.files[key].remove(fn)
93 | elif file_mode == "all":
94 | pass
95 | else:
96 | raise NotImplementedError()
97 |
98 | self.flattened_files = []
99 | for key in self.classes.keys():
100 | self.flattened_files.extend([(key, fn) for fn in self.files[key]])
101 |
102 | self.augment = augment
103 |
104 | # Noise adaptive weighting
105 | if noise_adapt_weight_root is None:
106 | return
107 |
108 | self.noise_adapt_weight = {}
109 | for key in self.classes.keys():
110 | with open(os.path.join(noise_adapt_weight_root, f"{key}.csv"), "r") as f:
111 | for line in f.readlines():
112 | vid_id, score = line.strip().split(",")
113 | self.noise_adapt_weight[f"{key}/{vid_id}"] = float(score)
114 |
115 | self.noise_adapt_weight_thr = {line.split(",")[0]: float(line.split(",")[2].strip())
116 | for line in open(noise_adapt_weight_threshold_file, "r").readlines()[1:]}
117 |
118 | def __getitem__(self, idx):
119 | class_name, video_fn = self.flattened_files[idx]
120 | file_id = os.path.basename(video_fn).split(".")[0]
121 |
122 | video_frames = extract_frames(video_fn, fps=1, size=(398, 224)).copy()
123 |
124 | if self.augment:
125 | video_frames = ChangeItVideoDataset.augment_fc(video_frames)
126 | video_frames = video_frames[:, :, random.randint(0, 398 - 224 - 1):][:, :, :224]
127 | else:
128 | video_frames = video_frames[:, :, (398 - 224) // 2:][:, :, :224]
129 | video_frames = torch.from_numpy(video_frames.copy())
130 |
131 | annotation = self.annotations[class_name][file_id] \
132 | if self.annotations is not None and file_id in self.annotations[class_name] else None
133 | video_level_score = self.noise_adapt_weight[f"{class_name}/{file_id}"] - self.noise_adapt_weight_thr[class_name] \
134 | if hasattr(self, "noise_adapt_weight") else None
135 |
136 | return class_name + "/" + file_id, self.classes[class_name], video_frames, annotation, video_level_score
137 |
138 | @property
139 | def n_classes(self):
140 | return len(self.classes)
141 |
142 | def __len__(self):
143 | return len(self.flattened_files)
144 |
145 | def __repr__(self):
146 | string = f"ChangeItVideoDataset(n_classes: {self.n_classes}, n_samples: {self.__len__()}, " \
147 | f"augment: {self.augment})"
148 | for key in sorted(self.classes.keys()):
149 | string += f"\n> {key:20} {len(self.files[key]):4d}"
150 | if hasattr(self, "noise_adapt_weight_thr"):
151 | len_ = len([
152 | fn for fn in self.files[key]
153 | if self.noise_adapt_weight[f"{key}/{os.path.basename(fn).split('.')[0]}"] >
154 | self.noise_adapt_weight_thr[key]
155 | ])
156 | string += f" (above threshold {self.noise_adapt_weight_thr[key]:.3f}: {len_:4d})"
157 | return string
158 |
159 | @staticmethod
160 | def augment_fc(video_frames):
161 | seq = iaa.Sequential(
162 | [
163 | # apply the following augmenters to most images
164 | iaa.Fliplr(0.5), # horizontally flip 50% of all images
165 | iaa.Flipud(0.1), # vertically flip 10% of all images
166 | # crop images by -5% to 10% of their height/width
167 | iaa.Sometimes(0.5, iaa.CropAndPad(
168 | percent=(-0.05, 0.1),
169 | pad_mode=ia.ALL,
170 | pad_cval=(0, 255)
171 | )),
172 | iaa.Sometimes(0.5, iaa.Affine(
173 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
174 | # scale images to 80-120% of their size, individually per axis
175 | translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
176 | # translate by -20 to +20 percent (per axis)
177 | rotate=(-15, 15), # rotate by -15 to +15 degrees
178 | shear=(-16, 16), # shear by -16 to +16 degrees
179 | order=[0, 1], # use nearest neighbour or bilinear interpolation (fast)
180 | cval=(0, 255), # if mode is constant, use a cval between 0 and 255
181 | mode=ia.ALL # use any of scikit-image's warping modes (see 2nd image from the top for examples)
182 | )),
183 | # execute 0 to 4 of the following (less important) augmenters per image
184 | # don't execute all of them, as that would often be way too strong
185 | iaa.SomeOf((0, 4), [
186 | iaa.OneOf([
187 | iaa.GaussianBlur((0, 3.0)), # blur images with a sigma between 0 and 3.0
188 | iaa.AverageBlur(k=(2, 3)),
189 | # blur image using local means with kernel sizes between 2 and 7
190 | iaa.MedianBlur(k=(3, 5)),
191 | # blur image using local medians with kernel sizes between 3 and 5
192 | ]),
193 | iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5)), # sharpen images
194 | iaa.Add((-10, 10), per_channel=0.5),
195 | # change brightness of images (by -10 to 10 of original value)
196 | iaa.AddToHueAndSaturation((-20, 20)), # change hue and saturation
197 | ], random_order=True)
198 | ],
199 | random_order=True
200 | )
201 | seq_det = seq.to_deterministic()
202 |
203 | video_frames_augmented = np.empty_like(video_frames)
204 | for i in range(len(video_frames)):
205 | video_frames_augmented[i] = seq_det.augment_image(video_frames[i])
206 | return video_frames_augmented
207 |
208 |
209 | def identity_collate(items):
210 | return items
211 |
212 |
213 | T_co = TypeVar('T_co', covariant=True)
214 |
215 |
216 | class DistributedDropFreeSampler:
217 |
218 | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
219 | rank: Optional[int] = None, shuffle: bool = True,
220 | seed: int = 0) -> None:
221 | if num_replicas is None:
222 | if not dist.is_available():
223 | raise RuntimeError("Requires distributed package to be available")
224 | num_replicas = dist.get_world_size()
225 | if rank is None:
226 | if not dist.is_available():
227 | raise RuntimeError("Requires distributed package to be available")
228 | rank = dist.get_rank()
229 | if rank >= num_replicas or rank < 0:
230 | raise ValueError(
231 | "Invalid rank {}, rank should be in the interval"
232 | " [0, {}]".format(rank, num_replicas - 1))
233 | self.dataset = dataset
234 | self.num_replicas = num_replicas
235 | self.rank = rank
236 | self.epoch = 0
237 |
238 | self.num_samples = len(self.dataset) // self.num_replicas # type: ignore[arg-type]
239 | if self.num_samples * self.num_replicas < len(self.dataset): # type: ignore[arg-type]
240 | if self.rank < len(self.dataset) % self.num_replicas: # type: ignore[arg-type]
241 | self.num_samples += 1
242 | self.total_size = len(self.dataset) # type: ignore[arg-type]
243 | self.shuffle = shuffle
244 | self.seed = seed
245 |
246 | def __iter__(self) -> Iterator[T_co]:
247 | if self.shuffle:
248 | # deterministically shuffle based on epoch and seed
249 | g = torch.Generator()
250 | g.manual_seed(self.seed + self.epoch)
251 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
252 | else:
253 | indices = list(range(len(self.dataset))) # type: ignore[arg-type]
254 | assert len(indices) == self.total_size
255 |
256 | # subsample
257 | indices = indices[self.rank:self.total_size:self.num_replicas]
258 | assert len(indices) == self.num_samples
259 |
260 | return iter(indices)
261 |
262 | def __len__(self) -> int:
263 | return self.num_samples
264 |
265 | def set_epoch(self, epoch: int) -> None:
266 | self.epoch = epoch
267 |
--------------------------------------------------------------------------------
/method/model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | sys.path.append("./CLIP")
4 | from clip.model import VisionTransformer
5 |
6 |
7 | class ClassifierHeads(torch.nn.Module):
8 |
9 | def __init__(self, layers, n_classes=1):
10 | super(ClassifierHeads, self).__init__()
11 |
12 | self.state = torch.nn.ModuleDict({
13 | f"l{i:d}": torch.nn.Linear(layers[i], layers[i + 1])
14 | for i in range(len(layers) - 1)
15 | })
16 | self.action = torch.nn.ModuleDict({
17 | f"l{i:d}": torch.nn.Linear(layers[i], layers[i + 1])
18 | for i in range(len(layers) - 1)
19 | })
20 |
21 | self.state_layer = torch.nn.Linear(layers[-1], 2 * n_classes + 1, bias=True)
22 | self.action_layer = torch.nn.Linear(layers[-1], 1 * n_classes + 1, bias=True)
23 |
24 | def forward(self, inputs):
25 | x = inputs
26 | for i in range(len(self.state)):
27 | x = self.state[f"l{i:d}"](x)
28 | x = torch.relu(x)
29 | state = self.state_layer(x)
30 |
31 | x = inputs
32 | for i in range(len(self.action)):
33 | x = self.action[f"l{i:d}"](x)
34 | x = torch.relu(x)
35 | action = self.action_layer(x)
36 |
37 | return {"state": state, "action": action}
38 |
39 |
40 | class ClipClassifier(torch.nn.Module):
41 |
42 | def __init__(self, hidden_mlp_layers, params, n_classes=1, train_backbone=True):
43 | super(ClipClassifier, self).__init__()
44 |
45 | self.args = params
46 | self.n_classes = n_classes
47 | self.hidden_mlp_layers = hidden_mlp_layers
48 |
49 | if "visual.conv1.weight" in params:
50 | vision_width = params["visual.conv1.weight"].shape[0]
51 | vision_heads = vision_width // 64
52 | vision_layers = len([k for k in params.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
53 | vision_patch_size = params["visual.conv1.weight"].shape[-1]
54 | grid_size = round((params["visual.positional_embedding"].shape[0] - 1) ** 0.5)
55 | image_resolution = vision_patch_size * grid_size
56 | self.args = dict(
57 | input_resolution=image_resolution,
58 | patch_size=vision_patch_size,
59 | width=vision_width,
60 | layers=vision_layers,
61 | heads=vision_heads,
62 | output_dim=None
63 | )
64 |
65 | self.backbone = VisionTransformer(**self.args)
66 |
67 | if "visual.conv1.weight" in params:
68 | self.backbone.load_state_dict({
69 | k[len("visual."):]: v for k, v in params.items()
70 | if k.startswith("visual.") and k != "visual.proj"
71 | })
72 |
73 | if not train_backbone:
74 | for param in self.backbone.parameters():
75 | param.requires_grad_(False)
76 |
77 | self.heads = ClassifierHeads([self.args["width"]] + hidden_mlp_layers, n_classes)
78 |
79 | @staticmethod
80 | def preprocess(imgs):
81 | # CLIP image preprocessing
82 | imgs = imgs.permute((0, 3, 1, 2)).float().div_(255)
83 | mean = torch.as_tensor((0.48145466, 0.4578275, 0.40821073),
84 | dtype=torch.float32, device=imgs.device).view(1, -1, 1, 1)
85 | std = torch.as_tensor((0.26862954, 0.26130258, 0.27577711),
86 | dtype=torch.float32, device=imgs.device).view(1, -1, 1, 1)
87 | return imgs.sub_(mean).div_(std)
88 |
89 | def forward(self, inputs):
90 | x = self.preprocess(inputs)
91 | x = self.backbone(x)
92 | x = self.heads(x)
93 | return x
94 |
--------------------------------------------------------------------------------
/method/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import math
4 | import torch
5 | import pickle
6 | import numpy as np
7 |
8 |
9 | def select_correct_classes(predictions, classes, n_classes=1):
10 | if n_classes == 1:
11 | return predictions
12 | # predictions: B, L, n_classes * third_dim
13 | B, L, dim = predictions.shape
14 | third_dim = dim // n_classes
15 |
16 | x = torch.arange(0, B).view(-1, 1, 1).repeat(1, L, third_dim)
17 | y = torch.arange(0, L).view(1, -1, 1).repeat(B, 1, third_dim)
18 | z = classes.view(-1, 1, 1).repeat(1, L, third_dim) * third_dim + \
19 | torch.arange(0, third_dim, device=classes.device).view(1, 1, third_dim)
20 |
21 | return predictions[x, y, z] # B, L, third_dim
22 |
23 |
24 | def constrained_argmax(pred_action, pred_state, also_only_state=False):
25 | max_val_, best_idx_ = -1, (0, 0, 0)
26 | for i in range(len(pred_state)):
27 | for j in range(i + 2, len(pred_state)):
28 | val_ = pred_state[i, 0] * pred_state[j, 1]
29 | k = np.argmax(pred_action[i + 1:j])
30 | val_ *= pred_action[i + 1 + k]
31 | if val_ > max_val_:
32 | best_idx_ = i, j, i + 1 + k
33 | max_val_ = val_
34 | if not also_only_state:
35 | return best_idx_
36 |
37 | max_val_state_, best_idx_state_ = -1, (0, 0)
38 | for i in range(len(pred_state)):
39 | for j in range(i + 1, len(pred_state)):
40 | val_ = pred_state[i, 0] * pred_state[j, 1]
41 | if val_ > max_val_state_:
42 | best_idx_state_ = i, j
43 | max_val_state_ = val_
44 | return best_idx_, best_idx_state_
45 |
46 |
47 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
48 |
49 | def lr_lambda(current_step):
50 | if current_step < num_warmup_steps:
51 | return float(current_step) / float(max(1, num_warmup_steps))
52 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
53 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
54 |
55 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
56 |
57 |
58 | class AverageMeter:
59 |
60 | def __init__(self):
61 | self.reset()
62 |
63 | def reset(self):
64 | self._sum = 0
65 | self._count = 0
66 |
67 | def update(self, val, n=1):
68 | self._sum += val * n
69 | self._count += n
70 |
71 | @property
72 | def value(self):
73 | if self._count == 0:
74 | return 0
75 | return self._sum / self._count
76 |
77 |
78 | class JointMeter:
79 | def __init__(self, n_classes):
80 | self._dict = {
81 | "sp": [[] for _ in range(n_classes)],
82 | "ap": [[] for _ in range(n_classes)],
83 | "jsp": [[] for _ in range(n_classes)],
84 | "jap": [[] for _ in range(n_classes)],
85 | "acc": [[] for _ in range(n_classes)]
86 | }
87 |
88 | def log(self, pred_action, pred_state, annotations, category):
89 | assert len(annotations) == len(pred_action) == len(pred_state)
90 |
91 | # state accuracy
92 | pred_state_idx = np.argmax(pred_state, axis=-1)
93 | n_gt_states = np.logical_or(annotations == 1, annotations == 3).sum()
94 | state_acc = ((pred_state_idx[annotations == 1] == 0).sum() +
95 | (pred_state_idx[annotations == 3] == 1).sum()) / n_gt_states if n_gt_states > 0 else 0.
96 | if n_gt_states > 0:
97 | self._dict["acc"][category].append(state_acc)
98 |
99 | # action precision
100 | self._dict["ap"][category].append(1. if annotations[np.argmax(pred_action)] == 2 else 0.)
101 |
102 | # state and joint precision
103 | joint, state_only = constrained_argmax(pred_action, pred_state, also_only_state=True)
104 | self._dict["sp"][category].append((0.5 if annotations[state_only[0]] == 1 else 0.0) + \
105 | (0.5 if annotations[state_only[1]] == 3 else 0.0))
106 | self._dict["jap"][category].append(1. if annotations[joint[2]] == 2 else 0.)
107 | self._dict["jsp"][category].append((0.5 if annotations[joint[0]] == 1 else 0.0) + \
108 | (0.5 if annotations[joint[1]] == 3 else 0.0))
109 |
110 | def __getattr__(self, item):
111 | if item in self._dict:
112 | return np.mean([np.mean(x) for x in self._dict[item]]) * 100
113 | raise NotImplementedError()
114 |
115 | def __getitem__(self, item):
116 | return [np.mean(self._dict[k][item]) * 100 for k in ["acc", "sp", "jsp", "ap", "jap"]]
117 |
118 | def dump(self, path, global_id):
119 | with open(f"{path}.{global_id}.pickle", "wb") as f:
120 | pickle.dump(self._dict, f)
121 |
122 | def load(self, path):
123 | n_classes = len(self._dict["sp"])
124 | self._dict = {
125 | "sp": [[] for _ in range(n_classes)],
126 | "ap": [[] for _ in range(n_classes)],
127 | "jsp": [[] for _ in range(n_classes)],
128 | "jap": [[] for _ in range(n_classes)],
129 | "acc": [[] for _ in range(n_classes)]
130 | }
131 | for fn in glob.glob(f"{path}.*.pickle"):
132 | dict_ = pickle.load(open(fn, "rb"))
133 | for k in self._dict.keys():
134 | for i in range(n_classes):
135 | self._dict[k][i] += dict_[k][i]
136 |
137 | for fn in glob.glob(f"{path}.*.pickle"):
138 | os.remove(fn)
139 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import time
4 | import tqdm
5 | import torch
6 | import random
7 | import socket
8 | import argparse
9 | import torch.nn.functional as F
10 | import torch.distributed as dist
11 | import torch.multiprocessing as mp
12 |
13 | import lookforthechange
14 | from method.dataset import ChangeItVideoDataset, identity_collate, DistributedDropFreeSampler
15 | from method.model import ClipClassifier
16 | from method.utils import get_cosine_schedule_with_warmup, AverageMeter, select_correct_classes, JointMeter
17 |
18 | cv2.setNumThreads(1) # do not spawn multiple threads for augmentation (ffmpeg then raises an error)
19 |
20 |
21 | def main(args):
22 | ngpus_per_node = torch.cuda.device_count()
23 | node_count = int(os.environ.get("SLURM_NPROCS", "1"))
24 | node_rank = int(os.environ.get("SLURM_PROCID", "0"))
25 | job_id = os.environ.get("SLURM_JOBID", "0")
26 |
27 | if node_count == 1: # for PBS/PMI clusters
28 | node_count = int(os.environ.get("PMI_SIZE", "1"))
29 | node_rank = int(os.environ.get("PMI_RANK", "0"))
30 | job_id = os.environ.get("PBS_JOBID", "".join([str(random.randint(0, 9)) for _ in range(5)]))
31 |
32 | dist_url = "file://{}.{}".format(os.path.realpath("distfile"), job_id)
33 | print(f"Hi from node {socket.gethostname()} ({node_rank}/{node_count} with {ngpus_per_node} GPUs)!", flush=True)
34 |
35 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=({
36 | "ngpus_per_node": ngpus_per_node,
37 | "node_count": node_count,
38 | "node_rank": node_rank,
39 | "dist_url": dist_url,
40 | "job_id": job_id
41 | }, args))
42 |
43 |
44 | def main_worker(local_rank, cluster_args, args):
45 | world_size = cluster_args["node_count"] * cluster_args["ngpus_per_node"]
46 | global_rank = cluster_args["node_rank"] * cluster_args["ngpus_per_node"] + local_rank
47 | dist.init_process_group(
48 | backend="nccl",
49 | init_method=cluster_args["dist_url"],
50 | world_size=world_size,
51 | rank=global_rank,
52 | )
53 |
54 | if global_rank == 0:
55 | for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
56 | print(f"# {k}: {v}")
57 | print(f"# effective_batch_size: {world_size * args.local_batch_size}", flush=True)
58 |
59 | ###############
60 | # DATASET
61 | ###############
62 | train_ds = ChangeItVideoDataset(
63 | video_roots=args.video_roots, annotation_root=os.path.join(args.dataset_root, "annotations"),
64 | file_mode="unannotated", noise_adapt_weight_root=None if args.ignore_video_weight else os.path.join(args.dataset_root, "videos"),
65 | noise_adapt_weight_threshold_file=None if args.ignore_video_weight else os.path.join(args.dataset_root, "categories.csv"), augment=args.augment
66 | )
67 | test_ds = ChangeItVideoDataset(
68 | video_roots=args.video_roots, annotation_root=os.path.join(args.dataset_root, "annotations"),
69 | file_mode="annotated", noise_adapt_weight_root=None if args.ignore_video_weight else os.path.join(args.dataset_root, "videos"),
70 | noise_adapt_weight_threshold_file=None if args.ignore_video_weight else os.path.join(args.dataset_root, "categories.csv"), augment=False
71 | )
72 |
73 | if global_rank == 0:
74 | print(train_ds, test_ds, sep="\n", flush=True)
75 |
76 | train_sampler = torch.utils.data.distributed.DistributedSampler(
77 | train_ds, shuffle=True, drop_last=True) if world_size > 1 else None
78 | train_ds_iter = torch.utils.data.DataLoader(
79 | train_ds, batch_size=args.local_batch_size, shuffle=world_size == 1, drop_last=True, num_workers=2,
80 | pin_memory=True, sampler=train_sampler, collate_fn=identity_collate)
81 |
82 | test_sampler = DistributedDropFreeSampler(test_ds, shuffle=False) if world_size > 1 else None
83 | test_ds_iter = torch.utils.data.DataLoader(
84 | test_ds, batch_size=1, shuffle=False, drop_last=False, num_workers=2,
85 | pin_memory=True, sampler=test_sampler, collate_fn=identity_collate)
86 |
87 | ###############
88 | # MODEL
89 | ###############
90 | weights = torch.jit.load(args.clip_weights, map_location="cpu").state_dict()
91 | model = ClipClassifier(hidden_mlp_layers=[4096],
92 | params=weights,
93 | n_classes=train_ds.n_classes,
94 | train_backbone=args.train_backbone)
95 | assert model.backbone.input_resolution == 224
96 |
97 | torch.cuda.set_device(local_rank)
98 | model.cuda(local_rank)
99 | model_parallel = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
100 |
101 | ###############
102 | # OPTIMIZER
103 | ###############
104 | head_params = model_parallel.module.heads.parameters()
105 | optim_head = torch.optim.SGD(head_params, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
106 | scheduler_head = get_cosine_schedule_with_warmup(optim_head, 5 * len(train_ds_iter), len(train_ds_iter) * args.n_epochs)
107 |
108 | if args.train_backbone:
109 | backbone_params = model_parallel.module.backbone.parameters()
110 | optim_backbone = torch.optim.AdamW(backbone_params, lr=args.lr_backbone, weight_decay=args.weight_decay_backbone)
111 | scheduler_backbone = get_cosine_schedule_with_warmup(optim_backbone, 5 * len(train_ds_iter), len(train_ds_iter) * args.n_epochs)
112 |
113 | ###############
114 | # TRAINING
115 | ###############
116 | n_frames_per_gt = args.n_frames_per_gt
117 | kappa_dist = 60
118 |
119 | loss_metric = AverageMeter()
120 | loss_norm_metric = AverageMeter()
121 | unsup_state_loss_metric = AverageMeter()
122 | unsup_action_loss_metric = AverageMeter()
123 |
124 | for epoch in range(1, args.n_epochs + 1):
125 | if world_size > 1: train_sampler.set_epoch(epoch)
126 | loss_metric.reset()
127 | loss_norm_metric.reset()
128 | unsup_state_loss_metric.reset()
129 | unsup_action_loss_metric.reset()
130 |
131 | iterator = tqdm.tqdm(train_ds_iter) if global_rank == 0 else train_ds_iter
132 | for batch in iterator: # id, class, video, annotation/None, weight
133 |
134 | optim_head.zero_grad()
135 | if args.train_backbone: optim_backbone.zero_grad()
136 |
137 | # COMPUTE GT FOR ALL VIDEOS IN BATCH
138 | batch_for_training = []
139 | for _, class_, inputs, _, weight in batch:
140 | classes = torch.LongTensor([class_])
141 |
142 | # PREDICT
143 | with torch.no_grad():
144 | predictions = []
145 | for i in range(0, len(inputs), 256):
146 | predictions += [model(inputs[i:i + 256].cuda(local_rank))]
147 | predictions = {
148 | "state": torch.cat([p["state"] for p in predictions], dim=0),
149 | "action": torch.cat([p["action"] for p in predictions], dim=0)
150 | }
151 |
152 | st_probs = select_correct_classes(
153 | torch.softmax(predictions["state"].unsqueeze(0), -1), classes, n_classes=train_ds.n_classes)
154 | ac_probs = select_correct_classes(
155 | torch.softmax(predictions["action"].unsqueeze(0), -1), classes, n_classes=train_ds.n_classes + 1)
156 |
157 | # COMPUTE GROUND TRUTH
158 | indices = lookforthechange.optimal_state_change_indices(
159 | st_probs, ac_probs, lens=torch.tensor([st_probs.shape[1]], dtype=torch.int32, device=st_probs.device))
160 | indices = indices.view(1, 3).cpu() # [S1idx, S2idx, ACidx]
161 |
162 | positives = indices.repeat(n_frames_per_gt, 1) + \
163 | torch.arange(-(n_frames_per_gt // 2), (n_frames_per_gt // 2) + 1, 1, device=indices.device).unsqueeze_(1)
164 | indices_extended = torch.cat([
165 | positives.transpose(1, 0).reshape(-1), positives[:, 2] - kappa_dist, positives[:, 2] + kappa_dist
166 | ], 0).clamp_(0, len(inputs) - 1)
167 | # [ S1idx - 1, S1idx, S1idx + 1,
168 | # S2idx - 1, S2idx, S2idx + 1,
169 | # ACidx - 1, ACidx, ACidx + 1,
170 | # ACidx - 61, ACidx - 60, ACidx - 59,
171 | # ACidx + 59, ACidx + 60, ACidx + 61]
172 |
173 | bg_class_index = train_ds.n_classes
174 | action_targets = torch.LongTensor([bg_class_index] * n_frames_per_gt * 2 +
175 | [class_] * n_frames_per_gt +
176 | [bg_class_index] * n_frames_per_gt * 2)
177 | # [ BG, BG, BG,
178 | # BG, BG, BG,
179 | # CLS, CLS, CLS,
180 | # BG, BG, BG,
181 | # BG, BG, BG]
182 |
183 | bg_class_index = train_ds.n_classes * 2
184 | state_targets = torch.LongTensor([class_ * 2 + 0] * n_frames_per_gt +
185 | [class_ * 2 + 1] * n_frames_per_gt +
186 | [bg_class_index] * n_frames_per_gt +
187 | [bg_class_index] * n_frames_per_gt * 2)
188 | state_target_mask = torch.FloatTensor([1.] * n_frames_per_gt * 3 + [0.] * n_frames_per_gt * 2)
189 | # [ CS1, CS1, CS1,
190 | # CS2, CS2, CS2,
191 | # BG, BG, BG,
192 | # *, *, *,
193 | # *, *, *]
194 |
195 | batch_for_training.append((
196 | inputs[indices_extended], action_targets, state_targets, state_target_mask, weight))
197 |
198 | # FORWARD + BACKWARD PASS
199 | predictions = model_parallel(torch.cat([x[0] for x in batch_for_training], 0).cuda(local_rank))
200 |
201 | if batch_for_training[0][4] is None:
202 | video_loss_weight = torch.FloatTensor([1. for _ in batch_for_training]).view(-1, 1).cuda(local_rank)
203 | else:
204 | video_loss_weight = torch.FloatTensor([x[4] for x in batch_for_training]) * (-1 / 0.001)
205 | video_loss_weight = 1 / (1 + torch.exp(video_loss_weight))
206 | video_loss_weight = video_loss_weight.view(-1, 1).cuda(local_rank)
207 |
208 | state_gt = torch.cat([x[2] for x in batch_for_training], 0).cuda(local_rank)
209 | state_gt_mask = torch.cat([x[3] for x in batch_for_training], 0).cuda(local_rank)
210 | action_gt = torch.cat([x[1] for x in batch_for_training], 0).cuda(local_rank)
211 |
212 | state_loss = F.cross_entropy(predictions["state"], state_gt, reduction="none") * state_gt_mask
213 | state_loss = state_loss.view(-1, n_frames_per_gt * 5) * video_loss_weight
214 | action_loss = F.cross_entropy(predictions["action"], action_gt, reduction="none")
215 | action_loss = action_loss.view(-1, n_frames_per_gt * 5) * video_loss_weight
216 |
217 | state_loss = torch.sum(state_loss)
218 | action_loss = 0.2 * torch.sum(action_loss)
219 | loss = state_loss + action_loss
220 |
221 | # DistributedDataParallel does gradient averaging, i.e. loss is x-times smaller than in Look for the Change.
222 | # When training with frozen backbone, make it somewhat equivalent to the Look for the Change setup.
223 | if not args.train_backbone:
224 | loss = loss * world_size
225 | loss.backward()
226 |
227 | optim_head.step()
228 | scheduler_head.step()
229 | if args.train_backbone:
230 | optim_backbone.step()
231 | scheduler_backbone.step()
232 |
233 | loss_metric.update(loss.item(), len(batch_for_training))
234 | unsup_state_loss_metric.update(state_loss.item(), len(batch_for_training))
235 | unsup_action_loss_metric.update(action_loss.item(), len(batch_for_training))
236 |
237 | ###############
238 | # VALIDATION
239 | ###############
240 | joint_meter = JointMeter(train_ds.n_classes)
241 | for batch in test_ds_iter:
242 | _, class_, inputs, annot, _ = batch[0]
243 | classes = torch.LongTensor([class_])
244 |
245 | with torch.no_grad():
246 | predictions = model(inputs.cuda(local_rank))
247 | st_probs = select_correct_classes(
248 | torch.softmax(predictions["state"].unsqueeze(0), -1), classes, n_classes=train_ds.n_classes)
249 | ac_probs = select_correct_classes(
250 | torch.softmax(predictions["action"].unsqueeze(0), -1), classes, n_classes=train_ds.n_classes + 1)
251 |
252 | joint_meter.log(ac_probs[0, :, 0].cpu().numpy(), st_probs[0].cpu().numpy(), annot, category=class_)
253 |
254 | vallog_fn = "{}.{}".format(os.path.realpath("vallog"), cluster_args["job_id"])
255 | joint_meter.dump(vallog_fn, global_rank)
256 | dist.barrier()
257 |
258 | if global_rank == 0:
259 | time.sleep(10)
260 | joint_meter.load(vallog_fn)
261 | dir_ = f'logs/{cluster_args["job_id"]}'
262 | os.makedirs(dir_, exist_ok=True)
263 | torch.save({"state_dict": model_parallel.state_dict(), "args": model.args, "n_classes": model.n_classes,
264 | "hidden_mlp_layers": model.hidden_mlp_layers}, f"{dir_}/epoch{epoch:03d}.pth")
265 |
266 | print(f"Epoch {epoch} ("
267 | f"T loss: {loss_metric.value:.3f}, "
268 | f"T lr: {scheduler_head.get_last_lr()[0]:.6f}, "
269 | f"T grad norm: {loss_norm_metric.value:.1f}, "
270 | f"T unsup state loss: {unsup_state_loss_metric.value:.3f}, "
271 | f"T unsup action loss: {unsup_action_loss_metric.value:.3f}, "
272 | f"V state acc: {joint_meter.acc:.1f}%, "
273 | f"V state prec: {joint_meter.sp:.1f}%, "
274 | f"V state joint prec: {joint_meter.jsp:.1f}%, "
275 | f"V action prec: {joint_meter.ap:.1f}%, "
276 | f"V action joint prec: {joint_meter.jap:.1f}%)", flush=True)
277 |
278 | print("> {:20} {:>6} {:>6} {:>6} {:>6} {:>6}".format("CATEGORY", "SAcc", "SP", "JtSP", "AP", "JtAP"))
279 | print("\n".join([
280 | "> {:20}{:6.1f}%{:6.1f}%{:6.1f}%{:6.1f}%{:6.1f}%".format(cls_name, *joint_meter[train_ds.classes[cls_name]])
281 | for cls_name in sorted(train_ds.classes.keys())
282 | ]), flush=True)
283 |
284 |
285 | if __name__ == '__main__':
286 | parser = argparse.ArgumentParser()
287 | parser.add_argument("--video_roots", type=str, nargs="+", default=["./videos"])
288 | parser.add_argument("--dataset_root", type=str, default="./ChangeIt")
289 | parser.add_argument("--lr", default=0.0001, type=float)
290 | parser.add_argument("--lr_backbone", default=0.00001, type=float)
291 | parser.add_argument("--weight_decay", type=float, default=0.001)
292 | parser.add_argument("--weight_decay_backbone", type=float, default=0.)
293 | parser.add_argument("--train_backbone", action="store_true")
294 | parser.add_argument("--n_frames_per_gt", type=int, default=3)
295 | parser.add_argument("--local_batch_size", type=int, default=2)
296 | parser.add_argument("--clip_weights", type=str, default="./weights/ViT-L-14.pt")
297 | parser.add_argument("--n_epochs", type=int, default=20)
298 | parser.add_argument("--augment", action="store_true")
299 | parser.add_argument("--ignore_video_weight", action="store_true")
300 | main(parser.parse_args())
301 |
--------------------------------------------------------------------------------