├── README.md ├── WPT_Filters ├── bior22_wpt_3.mat ├── db2_wpt_3.mat └── haar_wpt_3.mat ├── csv_files └── README.md ├── data_download.sh ├── data_unzip.sh ├── demo_feat.py ├── demo_score.py ├── model_io.py ├── models ├── LIVE_ETRI.save ├── LIVE_YT_HFR.save ├── README.md └── YouTube_UGC.save ├── modules ├── CONTRIQUE_model.py ├── GRUModel.py ├── __init__.py ├── __pycache__ │ ├── CONTRIQUE_model.cpython-37.pyc │ ├── GRUModel.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── configure_optimizers.cpython-37.pyc │ ├── dataset_loader.cpython-37.pyc │ ├── gather.cpython-37.pyc │ ├── network.cpython-37.pyc │ ├── nt_xent_multiclass.cpython-37.pyc │ └── scheduler.cpython-37.pyc ├── configure_optimizers.py ├── dataset_loader.py ├── gather.py ├── network.py ├── nt_xent_multiclass.py ├── scheduler.py └── sync_batchnorm │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-39.pyc │ ├── batchnorm.cpython-36.pyc │ ├── batchnorm.cpython-37.pyc │ ├── batchnorm.cpython-39.pyc │ ├── comm.cpython-36.pyc │ ├── comm.cpython-37.pyc │ ├── comm.cpython-39.pyc │ ├── replicate.cpython-36.pyc │ ├── replicate.cpython-37.pyc │ └── replicate.cpython-39.pyc │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── requirements.txt ├── sample_videos ├── 30.mp4 └── Flips_crf_48_30fps.webm ├── spatial_feature_extract_syn.py ├── spatial_feature_extract_ugc.py ├── train.py └── train_regressor.py /README.md: -------------------------------------------------------------------------------- 1 | # CONVIQT: Contrastive Video Quality Estimator 2 | 3 | **Pavan C. Madhusudana**, Neil Birkbeck, Yilin Wang, Balu Adsumilli and Alan C. Bovik 4 | 5 | This is the official repository of the paper [CONVIQT: Contrastive Video Quality Estimator](https://arxiv.org/abs/2110.13266) 6 | 7 | ## Usage 8 | The code has been tested on Linux systems with python 3.7. Please refer to [requirements.txt](requirements.txt) for installing dependent packages. 9 | 10 | ### Running CONVIQT 11 | In order to obtain quality score, checkpoints needs to be downloaded. The following command can be used to download the checkpoint. 12 | ``` 13 | wget -L https://utexas.box.com/shared/static/rhpa8nkcfzpvdguo97n2d5dbn4qb03z8.tar -O models/CONTRIQUE_checkpoint25.tar -q --show-progress 14 | wget -L https://utexas.box.com/shared/static/7s8348b0imqe27qkgq8lojfc2od1631a.tar -O models/CONVIQT_checkpoint10.tar -q --show-progress 15 | ``` 16 | Alternatively, the checkpoints can also be downloaded using these links [link1](https://utexas.box.com/s/rhpa8nkcfzpvdguo97n2d5dbn4qb03z8) and [link2](https://utexas.box.com/s/7s8348b0imqe27qkgq8lojfc2od1631a) 17 | 18 | Google drive link for the checkpoints [link1](https://drive.google.com/file/d/1pmaomNVFhDgPSREgHBzZSu-SuGzNJyEt/view?usp=drive_web) and [link2](https://drive.google.com/file/d/1f3h8gha8YbuLTngzAkmf7MB79rYcywin/view?usp=sharing) 19 | 20 | ### Obtaining Quality Scores 21 | We provide trained regressor models in [models](models) directory which can be used for predicting image quality using features obtained from CONVIQT model. For demonstration purposes, some sample videos provided in the [sample_videos](sample_videos) folder. 22 | 23 | For blind quality prediction, the following commands can be used. 24 | ``` 25 | python3 demo_score.py --video_path sample_videos/30.mp4 --spatial_model_path models/CONTRIQUE_checkpoint25.tar --temporal_model_path models/CONVIQT_checkpoint10.tar --linear_regressor_path models/YouTube_UGC.save 26 | python3 demo_score.py --video_path sample_videos/Flips_crf_48_30fps.webm --spatial_model_path models/CONTRIQUE_checkpoint25.tar --temporal_model_path models/CONVIQT_checkpoint10.tar --linear_regressor_path models/LIVE_YT_HFR.save 27 | ``` 28 | 29 | ### Obtaining CONVIQT Features 30 | For calculating CONVIQT features, the following commands can be used. The features are saved in '.npy' format. 31 | ``` 32 | python3 demo_feat.py --video_path sample_videos/30.mp4 --spatial_model_path models/CONTRIQUE_checkpoint25.tar --temporal_model_path models/CONVIQT_checkpoint10.tar --features_save_path features.npy 33 | python3 demo_feat.py --video_path sample_videos/Flips_crf_48_30fps.webm --spatial_model_path models/CONTRIQUE_checkpoint25.tar --temporal_model_path models/CONVIQT_checkpoint10.tar --features_save_path features.npy 34 | ``` 35 | 36 | ## Training CONVIQT 37 | ### Download Training Data 38 | Create a directory ```mkdir training_data``` to store videos used for training CONVIQT. Run the following commands to download and unzip training data containing videos with synthetic distortions. 39 | ``` 40 | bash data_download.sh 41 | bash data_unzip.sh 42 | ``` 43 | 44 | For UGC videos download Kinetics dataset [link](https://www.deepmind.com/open-source/kinetics) and unzip the data. For training CONVIQT only directories parts_0-parts_10 present in Kinetics-400 dataset are needed. 45 | 46 | ### Training Model 47 | Download csv files containing path to videos and corresponding distortion classes. 48 | ``` 49 | 50 | wget -L https://utexas.box.com/shared/static/63pvroz3287j1kj7ja0gv81vw2txnwlw.csv -O csv_files/file_names_ugc.csv -q --show-progress 51 | wget -L https://utexas.box.com/shared/static/migniec2yb07vc8kz002kub658840dun.csv -O csv_files/file_names_syn.csv -q --show-progress 52 | ``` 53 | The above files can also be downloaded manually using these links [link1](https://utexas.box.com/s/63pvroz3287j1kj7ja0gv81vw2txnwlw), [link2](https://utexas.box.com/s/migniec2yb07vc8kz002kub658840dun) 54 | Google drive links [link1](https://drive.google.com/file/d/1N7EGdS-mobbcWmJUOvblCyQ_xWQ4hwjc/view?usp=sharing), [link2](https://drive.google.com/file/d/109w9c6t8EAEP_yrLKJtzsYAB1-QbLXoy/view?usp=sharing) 55 | 56 | ### Spatial Feature Extraction 57 | Spatial features are extracted using [CONTRIQUE](https://github.com/pavancm/CONTRIQUE) model using the following commands 58 | ``` 59 | python3 spatial_feature_extract_syn.py 60 | python3 spatial_feature_extract_ugc.py 61 | ``` 62 | Extracted spatial features are saved in training_data directory. 63 | 64 | For training CONVIQT with a single GPU the following command can be used 65 | ``` 66 | python3 train.py --batch_size 256 --lr 0.6 --epochs 25 67 | ``` 68 | 69 | ### Training Linear Regressor 70 | After CONVIQT model training is complete, a linear regressor is trained using CONVIQT features and corresponding ground truth quality scores using the following command. 71 | 72 | ``` 73 | python3 train_regressor.py --feat_path feat.npy --ground_truth_path scores.npy --alpha 0.1 74 | ``` 75 | 76 | ## Contact 77 | Please contact Pavan (pavan.madhusudana@gmail.com) if you have any questions, suggestions or corrections to the above implementation. 78 | 79 | ## Acknowledgement 80 | This repository is built on the [CONTRIQUE](https://github.com/pavancm/CONTRIQUE) repository. 81 | -------------------------------------------------------------------------------- /WPT_Filters/bior22_wpt_3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/WPT_Filters/bior22_wpt_3.mat -------------------------------------------------------------------------------- /WPT_Filters/db2_wpt_3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/WPT_Filters/db2_wpt_3.mat -------------------------------------------------------------------------------- /WPT_Filters/haar_wpt_3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/WPT_Filters/haar_wpt_3.mat -------------------------------------------------------------------------------- /csv_files/README.md: -------------------------------------------------------------------------------- 1 | Directory containing csv files having video paths and corresponding distortion class labels 2 | -------------------------------------------------------------------------------- /data_download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #Download Waterloo1k directory 4 | wget -L https://utexas.box.com/shared/static/k1eerw4kfv8v8uzvxhkb3mpcsvi66y97 -O training_data/Waterloo1k_aa.zip -q --show-progress 5 | wget -L https://utexas.box.com/shared/static/qugg2e8rh9rcxlwk2oml3dxaellsebd7 -O training_data/Waterloo1k_ab.zip -q --show-progress 6 | wget -L https://utexas.box.com/shared/static/9fquu19a5ysx5m6zf2luhapz48vwa0jj -O training_data/Waterloo1k_ac.zip -q --show-progress 7 | wget -L https://utexas.box.com/shared/static/wrt8dvlk6lknkqvco8ocv7bmh34zgo75 -O training_data/Waterloo1k_ad.zip -q --show-progress 8 | wget -L https://utexas.box.com/shared/static/xna355b0lpr3qlzeqtv0u4gqo7mw5aws -O training_data/Waterloo1k_ae.zip -q --show-progress 9 | wget -L https://utexas.box.com/shared/static/bjadyqd2rudu5afiscn08vm8md0y1cfv -O training_data/Waterloo1k_af.zip -q --show-progress 10 | wget -L https://utexas.box.com/shared/static/vovv6u7rcp8um52bl2aq1kjxnves21ep -O training_data/Waterloo1k_ag.zip -q --show-progress 11 | wget -L https://utexas.box.com/shared/static/gz9twwsvn3r62ilhrqtsi0ew0qe1te2e -O training_data/Waterloo1k_ah.zip -q --show-progress 12 | wget -L https://utexas.box.com/shared/static/x1aom92stivg7iexs62dhnx5eljul1ez -O training_data/Waterloo1k_ai.zip -q --show-progress 13 | cat training_data/Waterloo1k_* > training_data/Waterloo1k.zip 14 | rm -f training_data/Waterloo1k_* 15 | 16 | #Download REDs directory 17 | wget -L https://utexas.box.com/shared/static/z2c2oufm8qdy6qey5hc8s64bjnyba6hz -O training_data/REDS_aa.zip -q --show-progress 18 | wget -L https://utexas.box.com/shared/static/b425i9fs9hk5o8rkg6i63wy7rmzy8xk0 -O training_data/REDS_ab.zip -q --show-progress 19 | wget -L https://utexas.box.com/shared/static/txbaf6fv3pszdk4oymxo8b6hq3ai7ny0 -O training_data/REDS_ac.zip -q --show-progress 20 | wget -L https://utexas.box.com/shared/static/603b5g1v16l07gxcegkvoyqe8r2ya2xb -O training_data/REDS_ad.zip -q --show-progress 21 | wget -L https://utexas.box.com/shared/static/u25pldqucozb8vyh0g6v3e1wqjlxnqyq -O training_data/REDS_ae.zip -q --show-progress 22 | wget -L https://utexas.box.com/shared/static/soedh5w6te5alx04u0fzn96lbq4xce1h -O training_data/REDS_af.zip -q --show-progress 23 | wget -L https://utexas.box.com/shared/static/78ma8l45akp0bzeyo40foc822frveyo5 -O training_data/REDS_ag.zip -q --show-progress 24 | wget -L https://utexas.box.com/shared/static/ptabanwvqyi7cbheh5fcs2ldejkieehd -O training_data/REDS_ah.zip -q --show-progress 25 | wget -L https://utexas.box.com/shared/static/kw569pa3umhnoohy0b8scjbb5uic6y6u -O training_data/REDS_ai.zip -q --show-progress 26 | wget -L https://utexas.box.com/shared/static/7v7sgtclulpr87elzu5mqpip3jjo5ash -O training_data/REDS_aj.zip -q --show-progress 27 | wget -L https://utexas.box.com/shared/static/tf5mfe9lgtqht6m0fqapbkmmg4wb682i -O training_data/REDS_ak.zip -q --show-progress 28 | wget -L https://utexas.box.com/shared/static/lgicp39q12zqo4apb7321761me4bpiz4 -O training_data/REDS_al.zip -q --show-progress 29 | wget -L https://utexas.box.com/shared/static/687abgp51h54wzf0xy8r06ksfwb5yln5 -O training_data/REDS_am.zip -q --show-progress 30 | wget -L https://utexas.box.com/shared/static/j9exrex72r8ili6u5h2ewsl9r7u009os -O training_data/REDS_an.zip -q --show-progress 31 | cat training_data/REDS_* > training_data/REDS.zip 32 | rm -f training_data/REDS_* 33 | 34 | #Download UVG directory 35 | wget -L https://utexas.box.com/shared/static/hg8w5w6kb6m8exzydm2w8rnve74dh70l.zip -O training_data/UVG.zip -q --show-progress 36 | 37 | #Download MCML directory 38 | wget -L https://utexas.box.com/shared/static/qs3rte5a0eq342qy82us06428t79igfj.zip -O training_data/MCML.zip -q --show-progress 39 | 40 | #Download dareful directory 41 | wget -L https://utexas.box.com/shared/static/xkhm1scirbin1dmz7f45t2tikgoslke8 -O training_data/dareful_aa.zip -q --show-progress 42 | wget -L https://utexas.box.com/shared/static/aiwzplg1p5h0axxxiw1a5zdkydfprm1i -O training_data/dareful_ab.zip -q --show-progress 43 | wget -L https://utexas.box.com/shared/static/fqqnx21bsumlsib15cjr5ifg71rzvz17 -O training_data/dareful_ac.zip -q --show-progress 44 | wget -L https://utexas.box.com/shared/static/u1upzahh8or4edclhl8tder4di789fgw -O training_data/dareful_ad.zip -q --show-progress 45 | wget -L https://utexas.box.com/shared/static/zvuwx8e6ptb4y55fhrv3t3ljhgmmx8ga -O training_data/dareful_ae.zip -q --show-progress 46 | wget -L https://utexas.box.com/shared/static/89gy3gicdbllxduznqg9osekc3msq184 -O training_data/dareful_af.zip -q --show-progress 47 | wget -L https://utexas.box.com/shared/static/13oxybjwcm5igoi9k7abqc167d7cz42m -O training_data/dareful_ag.zip -q --show-progress 48 | wget -L https://utexas.box.com/shared/static/ojpvuvj3ru8ot3ftxlp46x6dcl4u773b -O training_data/dareful_ah.zip -q --show-progress 49 | wget -L https://utexas.box.com/shared/static/tn77ve7n423dcnesii47o9s2bimj6g2d -O training_data/dareful_ai.zip -q --show-progress 50 | wget -L https://utexas.box.com/shared/static/ao7ooqvnnw8tihhwgkqbyr67uunygd32 -O training_data/dareful_aj.zip -q --show-progress 51 | wget -L https://utexas.box.com/shared/static/faq3nz4ov78cawwcwo3ecr506k3e24bz -O training_data/dareful_ak.zip -q --show-progress 52 | wget -L https://utexas.box.com/shared/static/q0ex7acbgyoantj8lk1tn3ihbxkeetus -O training_data/dareful_al.zip -q --show-progress 53 | cat training_data/dareful_* > training_data/dareful.zip 54 | rm -f training_data/dareful_* 55 | -------------------------------------------------------------------------------- /data_unzip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | unzip training_data/Waterloo1k.zip 3 | rm -f training_data/Waterloo1k.zip 4 | unzip training_data/REDS.zip 5 | rm -f training_data/REDS.zip 6 | unzip training_data/dareful.zip 7 | rm -f training_data/dareful.zip 8 | unzip training_data/UVG.zip 9 | rm -f training_data/UVG.zip 10 | unzip training_data/MCML.zip 11 | rm -f training_data/MCML.zip 12 | -------------------------------------------------------------------------------- /demo_feat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.network import get_network 3 | from modules.CONTRIQUE_model import CONTRIQUE_model 4 | from modules.GRUModel import GRUModel 5 | from torchvision import transforms 6 | import numpy as np 7 | 8 | import os 9 | import argparse 10 | import pickle 11 | import skvideo.io 12 | from PIL import Image 13 | 14 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 15 | 16 | class torch_transform: 17 | def __init__(self, size): 18 | self.transform1 = transforms.Compose( 19 | [ 20 | transforms.Resize((size[0],size[1])), 21 | transforms.ToTensor(), 22 | ] 23 | ) 24 | 25 | self.transform2 = transforms.Compose( 26 | [ 27 | transforms.Resize((size[0] // 2, size[1] // 2)), 28 | transforms.ToTensor(), 29 | ] 30 | ) 31 | 32 | def __call__(self, x): 33 | return self.transform1(x), self.transform2(x) 34 | 35 | def create_data_loader(image, image_2, batch_size): 36 | train = torch.utils.data.TensorDataset(image, image_2) 37 | loader = torch.utils.data.DataLoader( 38 | train, 39 | batch_size=batch_size, 40 | drop_last=True, 41 | num_workers=12, 42 | sampler=None, 43 | shuffle=False 44 | ) 45 | return loader 46 | 47 | def extract_features(args, model, loader): 48 | feat = [] 49 | 50 | model.eval() 51 | for step, (batch_im, batch_im_2) in enumerate(loader): 52 | batch_im = batch_im.type(torch.float32) 53 | batch_im_2 = batch_im_2.type(torch.float32) 54 | 55 | batch_im = batch_im.cuda(non_blocking=True) 56 | batch_im_2 = batch_im_2.cuda(non_blocking=True) 57 | 58 | with torch.no_grad(): 59 | _,_, _, _, model_feat, model_feat_2, _, _ = model(batch_im, batch_im_2) 60 | 61 | feat_ = np.hstack((model_feat.detach().cpu().numpy(),\ 62 | model_feat_2.detach().cpu().numpy())) 63 | feat.extend(feat_) 64 | return np.array(feat) 65 | 66 | def extract_features_temporal(args, model, loader): 67 | feat = [] 68 | 69 | model.eval() 70 | for step, (batch_im, batch_im_2) in enumerate(loader): 71 | batch_im = batch_im.type(torch.float32) 72 | batch_im_2 = batch_im_2.type(torch.float32) 73 | 74 | batch_im = batch_im.cuda(non_blocking=True).unsqueeze(0) 75 | batch_im_2 = batch_im_2.cuda(non_blocking=True).unsqueeze(0) 76 | 77 | with torch.no_grad(): 78 | _, _, model_feat, model_feat_2 = model(batch_im, batch_im_2) 79 | 80 | feat_ = np.hstack((model_feat.detach().cpu().numpy(),\ 81 | model_feat_2.detach().cpu().numpy())) 82 | feat.extend(feat_) 83 | return np.array(feat) 84 | 85 | def main(args): 86 | # load video 87 | video = skvideo.io.FFmpegReader(args.video_path) 88 | T, height, width, C = video.getShape() 89 | 90 | #define torch transform for 2 spatial scales 91 | transform = torch_transform((height, width)) 92 | 93 | #define arrays to store frames 94 | frames = torch.zeros((T,3,height,width), dtype=torch.float16) 95 | frames_2 = torch.zeros((T,3,height// 2,width// 2), dtype=torch.float16) 96 | 97 | # read every video frame 98 | for frame_ind in range(T): 99 | inp_frame = Image.fromarray(next(video)) 100 | inp_frame, inp_frame_2 = transform(inp_frame) 101 | frames[frame_ind],frames_2[frame_ind] = \ 102 | inp_frame.type(torch.float16), inp_frame_2.type(torch.float16) 103 | 104 | # convert to torch tensors 105 | loader = create_data_loader(frames, frames_2, args.num_frames) 106 | 107 | # load CONTRIQUE Model 108 | encoder = get_network('resnet50', pretrained=False) 109 | model = CONTRIQUE_model(args, encoder, 2048) 110 | model.load_state_dict(torch.load(args.spatial_model_path, map_location=args.device.type)) 111 | model = model.to(args.device) 112 | 113 | # extract CONTRIQUE features 114 | video_feat = extract_features(args, model, loader) 115 | 116 | #load CONVIQT model 117 | temporal_model = GRUModel(c_in = 2048, hidden_size = 1024, \ 118 | projection_dim = 128, normalize = True,\ 119 | num_layers = 1) 120 | temporal_model.load_state_dict(torch.load(args.temporal_model_path, \ 121 | map_location=args.device.type)) 122 | temporal_model = temporal_model.to(args.device) 123 | 124 | #extract CONVIQT features 125 | feat_frames = torch.from_numpy(video_feat[:,:2048]) 126 | feat_frames_2 = torch.from_numpy(video_feat[:,2048:]) 127 | loader = create_data_loader(feat_frames, feat_frames_2, \ 128 | args.num_frames) 129 | video_feat = extract_features_temporal(args, temporal_model, loader) 130 | 131 | # save features 132 | np.save(args.feature_save_path, video_feat) 133 | print('Done') 134 | 135 | def parse_args(): 136 | parser = argparse.ArgumentParser() 137 | 138 | parser.add_argument('--video_path', type=str, \ 139 | default='sample_videos/30.mp4', \ 140 | help='Path to video', metavar='') 141 | parser.add_argument('--spatial_model_path', type=str, \ 142 | default='models/CONTRIQUE_checkpoint25.tar', \ 143 | help='Path to trained CONTRIQUE model', metavar='') 144 | parser.add_argument('--temporal_model_path', type=str, \ 145 | default='models/CONVIQT_checkpoint10.tar', \ 146 | help='Path to trained CONVIQT model', metavar='') 147 | parser.add_argument('--linear_regressor_path', type=str, \ 148 | default='models/YouTube_UGC.save', \ 149 | help='Path to trained linear regressor', metavar='') 150 | parser.add_argument('--num_frames', type=int, \ 151 | default=16, \ 152 | help='number of frames fed to GRU', metavar='') 153 | parser.add_argument('--feature_save_path', type=str, \ 154 | default='CONVIQT_feat.npy', \ 155 | help='path to save features', metavar='') 156 | args = parser.parse_args() 157 | args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 158 | return args 159 | 160 | if __name__ == '__main__': 161 | args = parse_args() 162 | main(args) -------------------------------------------------------------------------------- /demo_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.network import get_network 3 | from modules.CONTRIQUE_model import CONTRIQUE_model 4 | from modules.GRUModel import GRUModel 5 | from torchvision import transforms 6 | import numpy as np 7 | 8 | import os 9 | import argparse 10 | import pickle 11 | import skvideo.io 12 | from PIL import Image 13 | 14 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 15 | 16 | class torch_transform: 17 | def __init__(self, size): 18 | self.transform1 = transforms.Compose( 19 | [ 20 | transforms.Resize((size[0],size[1])), 21 | transforms.ToTensor(), 22 | ] 23 | ) 24 | 25 | self.transform2 = transforms.Compose( 26 | [ 27 | transforms.Resize((size[0] // 2, size[1] // 2)), 28 | transforms.ToTensor(), 29 | ] 30 | ) 31 | 32 | def __call__(self, x): 33 | return self.transform1(x), self.transform2(x) 34 | 35 | def create_data_loader(image, image_2, batch_size): 36 | train = torch.utils.data.TensorDataset(image, image_2) 37 | loader = torch.utils.data.DataLoader( 38 | train, 39 | batch_size=batch_size, 40 | drop_last=True, 41 | num_workers=12, 42 | sampler=None, 43 | shuffle=False 44 | ) 45 | return loader 46 | 47 | def extract_features(args, model, loader): 48 | feat = [] 49 | 50 | model.eval() 51 | for step, (batch_im, batch_im_2) in enumerate(loader): 52 | batch_im = batch_im.type(torch.float32) 53 | batch_im_2 = batch_im_2.type(torch.float32) 54 | 55 | batch_im = batch_im.cuda(non_blocking=True) 56 | batch_im_2 = batch_im_2.cuda(non_blocking=True) 57 | 58 | with torch.no_grad(): 59 | _,_, _, _, model_feat, model_feat_2, _, _ = model(batch_im, batch_im_2) 60 | 61 | feat_ = np.hstack((model_feat.detach().cpu().numpy(),\ 62 | model_feat_2.detach().cpu().numpy())) 63 | feat.extend(feat_) 64 | return np.array(feat) 65 | 66 | def extract_features_temporal(args, model, loader): 67 | feat = [] 68 | 69 | model.eval() 70 | for step, (batch_im, batch_im_2) in enumerate(loader): 71 | batch_im = batch_im.type(torch.float32) 72 | batch_im_2 = batch_im_2.type(torch.float32) 73 | 74 | batch_im = batch_im.cuda(non_blocking=True).unsqueeze(0) 75 | batch_im_2 = batch_im_2.cuda(non_blocking=True).unsqueeze(0) 76 | 77 | with torch.no_grad(): 78 | _, _, model_feat, model_feat_2 = model(batch_im, batch_im_2) 79 | 80 | feat_ = np.hstack((model_feat.detach().cpu().numpy(),\ 81 | model_feat_2.detach().cpu().numpy())) 82 | feat.extend(feat_) 83 | return np.array(feat) 84 | 85 | def main(args): 86 | # load video 87 | video = skvideo.io.FFmpegReader(args.video_path) 88 | T, height, width, C = video.getShape() 89 | 90 | #define torch transform for 2 spatial scales 91 | transform = torch_transform((height, width)) 92 | 93 | #define arrays to store frames 94 | frames = torch.zeros((T,3,height,width), dtype=torch.float16) 95 | frames_2 = torch.zeros((T,3,height// 2,width// 2), dtype=torch.float16) 96 | 97 | # read every video frame 98 | for frame_ind in range(T): 99 | inp_frame = Image.fromarray(next(video)) 100 | inp_frame, inp_frame_2 = transform(inp_frame) 101 | frames[frame_ind],frames_2[frame_ind] = \ 102 | inp_frame.type(torch.float16), inp_frame_2.type(torch.float16) 103 | 104 | # convert to torch tensors 105 | loader = create_data_loader(frames, frames_2, args.num_frames) 106 | 107 | # load CONTRIQUE Model 108 | encoder = get_network('resnet50', pretrained=False) 109 | model = CONTRIQUE_model(args, encoder, 2048) 110 | model.load_state_dict(torch.load(args.spatial_model_path, map_location=args.device.type)) 111 | model = model.to(args.device) 112 | 113 | # extract CONTRIQUE features 114 | video_feat = extract_features(args, model, loader) 115 | 116 | #load CONVIQT model 117 | temporal_model = GRUModel(c_in = 2048, hidden_size = 1024, \ 118 | projection_dim = 128, normalize = True,\ 119 | num_layers = 1) 120 | temporal_model.load_state_dict(torch.load(args.temporal_model_path, \ 121 | map_location=args.device.type)) 122 | temporal_model = temporal_model.to(args.device) 123 | 124 | #extract CONVIQT features 125 | feat_frames = torch.from_numpy(video_feat[:,:2048]) 126 | feat_frames_2 = torch.from_numpy(video_feat[:,2048:]) 127 | loader = create_data_loader(feat_frames, feat_frames_2, \ 128 | args.num_frames) 129 | video_feat = extract_features_temporal(args, temporal_model, loader) 130 | 131 | # load regressor model 132 | regressor = pickle.load(open(args.linear_regressor_path, 'rb')) 133 | score = regressor.predict(video_feat)[0] 134 | print(score) 135 | 136 | def parse_args(): 137 | parser = argparse.ArgumentParser() 138 | 139 | parser.add_argument('--video_path', type=str, \ 140 | default='sample_videos/30.mp4', \ 141 | help='Path to video', metavar='') 142 | parser.add_argument('--spatial_model_path', type=str, \ 143 | default='models/CONTRIQUE_checkpoint25.tar', \ 144 | help='Path to trained CONTRIQUE model', metavar='') 145 | parser.add_argument('--temporal_model_path', type=str, \ 146 | default='models/CONVIQT_checkpoint10.tar', \ 147 | help='Path to trained CONVIQT model', metavar='') 148 | parser.add_argument('--linear_regressor_path', type=str, \ 149 | default='models/YouTube_UGC.save', \ 150 | help='Path to trained linear regressor', metavar='') 151 | parser.add_argument('--num_frames', type=int, \ 152 | default=16, \ 153 | help='number of frames fed to GRU', metavar='') 154 | args = parser.parse_args() 155 | args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 156 | return args 157 | 158 | if __name__ == '__main__': 159 | args = parse_args() 160 | main(args) -------------------------------------------------------------------------------- /model_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | def save_model(args, model, optimizer): 5 | out = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.current_epoch)) 6 | 7 | # To save a DataParallel model generically, save the model.module.state_dict(). 8 | # This way, you have the flexibility to load the model any way you want to any device you want. 9 | if args.nodes > 1: 10 | torch.save(model.module.state_dict(), out) 11 | else: 12 | if isinstance(model, torch.nn.DataParallel): 13 | torch.save(model.module.state_dict(), out) 14 | else: 15 | torch.save(model.state_dict(), out) -------------------------------------------------------------------------------- /models/LIVE_ETRI.save: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/models/LIVE_ETRI.save -------------------------------------------------------------------------------- /models/LIVE_YT_HFR.save: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/models/LIVE_YT_HFR.save -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | ## Model Nomenclatures 2 | 3 | CONTRIQUE_checkpoint25.tar : CONTRIQUE model trained for 25 epochs. 4 | 5 | CONVIQT_checkpoint10.tar : CONVIQT model trained for 10 epochs. 6 | 7 | LIVE_ETRI.save : linear regressor trained with CONVIQT features on LIVE-ETRI dataset. 8 | 9 | LIVE_YT_HFR.save : linear regressor trained with CONVIQT features on LIVE-YT-HFR dataset. 10 | 11 | YouTube_UGC.save : linear regressor trained with CONVIQT features on YouTube_UGC dataset. 12 | -------------------------------------------------------------------------------- /models/YouTube_UGC.save: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/models/YouTube_UGC.save -------------------------------------------------------------------------------- /modules/CONTRIQUE_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class CONTRIQUE_model(nn.Module): 5 | # resnet50 architecture with projector 6 | def __init__(self, args, encoder, n_features, \ 7 | patch_dim = (2,2), normalize = True, projection_dim = 128): 8 | super(CONTRIQUE_model, self).__init__() 9 | 10 | self.normalize = normalize 11 | self.encoder = nn.Sequential(*list(encoder.children())[:-2]) 12 | self.n_features = n_features 13 | self.patch_dim = patch_dim 14 | 15 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 16 | self.avgpool_patch = nn.AdaptiveAvgPool2d(patch_dim) 17 | 18 | # MLP for projector 19 | self.projector = nn.Sequential( 20 | nn.Linear(self.n_features, self.n_features, bias=False), 21 | nn.BatchNorm1d(self.n_features), 22 | nn.ReLU(), 23 | nn.Linear(self.n_features, projection_dim, bias=False), 24 | nn.BatchNorm1d(projection_dim), 25 | ) 26 | 27 | def forward(self, x_i, x_j): 28 | # global features 29 | h_i = self.encoder(x_i) 30 | h_j = self.encoder(x_j) 31 | 32 | # local features 33 | h_i_patch = self.avgpool_patch(h_i) 34 | h_j_patch = self.avgpool_patch(h_j) 35 | 36 | h_i_patch = h_i_patch.reshape(-1,self.n_features,\ 37 | self.patch_dim[0]*self.patch_dim[1]) 38 | 39 | h_j_patch = h_j_patch.reshape(-1,self.n_features,\ 40 | self.patch_dim[0]*self.patch_dim[1]) 41 | 42 | h_i_patch = torch.transpose(h_i_patch,2,1) 43 | h_i_patch = h_i_patch.reshape(-1, self.n_features) 44 | 45 | h_j_patch = torch.transpose(h_j_patch,2,1) 46 | h_j_patch = h_j_patch.reshape(-1, self.n_features) 47 | 48 | h_i = self.avgpool(h_i) 49 | h_j = self.avgpool(h_j) 50 | 51 | h_i = h_i.view(-1, self.n_features) 52 | h_j = h_j.view(-1, self.n_features) 53 | 54 | if self.normalize: 55 | h_i = nn.functional.normalize(h_i, dim=1) 56 | h_j = nn.functional.normalize(h_j, dim=1) 57 | 58 | h_i_patch = nn.functional.normalize(h_i_patch, dim=1) 59 | h_j_patch = nn.functional.normalize(h_j_patch, dim=1) 60 | 61 | # global projections 62 | z_i = self.projector(h_i) 63 | z_j = self.projector(h_j) 64 | 65 | # local projections 66 | z_i_patch = self.projector(h_i_patch) 67 | z_j_patch = self.projector(h_j_patch) 68 | 69 | return z_i, z_j, z_i_patch, z_j_patch, h_i, h_j, h_i_patch, h_j_patch -------------------------------------------------------------------------------- /modules/GRUModel.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class GRUModel(nn.Module): 4 | """ 5 | c_out inception_time output 6 | n_out model output 7 | """ 8 | def __init__(self, c_in, hidden_size, projection_dim, normalize,\ 9 | num_layers = 1, **kwargs): 10 | super().__init__(**kwargs) 11 | self.gru = nn.GRU(c_in, hidden_size, batch_first = True, \ 12 | num_layers = num_layers) 13 | self.gap = nn.AdaptiveAvgPool1d(1) 14 | self.projector = nn.Sequential( 15 | nn.Linear(hidden_size, hidden_size, bias=False), 16 | nn.BatchNorm1d(hidden_size), 17 | nn.ReLU(), 18 | nn.Linear(hidden_size, projection_dim, bias=False), 19 | nn.BatchNorm1d(projection_dim), 20 | ) 21 | self.normalize = normalize 22 | self.c_out = hidden_size 23 | 24 | def forward(self, x_i, x_j): 25 | h_i, h_in = self.gru(x_i) 26 | h_j, h_jn = self.gru(x_j) 27 | 28 | h_i = self.gap(h_i.transpose(1,2)) 29 | h_j = self.gap(h_j.transpose(1,2)) 30 | 31 | h_i = h_i.view(-1, self.c_out) 32 | h_j = h_j.view(-1, self.c_out) 33 | 34 | if self.normalize: 35 | h_i = nn.functional.normalize(h_i, dim=1) 36 | h_j = nn.functional.normalize(h_j, dim=1) 37 | 38 | z_i = self.projector(h_i) 39 | z_j = self.projector(h_j) 40 | return z_i, z_j, h_i, h_j -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .network import get_network 2 | -------------------------------------------------------------------------------- /modules/__pycache__/CONTRIQUE_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/__pycache__/CONTRIQUE_model.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/GRUModel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/__pycache__/GRUModel.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/configure_optimizers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/__pycache__/configure_optimizers.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/dataset_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/__pycache__/dataset_loader.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/gather.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/__pycache__/gather.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/__pycache__/network.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/nt_xent_multiclass.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/__pycache__/nt_xent_multiclass.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/__pycache__/scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /modules/configure_optimizers.py: -------------------------------------------------------------------------------- 1 | from modules import scheduler as sch 2 | import torch 3 | 4 | def configure_optimizers(args, model, cur_iter=-1): 5 | iters = args.iters 6 | 7 | def exclude_from_wd_and_adaptation(name): 8 | if 'bn' in name: 9 | return True 10 | 11 | param_groups = [ 12 | { 13 | 'params': [p for name, p in model.named_parameters() if not exclude_from_wd_and_adaptation(name)], 14 | 'weight_decay': args.weight_decay, 15 | 'layer_adaptation': True, 16 | }, 17 | { 18 | 'params': [p for name, p in model.named_parameters() if exclude_from_wd_and_adaptation(name)], 19 | 'weight_decay': 0., 20 | 'layer_adaptation': False, 21 | }, 22 | ] 23 | 24 | LR = args.lr 25 | 26 | if args.opt == 'sgd': 27 | optimizer = torch.optim.SGD( 28 | param_groups, 29 | lr=LR, 30 | momentum=0.9, 31 | ) 32 | elif args.opt == 'adam': 33 | optimizer = torch.optim.Adam( 34 | param_groups, 35 | lr=LR, 36 | ) 37 | else: 38 | raise NotImplementedError 39 | 40 | if args.reload: 41 | fl = torch.load(args.model_path + 'optimizer.tar') 42 | optimizer.load_state_dict(fl['optimizer']) 43 | cur_iter = fl['scheduler']['last_epoch'] - 1 44 | 45 | if args.lr_schedule == 'warmup-anneal': 46 | scheduler = sch.LinearWarmupAndCosineAnneal( 47 | optimizer, 48 | args.warmup, 49 | iters, 50 | last_epoch=cur_iter, 51 | ) 52 | elif args.lr_schedule == 'linear': 53 | scheduler = sch.LinearLR(optimizer, iters, last_epoch=cur_iter) 54 | elif args.lr_schedule == 'const': 55 | scheduler = sch.LinearWarmupAndConstant( 56 | optimizer, 57 | args.warmup, 58 | iters, 59 | last_epoch=cur_iter, 60 | ) 61 | else: 62 | raise NotImplementedError 63 | 64 | if args.reload: 65 | scheduler.load_state_dict(fl['scheduler']) 66 | 67 | return optimizer, scheduler -------------------------------------------------------------------------------- /modules/dataset_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | from torchvision import transforms 4 | import numpy as np 5 | import pandas as pd 6 | 7 | class video_data_feat(Dataset): 8 | def __init__(self, file_path, temporal_len = 16, transform = True): 9 | self.fls = pd.read_csv(file_path) 10 | self.temporal_len = temporal_len 11 | self.tranform_toT = transforms.Compose([ 12 | transforms.ToTensor(), 13 | ]) 14 | 15 | def __len__(self): 16 | return len(self.fls) 17 | 18 | def __getitem__(self, idx): 19 | if torch.is_tensor(idx): 20 | idx = idx.tolist() 21 | 22 | vid_path = self.fls.iloc[idx]['File_names'] 23 | 24 | if 'Kinetics-400' in vid_path: 25 | vid_path = vid_path.replace('Kinetics-400','/mnt/LIVELAB_NAS/Pavan/kinetics_feat') 26 | name = vid_path.split('.m')[0] 27 | elif 'REDS' in vid_path: 28 | vid_path = vid_path.replace('REDS','') 29 | vid_path = vid_path.replace('dist','/mnt/LIVELAB_NAS/Pavan/dist_feat') 30 | name = vid_path.split('.w')[0] 31 | else: 32 | vid_path = vid_path.replace('dist','/mnt/LIVELAB_NAS/Pavan/dist_feat') 33 | name = vid_path.split('.w')[0] 34 | 35 | # determine first video characteristics 36 | div_factor1 = np.random.choice([1,2],1)[0] 37 | colorspace_choice1 = 4 38 | temporal_choice1 = 3 39 | 40 | feat_path1 = name + '_color' + str(colorspace_choice1) +\ 41 | '_temp' + str(temporal_choice1) + '.npy' 42 | feat1 = np.load(feat_path1, allow_pickle = True) 43 | T,D = feat1.shape 44 | 45 | #randomly sample a clip of length temporal_len 46 | T1 = np.random.randint(1, T - 1 - self.temporal_len) 47 | feat1 = feat1[T1:T1+self.temporal_len,:] 48 | 49 | #choose the scale 50 | if div_factor1 == 1: 51 | feat1 = feat1[:,:D // 2] 52 | else: 53 | feat1 = feat1[:,D//2:] 54 | 55 | feat1 = self.tranform_toT(feat1) 56 | 57 | # determine second video characteristics 58 | div_factor2 = 3 - div_factor1 59 | colorspace_choice2 = 4 60 | temporal_choice2 = 3 61 | 62 | feat_path2 = name + '_color' + str(colorspace_choice2) +\ 63 | '_temp' + str(temporal_choice2) + '.npy' 64 | feat2 = np.load(feat_path2, allow_pickle=True) 65 | T,D = feat2.shape 66 | 67 | #randomly sample a clip of length temporal_len 68 | T2 = np.random.randint(1, T - 1 - self.temporal_len) 69 | feat2 = feat2[T2:T2+self.temporal_len,:] 70 | 71 | #choose the scale 72 | if div_factor2 == 1: 73 | feat2 = feat2[:,:D // 2] 74 | else: 75 | feat2 = feat2[:,D//2:] 76 | 77 | feat2 = self.tranform_toT(feat2) 78 | 79 | label = self.fls.iloc[idx]['labels'] 80 | label = label[1:-1].split(' ') 81 | label = np.array([t.replace(',','') for t in label]).astype(np.float32) 82 | 83 | return feat1, feat2, label -------------------------------------------------------------------------------- /modules/gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | class GatherLayer(torch.autograd.Function): 6 | '''Gather tensors from all process, supporting backward propagation. 7 | ''' 8 | 9 | @staticmethod 10 | def forward(ctx, input): 11 | ctx.save_for_backward(input) 12 | output = [torch.zeros_like(input) \ 13 | for _ in range(dist.get_world_size())] 14 | dist.all_gather(output, input) 15 | return tuple(output) 16 | 17 | @staticmethod 18 | def backward(ctx, *grads): 19 | input, = ctx.saved_tensors 20 | grad_out = torch.zeros_like(input) 21 | grad_out[:] = grads[dist.get_rank()] 22 | return grad_out -------------------------------------------------------------------------------- /modules/network.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | def get_network(name, pretrained=False): 4 | network = { 5 | "VGG16": torchvision.models.vgg16(pretrained=pretrained), 6 | "VGG16_bn": torchvision.models.vgg16_bn(pretrained=pretrained), 7 | "resnet18": torchvision.models.resnet18(pretrained=pretrained), 8 | "resnet34": torchvision.models.resnet34(pretrained=pretrained), 9 | "resnet50": torchvision.models.resnet50(pretrained=pretrained), 10 | "resnet101": torchvision.models.resnet101(pretrained=pretrained), 11 | "resnet152": torchvision.models.resnet152(pretrained=pretrained), 12 | } 13 | if name not in network.keys(): 14 | raise KeyError(f"{name} is not a valid network architecture") 15 | return network[name] 16 | -------------------------------------------------------------------------------- /modules/nt_xent_multiclass.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .gather import GatherLayer 4 | 5 | class NT_Xent(nn.Module): 6 | def __init__(self, batch_size, temperature, device, world_size): 7 | super(NT_Xent, self).__init__() 8 | self.batch_size = batch_size 9 | self.temperature = temperature 10 | self.device = device 11 | self.world_size = world_size 12 | 13 | def forward(self, z_i, z_j, dist_labels): 14 | 15 | N = 2 * z_i.shape[0] * self.world_size 16 | 17 | z = torch.cat((z_i, z_j), dim=0) 18 | dist_labels = torch.cat((dist_labels, dist_labels),dim=0) 19 | 20 | if self.world_size > 1: 21 | z = torch.cat(GatherLayer.apply(z), dim=0) 22 | dist_labels = torch.cat(GatherLayer.apply(dist_labels), dim=0) 23 | 24 | # calculate similarity and divide by temperature parameter 25 | z = nn.functional.normalize(z, p=2, dim=1) 26 | sim = torch.mm(z, z.T) / self.temperature 27 | dist_labels = dist_labels.cpu() 28 | 29 | positive_mask = torch.mm(dist_labels.to_sparse(), dist_labels.T) 30 | positive_mask = positive_mask.fill_diagonal_(0).to(sim.device) 31 | zero_diag = torch.ones((N, N)).fill_diagonal_(0).to(sim.device) 32 | 33 | # calculate normalized cross entropy value 34 | positive_sum = torch.sum(positive_mask, dim=1) 35 | denominator = torch.sum(torch.exp(sim)*zero_diag,dim=1) 36 | loss = torch.mean(torch.log(denominator) - \ 37 | (torch.sum(sim * positive_mask, dim=1)/positive_sum)) 38 | 39 | return loss 40 | -------------------------------------------------------------------------------- /modules/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import warnings 4 | 5 | class LinearLR(torch.optim.lr_scheduler._LRScheduler): 6 | def __init__(self, optimizer, num_epochs, last_epoch=-1): 7 | self.num_epochs = max(num_epochs, 1) 8 | super().__init__(optimizer, last_epoch) 9 | 10 | def get_lr(self): 11 | res = [] 12 | for lr in self.base_lrs: 13 | res.append(np.maximum(lr * np.minimum(-self.last_epoch * 1. / self.num_epochs + 1., 1.), 0.)) 14 | return res 15 | 16 | class LinearWarmupAndConstant(torch.optim.lr_scheduler._LRScheduler): 17 | def __init__(self, optimizer, warm_up, T_max, last_epoch=-1): 18 | self.warm_up = int(warm_up * T_max) 19 | self.T_max = T_max - self.warm_up 20 | super().__init__(optimizer, last_epoch=last_epoch) 21 | 22 | def get_lr(self): 23 | if not self._get_lr_called_within_step: 24 | warnings.warn("To get the last learning rate computed by the scheduler, " 25 | "please use `get_last_lr()`.") 26 | 27 | if self.last_epoch == 0: 28 | return [lr / (self.warm_up + 1) for lr in self.base_lrs] 29 | elif self.last_epoch <= self.warm_up: 30 | c = (self.last_epoch + 1) / self.last_epoch 31 | return [group['lr'] * c for group in self.optimizer.param_groups] 32 | else: 33 | # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493 34 | le = self.last_epoch - self.warm_up 35 | return [group['lr'] for group in self.optimizer.param_groups] 36 | 37 | class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler): 38 | def __init__(self, optimizer, warm_up, T_max, last_epoch=-1): 39 | self.warm_up = int(warm_up * T_max) 40 | self.T_max = T_max - self.warm_up 41 | super().__init__(optimizer, last_epoch=last_epoch) 42 | 43 | def get_lr(self): 44 | if not self._get_lr_called_within_step: 45 | warnings.warn("To get the last learning rate computed by the scheduler, " 46 | "please use `get_last_lr()`.") 47 | 48 | if self.last_epoch == 0: 49 | return [lr / (self.warm_up + 1) for lr in self.base_lrs] 50 | elif self.last_epoch <= self.warm_up: 51 | c = (self.last_epoch + 1) / self.last_epoch 52 | return [group['lr'] * c for group in self.optimizer.param_groups] 53 | else: 54 | # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493 55 | le = self.last_epoch - self.warm_up 56 | return [(1 + np.cos(np.pi * le / self.T_max)) / 57 | (1 + np.cos(np.pi * (le - 1) / self.T_max)) * 58 | group['lr'] 59 | for group in self.optimizer.param_groups] 60 | 61 | 62 | class BaseLR(torch.optim.lr_scheduler._LRScheduler): 63 | def get_lr(self): 64 | return [group['lr'] for group in self.optimizer.param_groups] 65 | -------------------------------------------------------------------------------- /modules/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/batchnorm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/batchnorm.cpython-39.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/comm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/comm.cpython-39.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/__pycache__/replicate.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/modules/sync_batchnorm/__pycache__/replicate.cpython-39.pyc -------------------------------------------------------------------------------- /modules/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | import contextlib 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from torch.nn.modules.batchnorm import _BatchNorm 18 | 19 | try: 20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 21 | except ImportError: 22 | ReduceAddCoalesced = Broadcast = None 23 | 24 | try: 25 | from jactorch.parallel.comm import SyncMaster 26 | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback 27 | except ImportError: 28 | from .comm import SyncMaster 29 | from .replicate import DataParallelWithCallback 30 | 31 | __all__ = [ 32 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', 33 | 'patch_sync_batchnorm', 'convert_model' 34 | ] 35 | 36 | 37 | def _sum_ft(tensor): 38 | """sum over the first and last dimention""" 39 | return tensor.sum(dim=0).sum(dim=-1) 40 | 41 | 42 | def _unsqueeze_ft(tensor): 43 | """add new dimensions at the front and the tail""" 44 | return tensor.unsqueeze(0).unsqueeze(-1) 45 | 46 | 47 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 48 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 49 | 50 | 51 | class _SynchronizedBatchNorm(_BatchNorm): 52 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): 53 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' 54 | 55 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, 56 | track_running_stats=track_running_stats) 57 | 58 | if not self.track_running_stats: 59 | import warnings 60 | warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') 61 | 62 | self._sync_master = SyncMaster(self._data_parallel_master) 63 | 64 | self._is_parallel = False 65 | self._parallel_id = None 66 | self._slave_pipe = None 67 | 68 | def forward(self, input): 69 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 70 | if not (self._is_parallel and self.training): 71 | return F.batch_norm( 72 | input, self.running_mean, self.running_var, self.weight, self.bias, 73 | self.training, self.momentum, self.eps) 74 | 75 | # Resize the input to (B, C, -1). 76 | input_shape = input.size() 77 | input = input.view(input.size(0), self.num_features, -1) 78 | 79 | # Compute the sum and square-sum. 80 | sum_size = input.size(0) * input.size(2) 81 | input_sum = _sum_ft(input) 82 | input_ssum = _sum_ft(input ** 2) 83 | 84 | # Reduce-and-broadcast the statistics. 85 | if self._parallel_id == 0: 86 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 87 | else: 88 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 89 | 90 | # Compute the output. 91 | if self.affine: 92 | # MJY:: Fuse the multiplication for speed. 93 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 94 | else: 95 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 96 | 97 | # Reshape it. 98 | return output.view(input_shape) 99 | 100 | def __data_parallel_replicate__(self, ctx, copy_id): 101 | self._is_parallel = True 102 | self._parallel_id = copy_id 103 | 104 | # parallel_id == 0 means master device. 105 | if self._parallel_id == 0: 106 | ctx.sync_master = self._sync_master 107 | else: 108 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 109 | 110 | def _data_parallel_master(self, intermediates): 111 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 112 | 113 | # Always using same "device order" makes the ReduceAdd operation faster. 114 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 115 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 116 | 117 | to_reduce = [i[1][:2] for i in intermediates] 118 | to_reduce = [j for i in to_reduce for j in i] # flatten 119 | target_gpus = [i[1].sum.get_device() for i in intermediates] 120 | 121 | sum_size = sum([i[1].sum_size for i in intermediates]) 122 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 123 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 124 | 125 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 126 | 127 | outputs = [] 128 | for i, rec in enumerate(intermediates): 129 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 130 | 131 | return outputs 132 | 133 | def _compute_mean_std(self, sum_, ssum, size): 134 | """Compute the mean and standard-deviation with sum and square-sum. This method 135 | also maintains the moving average on the master device.""" 136 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 137 | mean = sum_ / size 138 | sumvar = ssum - sum_ * mean 139 | unbias_var = sumvar / (size - 1) 140 | bias_var = sumvar / size 141 | 142 | if hasattr(torch, 'no_grad'): 143 | with torch.no_grad(): 144 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 145 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 146 | else: 147 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 148 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 149 | 150 | return mean, bias_var.clamp(self.eps) ** -0.5 151 | 152 | 153 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 154 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 155 | mini-batch. 156 | 157 | .. math:: 158 | 159 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 160 | 161 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 162 | standard-deviation are reduced across all devices during training. 163 | 164 | For example, when one uses `nn.DataParallel` to wrap the network during 165 | training, PyTorch's implementation normalize the tensor on each device using 166 | the statistics only on that device, which accelerated the computation and 167 | is also easy to implement, but the statistics might be inaccurate. 168 | Instead, in this synchronized version, the statistics will be computed 169 | over all training samples distributed on multiple devices. 170 | 171 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 172 | as the built-in PyTorch implementation. 173 | 174 | The mean and standard-deviation are calculated per-dimension over 175 | the mini-batches and gamma and beta are learnable parameter vectors 176 | of size C (where C is the input size). 177 | 178 | During training, this layer keeps a running estimate of its computed mean 179 | and variance. The running sum is kept with a default momentum of 0.1. 180 | 181 | During evaluation, this running mean/variance is used for normalization. 182 | 183 | Because the BatchNorm is done over the `C` dimension, computing statistics 184 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 185 | 186 | Args: 187 | num_features: num_features from an expected input of size 188 | `batch_size x num_features [x width]` 189 | eps: a value added to the denominator for numerical stability. 190 | Default: 1e-5 191 | momentum: the value used for the running_mean and running_var 192 | computation. Default: 0.1 193 | affine: a boolean value that when set to ``True``, gives the layer learnable 194 | affine parameters. Default: ``True`` 195 | 196 | Shape:: 197 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 198 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 199 | 200 | Examples: 201 | >>> # With Learnable Parameters 202 | >>> m = SynchronizedBatchNorm1d(100) 203 | >>> # Without Learnable Parameters 204 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 205 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 206 | >>> output = m(input) 207 | """ 208 | 209 | def _check_input_dim(self, input): 210 | if input.dim() != 2 and input.dim() != 3: 211 | raise ValueError('expected 2D or 3D input (got {}D input)' 212 | .format(input.dim())) 213 | 214 | 215 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 216 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 217 | of 3d inputs 218 | 219 | .. math:: 220 | 221 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 222 | 223 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 224 | standard-deviation are reduced across all devices during training. 225 | 226 | For example, when one uses `nn.DataParallel` to wrap the network during 227 | training, PyTorch's implementation normalize the tensor on each device using 228 | the statistics only on that device, which accelerated the computation and 229 | is also easy to implement, but the statistics might be inaccurate. 230 | Instead, in this synchronized version, the statistics will be computed 231 | over all training samples distributed on multiple devices. 232 | 233 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 234 | as the built-in PyTorch implementation. 235 | 236 | The mean and standard-deviation are calculated per-dimension over 237 | the mini-batches and gamma and beta are learnable parameter vectors 238 | of size C (where C is the input size). 239 | 240 | During training, this layer keeps a running estimate of its computed mean 241 | and variance. The running sum is kept with a default momentum of 0.1. 242 | 243 | During evaluation, this running mean/variance is used for normalization. 244 | 245 | Because the BatchNorm is done over the `C` dimension, computing statistics 246 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 247 | 248 | Args: 249 | num_features: num_features from an expected input of 250 | size batch_size x num_features x height x width 251 | eps: a value added to the denominator for numerical stability. 252 | Default: 1e-5 253 | momentum: the value used for the running_mean and running_var 254 | computation. Default: 0.1 255 | affine: a boolean value that when set to ``True``, gives the layer learnable 256 | affine parameters. Default: ``True`` 257 | 258 | Shape:: 259 | - Input: :math:`(N, C, H, W)` 260 | - Output: :math:`(N, C, H, W)` (same shape as input) 261 | 262 | Examples: 263 | >>> # With Learnable Parameters 264 | >>> m = SynchronizedBatchNorm2d(100) 265 | >>> # Without Learnable Parameters 266 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 267 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 268 | >>> output = m(input) 269 | """ 270 | 271 | def _check_input_dim(self, input): 272 | if input.dim() != 4: 273 | raise ValueError('expected 4D input (got {}D input)' 274 | .format(input.dim())) 275 | 276 | 277 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 278 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 279 | of 4d inputs 280 | 281 | .. math:: 282 | 283 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 284 | 285 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 286 | standard-deviation are reduced across all devices during training. 287 | 288 | For example, when one uses `nn.DataParallel` to wrap the network during 289 | training, PyTorch's implementation normalize the tensor on each device using 290 | the statistics only on that device, which accelerated the computation and 291 | is also easy to implement, but the statistics might be inaccurate. 292 | Instead, in this synchronized version, the statistics will be computed 293 | over all training samples distributed on multiple devices. 294 | 295 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 296 | as the built-in PyTorch implementation. 297 | 298 | The mean and standard-deviation are calculated per-dimension over 299 | the mini-batches and gamma and beta are learnable parameter vectors 300 | of size C (where C is the input size). 301 | 302 | During training, this layer keeps a running estimate of its computed mean 303 | and variance. The running sum is kept with a default momentum of 0.1. 304 | 305 | During evaluation, this running mean/variance is used for normalization. 306 | 307 | Because the BatchNorm is done over the `C` dimension, computing statistics 308 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 309 | or Spatio-temporal BatchNorm 310 | 311 | Args: 312 | num_features: num_features from an expected input of 313 | size batch_size x num_features x depth x height x width 314 | eps: a value added to the denominator for numerical stability. 315 | Default: 1e-5 316 | momentum: the value used for the running_mean and running_var 317 | computation. Default: 0.1 318 | affine: a boolean value that when set to ``True``, gives the layer learnable 319 | affine parameters. Default: ``True`` 320 | 321 | Shape:: 322 | - Input: :math:`(N, C, D, H, W)` 323 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 324 | 325 | Examples: 326 | >>> # With Learnable Parameters 327 | >>> m = SynchronizedBatchNorm3d(100) 328 | >>> # Without Learnable Parameters 329 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 330 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 331 | >>> output = m(input) 332 | """ 333 | 334 | def _check_input_dim(self, input): 335 | if input.dim() != 5: 336 | raise ValueError('expected 5D input (got {}D input)' 337 | .format(input.dim())) 338 | 339 | 340 | @contextlib.contextmanager 341 | def patch_sync_batchnorm(): 342 | import torch.nn as nn 343 | 344 | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d 345 | 346 | nn.BatchNorm1d = SynchronizedBatchNorm1d 347 | nn.BatchNorm2d = SynchronizedBatchNorm2d 348 | nn.BatchNorm3d = SynchronizedBatchNorm3d 349 | 350 | yield 351 | 352 | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup 353 | 354 | 355 | def convert_model(module): 356 | """Traverse the input module and its child recursively 357 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d 358 | to SynchronizedBatchNorm*N*d 359 | 360 | Args: 361 | module: the input module needs to be convert to SyncBN model 362 | 363 | Examples: 364 | >>> import torch.nn as nn 365 | >>> import torchvision 366 | >>> # m is a standard pytorch model 367 | >>> m = torchvision.models.resnet18(True) 368 | >>> m = nn.DataParallel(m) 369 | >>> # after convert, m is using SyncBN 370 | >>> m = convert_model(m) 371 | """ 372 | if isinstance(module, torch.nn.DataParallel): 373 | mod = module.module 374 | mod = convert_model(mod) 375 | mod = DataParallelWithCallback(mod, device_ids=module.device_ids) 376 | return mod 377 | 378 | mod = module 379 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, 380 | torch.nn.modules.batchnorm.BatchNorm2d, 381 | torch.nn.modules.batchnorm.BatchNorm3d], 382 | [SynchronizedBatchNorm1d, 383 | SynchronizedBatchNorm2d, 384 | SynchronizedBatchNorm3d]): 385 | if isinstance(module, pth_module): 386 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) 387 | mod.running_mean = module.running_mean 388 | mod.running_var = module.running_var 389 | if module.affine: 390 | mod.weight.data = module.weight.data.clone().detach() 391 | mod.bias.data = module.bias.data.clone().detach() 392 | 393 | for name, child in module.named_children(): 394 | mod.add_module(name, convert_model(child)) 395 | 396 | return mod 397 | -------------------------------------------------------------------------------- /modules/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /modules/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /modules/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /modules/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1+cu101 2 | torchvision==0.8.2+cu101 3 | numpy==1.19.2 4 | Pillow==8.3.1 5 | scikit-video==1.1.11 6 | scikit-learn==1.0.2 7 | pandas==1.3.5 8 | scipy==1.6.2 9 | -------------------------------------------------------------------------------- /sample_videos/30.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/sample_videos/30.mp4 -------------------------------------------------------------------------------- /sample_videos/Flips_crf_48_30fps.webm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavancm/CONVIQT/4204f04c1e188c616e8a812a3aa916b1e573b87f/sample_videos/Flips_crf_48_30fps.webm -------------------------------------------------------------------------------- /spatial_feature_extract_syn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.network import get_network 3 | from modules.CONTRIQUE_model import CONTRIQUE_model 4 | from torchvision import transforms 5 | import numpy as np 6 | 7 | import os 8 | import argparse 9 | from PIL import Image 10 | import skvideo.io 11 | import pandas as pd 12 | import scipy.io 13 | 14 | class torch_transform: 15 | def __init__(self, size): 16 | self.transform1 = transforms.Compose( 17 | [ 18 | transforms.Resize((size[0],size[1])), 19 | transforms.ToTensor(), 20 | ] 21 | ) 22 | 23 | self.transform2 = transforms.Compose( 24 | [ 25 | transforms.Resize((size[0] // 2, size[1] // 2)), 26 | transforms.ToTensor(), 27 | ] 28 | ) 29 | 30 | def __call__(self, x): 31 | return self.transform1(x), self.transform2(x) 32 | 33 | def rescale_video(vid, min_val, max_val): 34 | vid = (vid - min_val)/(max_val - min_val + 1e-3) 35 | return vid 36 | 37 | def temporal_filt(vid, filt_choice): 38 | # choose temporal filter (haar, db2, bior22) 39 | if filt_choice == 3: 40 | return vid 41 | 42 | filters = ['haar','db2','bior22'] 43 | num_levels = 3 44 | filt_choice = np.random.choice([0,1,2],1)[0] 45 | 46 | #load filter 47 | filt_path = 'WPT_Filters/' + filters[filt_choice] + '_wpt_' \ 48 | + str(num_levels) + '.mat' 49 | wfun = scipy.io.loadmat(filt_path) 50 | wfun = wfun['wfun'] 51 | 52 | #choose subband 53 | subband_choice = np.random.choice(np.arange(len(wfun)),1)[0] 54 | 55 | #Temporal Filtering 56 | frame_data = vid.numpy() 57 | dpt_filt = np.zeros_like(frame_data) 58 | 59 | for ch in range(3): 60 | inp = frame_data[:,ch,:,:].astype(np.float32) 61 | out = scipy.ndimage.filters.convolve1d(inp,\ 62 | wfun[subband_choice,:],axis=0,mode='constant') 63 | dpt_filt[:,ch,:,:] = out.astype(np.float16) 64 | 65 | min_val, max_val = np.min(dpt_filt), np.max(dpt_filt) 66 | dpt_filt = torch.from_numpy(dpt_filt) 67 | dpt_filt = rescale_video(dpt_filt, min_val, max_val) 68 | return dpt_filt 69 | 70 | def create_data_loader(image, image_2, batch_size): 71 | train = torch.utils.data.TensorDataset(image, image_2) 72 | loader = torch.utils.data.DataLoader( 73 | train, 74 | batch_size=batch_size, 75 | drop_last=True, 76 | num_workers=4, 77 | sampler=None, 78 | shuffle=False 79 | ) 80 | return loader 81 | 82 | def extract_features(model, loader): 83 | feat = [] 84 | 85 | model.eval() 86 | for step, (batch_im, batch_im_2) in enumerate(loader): 87 | batch_im = batch_im.type(torch.float32) 88 | batch_im_2 = batch_im_2.type(torch.float32) 89 | 90 | batch_im = batch_im.cuda(non_blocking=True) 91 | batch_im_2 = batch_im_2.cuda(non_blocking=True) 92 | with torch.no_grad(): 93 | _,_, _, _, model_feat, model_feat_2, _, _ = model(batch_im, batch_im_2) 94 | 95 | feat_ = np.hstack((model_feat.detach().cpu().numpy(),\ 96 | model_feat_2.detach().cpu().numpy())) 97 | feat.extend(feat_) 98 | return np.array(feat) 99 | 100 | def main(args): 101 | # load CONTRIQUE Model 102 | encoder = get_network('resnet50', pretrained=False) 103 | model = CONTRIQUE_model(args, encoder, 2048) 104 | model.load_state_dict(torch.load(args.model_path, map_location=args.device.type)) 105 | model = model.to(args.device) 106 | 107 | write_fdr = args.write_fdr 108 | fls = pd.read_csv(args.csv_file) 109 | fls = fls.loc[:,'File_names'].tolist() 110 | 111 | for step,vid_path in enumerate(fls): 112 | # extract features for all temporal transforms 113 | for temporal in range(4): 114 | name = vid_path.split('/')[-1] 115 | sz = int(name.split('sz')[-1].split('_')[0]) 116 | 117 | content = vid_path.split('/')[-2] 118 | 119 | if not os.path.isdir(write_fdr + content): 120 | os.system('mkdir -p ' + write_fdr + content) 121 | os.system('chmod -R 777 ' + write_fdr + content) 122 | 123 | write_name = write_fdr + content + '/' + \ 124 | name.split('.')[0] + '_temp' + str(temporal) + '.npy' 125 | 126 | print(write_name) 127 | 128 | if os.path.exists(write_name): 129 | continue 130 | 131 | path = 'training_data/' + vid_path 132 | vid_raw = skvideo.io.FFmpegReader(path) 133 | T, H, W, C = vid_raw.getShape() 134 | 135 | dist_frames = torch.zeros((T,3,H*sz,W*sz)) 136 | dist_frames = dist_frames.type(torch.float16) 137 | 138 | dist_frames_2 = torch.zeros((T,3,H*sz//2,W*sz//2)) 139 | dist_frames_2 = dist_frames_2.type(torch.float16) 140 | 141 | transform = torch_transform((H*sz,W*sz)) 142 | for frame_ind in range(T): 143 | frame = Image.fromarray(next(vid_raw)) 144 | # resize to source spatial resolution 145 | frame = frame.resize((W*sz, H*sz), Image.LANCZOS) 146 | frame, frame_2 = transform(frame) 147 | dist_frames[frame_ind],dist_frames_2[frame_ind] = \ 148 | frame.type(torch.float16), frame_2.type(torch.float16) 149 | 150 | # temporal transforms 151 | dist_frames = temporal_filt(dist_frames, temporal) 152 | dist_frames_2 = temporal_filt(dist_frames_2, temporal) 153 | 154 | loader = create_data_loader(dist_frames, dist_frames_2, args.batch_size) 155 | video_feat = extract_features(model, loader) 156 | 157 | video_feat = video_feat.astype(np.float16) 158 | np.save(write_name, video_feat) 159 | vid_raw.close() 160 | 161 | def parse_args(): 162 | parser = argparse.ArgumentParser() 163 | 164 | parser.add_argument('--model_path', type=str, \ 165 | default='models/CONTRIQUE_checkpoint25.tar', \ 166 | help='Path to trained CONTRIQUE model', metavar='') 167 | parser.add_argument('--batch_size', type=int, \ 168 | default=16, \ 169 | help='batch size', metavar='') 170 | parser.add_argument('--csv_file', type=str, \ 171 | default='csv_files/file_names_syn.csv', \ 172 | help='path for csv file with filenames', metavar='') 173 | parser.add_argument('--write_fdr', type=str, \ 174 | default='training_data/dist_feat/', \ 175 | help='write folder', metavar='') 176 | args = parser.parse_args() 177 | args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 178 | return args 179 | 180 | if __name__ == '__main__': 181 | args = parse_args() 182 | main(args) -------------------------------------------------------------------------------- /spatial_feature_extract_ugc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.network import get_network 3 | from modules.CONTRIQUE_model import CONTRIQUE_model 4 | from torchvision import transforms 5 | import numpy as np 6 | 7 | import os 8 | import argparse 9 | from PIL import Image 10 | import skvideo.io 11 | import pandas as pd 12 | import scipy.io 13 | 14 | class torch_transform: 15 | def __init__(self, size): 16 | self.transform1 = transforms.Compose( 17 | [ 18 | transforms.Resize((size[0],size[1])), 19 | transforms.ToTensor(), 20 | ] 21 | ) 22 | 23 | self.transform2 = transforms.Compose( 24 | [ 25 | transforms.Resize((size[0] // 2, size[1] // 2)), 26 | transforms.ToTensor(), 27 | ] 28 | ) 29 | 30 | def __call__(self, x): 31 | return self.transform1(x), self.transform2(x) 32 | 33 | def rescale_video(vid, min_val, max_val): 34 | vid = (vid - min_val)/(max_val - min_val + 1e-3) 35 | return vid 36 | 37 | def temporal_filt(vid, filt_choice): 38 | # choose temporal filter (haar, db2, bior22) 39 | if filt_choice == 3: 40 | return vid 41 | 42 | filters = ['haar','db2','bior22'] 43 | num_levels = 3 44 | filt_choice = np.random.choice([0,1,2],1)[0] 45 | 46 | #load filter 47 | filt_path = 'WPT_Filters/' + filters[filt_choice] + '_wpt_' \ 48 | + str(num_levels) + '.mat' 49 | wfun = scipy.io.loadmat(filt_path) 50 | wfun = wfun['wfun'] 51 | 52 | #choose subband 53 | subband_choice = np.random.choice(np.arange(len(wfun)),1)[0] 54 | 55 | #Temporal Filtering 56 | frame_data = vid.numpy() 57 | dpt_filt = np.zeros_like(frame_data) 58 | 59 | for ch in range(3): 60 | inp = frame_data[:,ch,:,:].astype(np.float32) 61 | out = scipy.ndimage.filters.convolve1d(inp,\ 62 | wfun[subband_choice,:],axis=0,mode='constant') 63 | dpt_filt[:,ch,:,:] = out.astype(np.float16) 64 | 65 | min_val, max_val = np.min(dpt_filt), np.max(dpt_filt) 66 | dpt_filt = torch.from_numpy(dpt_filt) 67 | dpt_filt = rescale_video(dpt_filt, min_val, max_val) 68 | return dpt_filt 69 | 70 | def create_data_loader(image, image_2, batch_size): 71 | train = torch.utils.data.TensorDataset(image, image_2) 72 | loader = torch.utils.data.DataLoader( 73 | train, 74 | batch_size=batch_size, 75 | drop_last=True, 76 | num_workers=4, 77 | sampler=None, 78 | shuffle=False 79 | ) 80 | return loader 81 | 82 | def extract_features(model, loader): 83 | feat = [] 84 | 85 | model.eval() 86 | for step, (batch_im, batch_im_2) in enumerate(loader): 87 | batch_im = batch_im.type(torch.float32) 88 | batch_im_2 = batch_im_2.type(torch.float32) 89 | 90 | batch_im = batch_im.cuda(non_blocking=True) 91 | batch_im_2 = batch_im_2.cuda(non_blocking=True) 92 | with torch.no_grad(): 93 | _,_, _, _, model_feat, model_feat_2, _, _ = model(batch_im, batch_im_2) 94 | 95 | feat_ = np.hstack((model_feat.detach().cpu().numpy(),\ 96 | model_feat_2.detach().cpu().numpy())) 97 | feat.extend(feat_) 98 | return np.array(feat) 99 | 100 | def main(args): 101 | # load CONTRIQUE Model 102 | encoder = get_network('resnet50', pretrained=False) 103 | model = CONTRIQUE_model(args, encoder, 2048) 104 | model.load_state_dict(torch.load(args.model_path, map_location=args.device.type)) 105 | model = model.to(args.device) 106 | 107 | write_fdr = args.write_fdr 108 | fls = pd.read_csv(args.csv_file) 109 | fls = fls.loc[:,'File_names'].tolist() 110 | 111 | for step,vid_path in enumerate(fls): 112 | # extract features for all temporal transforms 113 | for temporal in range(4): 114 | name = vid_path.split('/')[-1] 115 | 116 | content = vid_path.split('/')[-2] 117 | 118 | if not os.path.isdir(write_fdr + content): 119 | os.system('mkdir -p ' + write_fdr + content) 120 | os.system('chmod -R 777 ' + write_fdr + content) 121 | 122 | write_name = write_fdr + content + '/' + \ 123 | name.split('.')[0] + '_temp' + str(temporal) + '.npy' 124 | 125 | print(write_name) 126 | 127 | if os.path.exists(write_name): 128 | continue 129 | 130 | path = 'training_data/' + vid_path 131 | vid_raw = skvideo.io.FFmpegReader(path) 132 | T, H, W, C = vid_raw.getShape() 133 | 134 | dist_frames = torch.zeros((T,3,H,W)) 135 | dist_frames = dist_frames.type(torch.float16) 136 | 137 | dist_frames_2 = torch.zeros((T,3,H//2,W//2)) 138 | dist_frames_2 = dist_frames_2.type(torch.float16) 139 | 140 | transform = torch_transform((H,W)) 141 | for frame_ind in range(T): 142 | frame = Image.fromarray(next(vid_raw)) 143 | # resize to source spatial resolution 144 | frame = frame.resize((W, H), Image.LANCZOS) 145 | frame, frame_2 = transform(frame) 146 | dist_frames[frame_ind],dist_frames_2[frame_ind] = \ 147 | frame.type(torch.float16), frame_2.type(torch.float16) 148 | 149 | # temporal transforms 150 | dist_frames = temporal_filt(dist_frames, temporal) 151 | dist_frames_2 = temporal_filt(dist_frames_2, temporal) 152 | 153 | loader = create_data_loader(dist_frames, dist_frames_2, args.batch_size) 154 | video_feat = extract_features(model, loader) 155 | 156 | video_feat = video_feat.astype(np.float16) 157 | np.save(write_name, video_feat) 158 | vid_raw.close() 159 | 160 | def parse_args(): 161 | parser = argparse.ArgumentParser() 162 | 163 | parser.add_argument('--model_path', type=str, \ 164 | default='models/CONTRIQUE_checkpoint25.tar', \ 165 | help='Path to trained CONTRIQUE model', metavar='') 166 | parser.add_argument('--batch_size', type=int, \ 167 | default=16, \ 168 | help='batch size', metavar='') 169 | parser.add_argument('--csv_file', type=str, \ 170 | default='csv_files/file_names_ugc.csv', \ 171 | help='path for csv file with filenames', metavar='') 172 | parser.add_argument('--write_fdr', type=str, \ 173 | default='training_data/kinetics_feat/', \ 174 | help='write folder', metavar='') 175 | args = parser.parse_args() 176 | args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 177 | return args 178 | 179 | if __name__ == '__main__': 180 | args = parse_args() 181 | main(args) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | import os 5 | 6 | # distributed training 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | from torch.nn.parallel import DataParallel 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | 12 | from modules.dataset_loader import video_data_feat 13 | 14 | from modules.GRUModel import GRUModel 15 | from modules.nt_xent_multiclass import NT_Xent 16 | from modules.configure_optimizers import configure_optimizers 17 | 18 | from model_io import save_model 19 | from modules.sync_batchnorm import convert_model 20 | import time 21 | import datetime 22 | from PIL import ImageFile 23 | 24 | ImageFile.LOAD_TRUNCATED_IMAGES = True 25 | torch.multiprocessing.set_sharing_strategy('file_system') 26 | 27 | def train(args, train_loader_syn, train_loader_ugc, \ 28 | model, criterion, optimizer, scaler, scheduler = None): 29 | loss_epoch = 0 30 | model.train() 31 | 32 | for step,((syn_i1, syn_i2, dist_label_syn),(ugc_i1, ugc_i2, _)) in \ 33 | enumerate(zip(train_loader_syn, train_loader_ugc)): 34 | 35 | #video 1 36 | syn_i1 = syn_i1.cuda(non_blocking=True) 37 | ugc_i1 = ugc_i1.cuda(non_blocking=True) 38 | x_i1 = torch.cat((syn_i1,ugc_i1),dim=0) 39 | 40 | #video 2 41 | syn_i2 = syn_i2.cuda(non_blocking=True) 42 | ugc_i2 = ugc_i2.cuda(non_blocking=True) 43 | x_i2 = torch.cat((syn_i2,ugc_i2),dim=0) 44 | 45 | x_i1 = x_i1.squeeze(1) 46 | x_i2 = x_i2.squeeze(1) 47 | 48 | # distortion classes 49 | # synthetic distortion classes 50 | dist_label = torch.zeros((2*args.batch_size, \ 51 | args.clusters+(args.batch_size*args.nodes))) 52 | dist_label[:args.batch_size,:args.clusters] = dist_label_syn.clone() 53 | 54 | # UGC data - each video is unique class 55 | dist_label[args.batch_size:,args.clusters + (args.nr*args.batch_size) : \ 56 | args.clusters + ((args.nr+1)*args.batch_size)] = \ 57 | torch.eye(args.batch_size) 58 | 59 | dist_label = dist_label.cuda(non_blocking=True) 60 | 61 | with torch.cuda.amp.autocast(enabled=True): 62 | z_i1, z_i2, h_i1, h_i2 = model(x_i1,x_i2) 63 | loss = criterion(z_i1, z_i2, dist_label) 64 | 65 | # update model weights 66 | optimizer.zero_grad() 67 | scaler.scale(loss).backward() 68 | scaler.step(optimizer) 69 | scaler.update() 70 | 71 | if scheduler: 72 | scheduler.step() 73 | 74 | if dist.is_available() and dist.is_initialized(): 75 | loss = loss.data.clone() 76 | dist.all_reduce(loss.div_(dist.get_world_size())) 77 | 78 | if args.nr == 0 and step % 5 == 0: 79 | lr = optimizer.param_groups[0]["lr"] 80 | print(f"Step [{step}/{args.steps}]\t Loss: {loss.item()}\t LR: {round(lr, 5)}") 81 | 82 | if args.nr == 0: 83 | args.global_step += 1 84 | 85 | loss_epoch += loss.item() 86 | 87 | return loss_epoch 88 | 89 | def main(gpu, args): 90 | rank = args.nr * args.gpus + gpu 91 | 92 | if args.nodes > 1: 93 | cur_dir = 'file://' + os.getcwd() + '/sharedfile' 94 | dist.init_process_group("nccl", init_method=cur_dir,\ 95 | rank=rank, timeout = datetime.timedelta(seconds=3600),\ 96 | world_size=args.world_size) 97 | torch.cuda.set_device(gpu) 98 | 99 | torch.manual_seed(args.seed) 100 | np.random.seed(args.seed) 101 | 102 | # loader for synthetic distortions data 103 | train_dataset_syn = video_data_feat(file_path=args.csv_file_syn,\ 104 | temporal_len = args.num_frames) 105 | 106 | if args.nodes > 1: 107 | train_sampler_syn = torch.utils.data.distributed.DistributedSampler( 108 | train_dataset_syn, num_replicas=args.world_size, rank=rank, shuffle=True 109 | ) 110 | else: 111 | train_sampler_syn = None 112 | 113 | train_loader_syn = torch.utils.data.DataLoader( 114 | train_dataset_syn, 115 | batch_size=args.batch_size, 116 | shuffle=(train_sampler_syn is None), 117 | drop_last=True, 118 | num_workers=args.workers, 119 | sampler=train_sampler_syn, 120 | ) 121 | 122 | # loader for authetically distorted data 123 | train_dataset_ugc = video_data_feat(file_path=args.csv_file_ugc,\ 124 | temporal_len = args.num_frames) 125 | 126 | if args.nodes > 1: 127 | train_sampler_ugc = torch.utils.data.distributed.DistributedSampler( 128 | train_dataset_ugc, num_replicas=args.world_size, rank=rank, shuffle=True 129 | ) 130 | else: 131 | train_sampler_ugc = None 132 | 133 | train_loader_ugc = torch.utils.data.DataLoader( 134 | train_dataset_ugc, 135 | batch_size=args.batch_size, 136 | shuffle=(train_sampler_ugc is None), 137 | drop_last=True, 138 | num_workers=args.workers, 139 | sampler=train_sampler_ugc, 140 | ) 141 | 142 | # initialize ResNet 143 | args.n_features = 2048 144 | temporal_model = GRUModel(args.n_features, args.n_features // 2, \ 145 | args.projection_dim, args.normalize,\ 146 | num_layers = 1) 147 | # initialize model 148 | if args.reload: 149 | model_fp = os.path.join( 150 | args.model_path, "checkpoint_{}.tar".format(args.epoch_num) 151 | ) 152 | temporal_model.load_state_dict(torch.load(model_fp, map_location=args.device.type)) 153 | temporal_model = temporal_model.to(args.device) 154 | 155 | #sgd optmizer 156 | args.steps = min(len(train_loader_syn),len(train_loader_ugc)) 157 | args.lr_schedule = 'warmup-anneal' 158 | args.warmup = 0.1 159 | args.weight_decay = 1e-4 160 | args.iters = args.steps*args.end_num 161 | optimizer, scheduler = configure_optimizers(args, temporal_model, cur_iter=-1) 162 | 163 | criterion = NT_Xent(args.batch_size, args.temperature, args.device, args.world_size) 164 | 165 | # DDP / DP 166 | if args.dataparallel: 167 | temporal_model = convert_model(temporal_model) 168 | temporal_model = DataParallel(temporal_model) 169 | 170 | else: 171 | if args.nodes > 1: 172 | temporal_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(temporal_model) 173 | temporal_model = DDP(temporal_model, device_ids=[gpu]);print(rank);dist.barrier() 174 | 175 | temporal_model = temporal_model.to(args.device) 176 | 177 | scaler = torch.cuda.amp.GradScaler(enabled=True) 178 | 179 | # writer = None 180 | if args.nr == 0: 181 | print('Training Started') 182 | 183 | if not os.path.isdir(args.model_path): 184 | os.mkdir(args.model_path) 185 | 186 | epoch_losses = [] 187 | args.global_step = 0 188 | args.current_epoch = args.start_epoch 189 | for epoch in range(args.start_epoch, args.epochs): 190 | start = time.time() 191 | 192 | loss_epoch = train(args, train_loader_syn, train_loader_ugc, \ 193 | temporal_model, criterion, optimizer, scaler, scheduler) 194 | 195 | end = time.time() 196 | print(np.round(end - start,4)) 197 | 198 | if args.nr == 0 and epoch % 1 == 0: 199 | save_model(args, temporal_model, optimizer) 200 | torch.save({'optimizer' : optimizer.state_dict(), 201 | 'scheduler' : scheduler.state_dict()},\ 202 | args.model_path + 'optimizer.tar') 203 | 204 | if args.nr == 0: 205 | print( 206 | f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / args.steps}" 207 | ) 208 | args.current_epoch += 1 209 | epoch_losses.append(loss_epoch / args.steps) 210 | np.save(args.model_path + 'losses.npy',epoch_losses) 211 | 212 | ## end training 213 | save_model(args, temporal_model, optimizer) 214 | 215 | def parse_args(): 216 | parser = argparse.ArgumentParser(description="CONVIQT") 217 | parser.add_argument('--nodes', type=int, default = 1, help = 'number of nodes', metavar='') 218 | parser.add_argument('--nr', type=int, default = 0, help = 'rank', metavar='') 219 | parser.add_argument('--csv_file_syn', type = str, \ 220 | default = 'csv_files/file_names_syn.csv',\ 221 | help = 'list of filenames of videos with synthetic distortions') 222 | parser.add_argument('--csv_file_ugc', type = str, \ 223 | default = 'csv_files/file_names_ugc.csv',\ 224 | help = 'list of filenames of UGC videos') 225 | parser.add_argument('--num_frames', type=tuple, default=16,\ 226 | help = 'number of frames in video') 227 | parser.add_argument('--batch_size', type=int, default = 512, \ 228 | help = 'number of videos in a batch') 229 | parser.add_argument('--workers', type = int, default = 4, \ 230 | help = 'number of workers') 231 | parser.add_argument('--opt', type = str, default = 'sgd',\ 232 | help = 'optimizer type') 233 | parser.add_argument('--lr', type = float, default = 0.6*2,\ 234 | help = 'learning rate') 235 | parser.add_argument('--model_path', type = str, default = 'checkpoints/',\ 236 | help = 'folder to save trained models') 237 | parser.add_argument('--temperature', type = float, default = 0.1,\ 238 | help = 'temperature parameter') 239 | parser.add_argument('--clusters', type = int, default = 120,\ 240 | help = 'number of synthetic distortion classes') 241 | parser.add_argument('--reload', type = bool, default = False,\ 242 | help = 'reload trained model') 243 | parser.add_argument('--normalize', type = bool, default = True,\ 244 | help = 'normalize encoder output') 245 | parser.add_argument('--projection_dim', type = int, default = 128,\ 246 | help = 'dimensions of the output feature from projector') 247 | parser.add_argument('--dataparallel', type = bool, default = False,\ 248 | help = 'use dataparallel module of PyTorch') 249 | parser.add_argument('--start_epoch', type = int, default = 0,\ 250 | help = 'starting epoch number') 251 | parser.add_argument('--end_num', type = int, default = 25,\ 252 | help = 'number to calculate learning rate decay') 253 | parser.add_argument('--epochs', type = int, default = 10,\ 254 | help = 'total number of epochs') 255 | parser.add_argument('--seed', type = int, default = 10,\ 256 | help = 'random seed') 257 | args = parser.parse_args() 258 | 259 | args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 260 | args.num_gpus = torch.cuda.device_count() 261 | args.gpus = 1 262 | args.world_size = args.gpus * args.nodes 263 | return args 264 | 265 | if __name__ == "__main__": 266 | args = parse_args() 267 | 268 | if args.nodes > 1: 269 | print( 270 | f"Training with {args.nodes} nodes, waiting until all nodes join before starting training" 271 | ) 272 | mp.spawn(main, args=(args,), nprocs=args.gpus, join=True) 273 | else: 274 | main(0, args) -------------------------------------------------------------------------------- /train_regressor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | from sklearn.linear_model import Ridge 4 | import pickle 5 | 6 | def main(args): 7 | 8 | feat = np.load(args.feat_path) 9 | scores = np.load(args.ground_truth_path) 10 | 11 | #train regression 12 | reg = Ridge(alpha=args.alpha).fit(feat, scores) 13 | pickle.dump(reg, open('lin_regressor.save','wb')) 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description="linear regressor") 17 | parser.add_argument('--feat_path', type=str, help = 'path to features file') 18 | parser.add_argument('--ground_truth_path', type=str, \ 19 | help = 'path to ground truth scores') 20 | parser.add_argument('--alpha', type = float, default = 0.1, \ 21 | help = 'regularization coefficient') 22 | args = parser.parse_args() 23 | return args 24 | 25 | if __name__ == "__main__": 26 | args = parse_args() 27 | main(args) --------------------------------------------------------------------------------