├── Dataset ├── DFDC │ ├── audio_feat │ │ ├── fcdlkuwhgc.npy │ │ ├── ghyqfwpdsy.npy │ │ ├── htuxozplwj.npy │ │ ├── junzqfxzxj.npy │ │ ├── lchmxzjjkj.npy │ │ ├── mutoprdgcj.npy │ │ ├── rfwixigmps.npy │ │ ├── ucteqgplkr.npy │ │ ├── wvduthysvb.npy │ │ └── xdldryztjq.npy │ ├── fake │ │ ├── junzqfxzxj_2.jpg │ │ ├── rfwixigmps_6.jpg │ │ ├── ucteqgplkr_8.jpg │ │ ├── wvduthysvb_6.jpg │ │ └── xdldryztjq_8.jpg │ └── real │ │ ├── fcdlkuwhgc_5.jpg │ │ ├── ghyqfwpdsy_4.jpg │ │ ├── htuxozplwj_6.jpg │ │ ├── lchmxzjjkj_9.jpg │ │ └── mutoprdgcj_7.jpg └── Voxceleb2 │ ├── face │ └── id00015 │ │ ├── 3X9uaIs66A0 │ │ ├── 00022.jpg │ │ └── 00023.jpg │ │ ├── JF-4trZP6fE │ │ ├── 00182.jpg │ │ └── 00183.jpg │ │ └── xUtyyVLxex0 │ │ ├── 00463.jpg │ │ ├── 00464.jpg │ │ ├── 00466.jpg │ │ ├── 00467.jpg │ │ ├── 00468.jpg │ │ └── 00473.jpg │ └── voice │ ├── id00015_3X9uaIs66A0_00022.npy │ ├── id00015_3X9uaIs66A0_00023.npy │ ├── id00015_JF-4trZP6fE_00182.npy │ ├── id00015_JF-4trZP6fE_00183.npy │ ├── id00015_xUtyyVLxex0_00463.npy │ ├── id00015_xUtyyVLxex0_00464.npy │ ├── id00015_xUtyyVLxex0_00466.npy │ ├── id00015_xUtyyVLxex0_00467.npy │ ├── id00015_xUtyyVLxex0_00468.npy │ └── id00015_xUtyyVLxex0_00473.npy ├── README.md ├── data ├── DFDC_dataset.py ├── Vox_image_dataset.py ├── __init__.py ├── base_dataset.py └── image_folder.py ├── models ├── DFD_model.py ├── __init__.py ├── base_model.py └── networks.py ├── options ├── base_options.py ├── test_options.py └── train_options.py ├── test_DF.py ├── train_DF.py └── util └── util.py /Dataset/DFDC/audio_feat/fcdlkuwhgc.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/audio_feat/fcdlkuwhgc.npy -------------------------------------------------------------------------------- /Dataset/DFDC/audio_feat/ghyqfwpdsy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/audio_feat/ghyqfwpdsy.npy -------------------------------------------------------------------------------- /Dataset/DFDC/audio_feat/htuxozplwj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/audio_feat/htuxozplwj.npy -------------------------------------------------------------------------------- /Dataset/DFDC/audio_feat/junzqfxzxj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/audio_feat/junzqfxzxj.npy -------------------------------------------------------------------------------- /Dataset/DFDC/audio_feat/lchmxzjjkj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/audio_feat/lchmxzjjkj.npy -------------------------------------------------------------------------------- /Dataset/DFDC/audio_feat/mutoprdgcj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/audio_feat/mutoprdgcj.npy -------------------------------------------------------------------------------- /Dataset/DFDC/audio_feat/rfwixigmps.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/audio_feat/rfwixigmps.npy -------------------------------------------------------------------------------- /Dataset/DFDC/audio_feat/ucteqgplkr.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/audio_feat/ucteqgplkr.npy -------------------------------------------------------------------------------- /Dataset/DFDC/audio_feat/wvduthysvb.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/audio_feat/wvduthysvb.npy -------------------------------------------------------------------------------- /Dataset/DFDC/audio_feat/xdldryztjq.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/audio_feat/xdldryztjq.npy -------------------------------------------------------------------------------- /Dataset/DFDC/fake/junzqfxzxj_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/fake/junzqfxzxj_2.jpg -------------------------------------------------------------------------------- /Dataset/DFDC/fake/rfwixigmps_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/fake/rfwixigmps_6.jpg -------------------------------------------------------------------------------- /Dataset/DFDC/fake/ucteqgplkr_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/fake/ucteqgplkr_8.jpg -------------------------------------------------------------------------------- /Dataset/DFDC/fake/wvduthysvb_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/fake/wvduthysvb_6.jpg -------------------------------------------------------------------------------- /Dataset/DFDC/fake/xdldryztjq_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/fake/xdldryztjq_8.jpg -------------------------------------------------------------------------------- /Dataset/DFDC/real/fcdlkuwhgc_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/real/fcdlkuwhgc_5.jpg -------------------------------------------------------------------------------- /Dataset/DFDC/real/ghyqfwpdsy_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/real/ghyqfwpdsy_4.jpg -------------------------------------------------------------------------------- /Dataset/DFDC/real/htuxozplwj_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/real/htuxozplwj_6.jpg -------------------------------------------------------------------------------- /Dataset/DFDC/real/lchmxzjjkj_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/real/lchmxzjjkj_9.jpg -------------------------------------------------------------------------------- /Dataset/DFDC/real/mutoprdgcj_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/DFDC/real/mutoprdgcj_7.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/face/id00015/3X9uaIs66A0/00022.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/face/id00015/3X9uaIs66A0/00022.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/face/id00015/3X9uaIs66A0/00023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/face/id00015/3X9uaIs66A0/00023.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/face/id00015/JF-4trZP6fE/00182.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/face/id00015/JF-4trZP6fE/00182.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/face/id00015/JF-4trZP6fE/00183.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/face/id00015/JF-4trZP6fE/00183.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00463.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00463.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00464.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00464.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00466.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00466.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00467.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00467.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00468.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00468.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00473.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/face/id00015/xUtyyVLxex0/00473.jpg -------------------------------------------------------------------------------- /Dataset/Voxceleb2/voice/id00015_3X9uaIs66A0_00022.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/voice/id00015_3X9uaIs66A0_00022.npy -------------------------------------------------------------------------------- /Dataset/Voxceleb2/voice/id00015_3X9uaIs66A0_00023.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/voice/id00015_3X9uaIs66A0_00023.npy -------------------------------------------------------------------------------- /Dataset/Voxceleb2/voice/id00015_JF-4trZP6fE_00182.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/voice/id00015_JF-4trZP6fE_00182.npy -------------------------------------------------------------------------------- /Dataset/Voxceleb2/voice/id00015_JF-4trZP6fE_00183.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/voice/id00015_JF-4trZP6fE_00183.npy -------------------------------------------------------------------------------- /Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00463.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00463.npy -------------------------------------------------------------------------------- /Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00464.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00464.npy -------------------------------------------------------------------------------- /Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00466.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00466.npy -------------------------------------------------------------------------------- /Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00467.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00467.npy -------------------------------------------------------------------------------- /Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00468.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00468.npy -------------------------------------------------------------------------------- /Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00473.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-IP/VFD/d8ee33d6573faa23430de20f15bcfc28ef13554b/Dataset/Voxceleb2/voice/id00015_xUtyyVLxex0_00473.npy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VFD 2 | This is the release code for CVPR2022 paper ["Voice-Face Homogeneity Tells Deepfake"](https://arxiv.org/abs/2203.02195). 3 | 4 | Part of the framework is borrowed from 5 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 6 | 7 | **Notes:** We only give a small batch of training and testing data, so some numerical modifications have been made to the dataset processing function to fit the small data. We will release the full data in a future official version. 8 | 9 | Train: 10 | 11 | ``` 12 | python train_DF.py --dataroot ./Dataset/Voxceleb2 --dataset_mode Vox_image --model DFD --no_flip --name experiment_name --serial_batches 13 | ``` 14 | 15 | Test (on DFDC): 16 | 17 | ``` 18 | python test_DF.py --dataroot ./Dataset/DFDC --dataset_mode DFDC --model DFD --no_flip --name experiment_name 19 | -------------------------------------------------------------------------------- /data/DFDC_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | from util import util 7 | import torch 8 | import json 9 | import cv2 10 | import numpy as np 11 | from tqdm import tqdm 12 | import torchaudio 13 | import warnings 14 | import librosa 15 | warnings.filterwarnings('ignore') 16 | 17 | 18 | class DFDCDataset(BaseDataset): 19 | 20 | 21 | def __init__(self, opt): 22 | 23 | BaseDataset.__init__(self, opt) 24 | self.opt = opt 25 | self.stage = opt.mode 26 | video_dataset_path_real = os.path.join(self.opt.dataroot, 'real') 27 | video_dataset_path_fake = os.path.join(self.opt.dataroot, 'fake') 28 | audio_dataset_path = os.path.join(self.opt.dataroot, 'audio_feat') 29 | self.video_real, self.video_fake, self.audio_real, self.audio_fake, self.audio_name = \ 30 | self.get_video_list(video_dataset_path_real, video_dataset_path_fake, audio_dataset_path, self.stage) 31 | 32 | self.transform = get_transform(self.opt) 33 | 34 | 35 | def __getitem__(self, index): 36 | 37 | video_real = self.video_real 38 | video_fake = self.video_fake 39 | audio_real = self.audio_real 40 | audio_fake = self.audio_fake 41 | audio_name = self.audio_name 42 | 43 | img_real = [] 44 | aud_real = [] 45 | img_fake = [] 46 | aud_fake = [] 47 | 48 | image_input = Image.open(video_real[index]).convert('RGB') 49 | img_d = self.transform(image_input) 50 | audio = np.load(audio_real[index]) 51 | 52 | audio_d = librosa.util.normalize(audio) 53 | img_real.append(img_d) 54 | aud_real.append(audio_d) 55 | 56 | audio_id = audio_name[index] 57 | 58 | 59 | image_input = Image.open(video_fake[index]).convert('RGB') 60 | img_d = self.transform(image_input) 61 | img_fake.append(img_d) 62 | 63 | audio = np.load(audio_fake[index]) 64 | audio_d = librosa.util.normalize(audio) 65 | aud_fake.append(audio_d) 66 | 67 | aud_real = np.stack(aud_real, axis=0) 68 | aud_fake = np.stack(aud_fake, axis=0) 69 | img_real = np.stack(img_real, axis=0) 70 | img_fake = np.stack(img_fake, axis=0) 71 | aud_real = np.expand_dims(aud_real, 1) 72 | aud_fake = np.expand_dims(aud_fake, 1) 73 | return { 74 | 'id': audio_id, 75 | 'img_real': img_real, 76 | 'img_fake': img_fake, 77 | 'aud_real': aud_real, 78 | 'aud_fake': aud_fake, 79 | } 80 | 81 | def __len__(self): 82 | 83 | return len(self.video_real) 84 | 85 | def get_video_list(self, dataset_path_real, dataset_path_fake, audio_dataset_path, mode='test'): 86 | video_feat_path_real = dataset_path_real 87 | video_feat_path_fake = dataset_path_fake 88 | audio_feat_path = audio_dataset_path 89 | 90 | video_real_path = [] 91 | video_fake_path = [] 92 | audio_real_path = [] 93 | audio_fake_path = [] 94 | audio_name = [] 95 | for i in tqdm(os.listdir(video_feat_path_real)): 96 | video_real_path.append(os.path.join(video_feat_path_real, i)) 97 | video_name = i.split('.jpg')[0][:-2] 98 | audio_real_path.append(os.path.join(audio_feat_path, video_name+'.npy')) 99 | audio_name.append(video_name) 100 | 101 | for i in tqdm(os.listdir(video_feat_path_fake)): 102 | video_fake_path.append(os.path.join(video_feat_path_fake, i)) 103 | video_name = i.split('.jpg')[0][:-2] 104 | audio_fake_path.append(os.path.join(audio_feat_path, video_name+'.npy')) 105 | audio_name.append(video_name) 106 | return video_real_path, video_fake_path, audio_real_path, audio_fake_path, audio_name 107 | -------------------------------------------------------------------------------- /data/Vox_image_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | from util import util 7 | import torch 8 | import json 9 | import cv2 10 | import numpy as np 11 | from tqdm import tqdm 12 | import torchaudio 13 | import warnings 14 | import librosa 15 | warnings.filterwarnings('ignore') 16 | 17 | 18 | class VoxImageDataset(BaseDataset): 19 | 20 | def __init__(self, opt): 21 | 22 | BaseDataset.__init__(self, opt) 23 | self.opt = opt 24 | self.stage = opt.mode 25 | if self.stage == 'train': 26 | self.video_dataset_path = os.path.join(self.opt.dataroot, 'face') 27 | self.audio_dataset_path = os.path.join(self.opt.dataroot, 'voice') 28 | elif self.stage == 'val': 29 | self.video_dataset_path = os.path.join(self.opt.dataroot, 'face') 30 | self.audio_dataset_path = os.path.join(self.opt.dataroot, 'voice') 31 | else: 32 | self.video_dataset_path = os.path.join(self.opt.dataroot, 'face') 33 | self.audio_dataset_path = os.path.join(self.opt.dataroot, 'voice') 34 | self.video_list, self.audio_list = self.get_video_list(self.video_dataset_path, self.audio_dataset_path, self.stage) 35 | random.seed(3) 36 | self.transform = get_transform(self.opt) 37 | 38 | def __getitem__(self, index): 39 | 40 | skip = self.__len__() // 2 41 | sample_exa = 3 42 | video_path = self.video_list[index] 43 | audio_path = self.audio_list[index] 44 | 45 | img_real = [] 46 | aud_real = [] 47 | img_fake = [] 48 | aud_fake = [] 49 | 50 | image_input = Image.open(video_path).convert('RGB') 51 | img_d = self.transform(image_input) 52 | audio = np.load(audio_path) 53 | audio_d = librosa.util.normalize(audio) 54 | img_real.append(img_d) 55 | aud_real.append(audio_d) 56 | 57 | min = index + skip 58 | max = (index + 2*skip)%self.__len__() 59 | if min > max: 60 | min = 0 61 | max = min + skip 62 | sample = random.sample(range(min, max), sample_exa) 63 | for ind in sample: 64 | image_input = Image.open(self.video_list[ind]).convert('RGB') 65 | img_d = self.transform(image_input) 66 | img_fake.append(img_d) 67 | 68 | audio = np.load(self.audio_list[ind]) 69 | audio_d = librosa.util.normalize(audio) 70 | aud_fake.append(audio_d) 71 | 72 | aud_real = np.stack(aud_real, axis=0) 73 | aud_fake = np.stack(aud_fake, axis=0) 74 | img_real = np.stack(img_real, axis=0) 75 | img_fake = np.stack(img_fake, axis=0) 76 | aud_real = np.expand_dims(aud_real, 1) 77 | aud_fake = np.expand_dims(aud_fake, 1) 78 | return { 79 | 'img_real': img_real, 80 | 'img_fake': img_fake, 81 | 'aud_real': aud_real, 82 | 'aud_fake': aud_fake, 83 | } 84 | 85 | def __len__(self): 86 | return len(self.video_list) 87 | 88 | def get_video_list(self, dataset_path, audio_dataset_path, mode): 89 | video_feat_path = os.path.join(dataset_path) 90 | audio_feat_path = os.path.join(audio_dataset_path) 91 | id_list = [i for i in os.listdir(video_feat_path) if i.startswith('id')] 92 | video_path = [] 93 | audio_path = [] 94 | 95 | # Too little data, train and val set are the same in the demo code, 96 | # they will be distinguished in the official version. 97 | if mode == 'train': 98 | id_list_new = id_list[0:1] 99 | len_max = 100000 100 | elif mode == 'val': 101 | id_list_new = id_list[0:1] 102 | len_max = 15000 103 | else: 104 | id_list_new = id_list 105 | len_max = 15000 106 | cnt = 0 107 | id_list_new.sort() 108 | print(id_list_new) 109 | for id in tqdm(id_list_new): 110 | id_video_split_path = os.path.join(video_feat_path, id) 111 | video_list = os.listdir(id_video_split_path) 112 | cnt_video = 0 113 | for video in video_list: 114 | image_list_path = os.path.join(id_video_split_path, video) 115 | image_list = os.listdir(image_list_path) 116 | if cnt_video > 50: 117 | break 118 | for image in image_list: 119 | if cnt_video > 40: 120 | break 121 | audio_p = os.path.join(id, video, image.replace('.jpg','.npy')).replace('/', '_') 122 | if image.endswith('.jpg') and os.path.exists(os.path.join(audio_feat_path, audio_p)): 123 | video_path.append(os.path.join(image_list_path, image)) 124 | audio_path.append(os.path.join(audio_feat_path, audio_p)) 125 | cnt += 1 126 | cnt_video += 1 127 | if cnt > len_max: 128 | break 129 | return video_path, audio_path 130 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader: 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads)) 80 | 81 | def load_data(self): 82 | return self 83 | 84 | def __len__(self): 85 | """Return the number of data in the dataset""" 86 | return min(len(self.dataset), self.opt.max_dataset_size) 87 | 88 | def __iter__(self): 89 | """Return a batch of data""" 90 | for i, data in enumerate(self.dataloader): 91 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 92 | break 93 | yield data 94 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 3 | """ 4 | import random 5 | import numpy as np 6 | import torch.utils.data as data 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | from abc import ABC, abstractmethod 10 | 11 | 12 | class BaseDataset(data.Dataset, ABC): 13 | """This class is an abstract base class (ABC) for datasets. 14 | To create a subclass, you need to implement the following four functions: 15 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 16 | -- <__len__>: return the size of dataset. 17 | -- <__getitem__>: get a data point. 18 | -- : (optionally) add dataset-specific options and set default options. 19 | """ 20 | 21 | def __init__(self, opt): 22 | """Initialize the class; save the options in the class 23 | Parameters: 24 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 25 | """ 26 | self.opt = opt 27 | self.root = opt.dataroot 28 | 29 | @staticmethod 30 | def modify_commandline_options(parser, is_train): 31 | """Add new dataset-specific options, and rewrite default values for existing options. 32 | Parameters: 33 | parser -- original option parser 34 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 35 | Returns: 36 | the modified parser. 37 | """ 38 | return parser 39 | 40 | @abstractmethod 41 | def __len__(self): 42 | """Return the total number of images in the dataset.""" 43 | return 0 44 | 45 | @abstractmethod 46 | def __getitem__(self, index): 47 | """Return a data point and its metadata information. 48 | Parameters: 49 | index - - a random integer for data indexing 50 | Returns: 51 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 52 | """ 53 | pass 54 | 55 | 56 | def get_params(opt, size): 57 | w, h = size 58 | new_h = h 59 | new_w = w 60 | if opt.preprocess == 'resize_and_crop': 61 | new_h = new_w = opt.load_size 62 | elif opt.preprocess == 'scale_width_and_crop': 63 | new_w = opt.load_size 64 | new_h = opt.load_size * h // w 65 | 66 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 67 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 68 | 69 | flip = random.random() > 0.5 70 | 71 | return {'crop_pos': (x, y), 'flip': flip} 72 | 73 | 74 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 75 | transform_list = [] 76 | if grayscale: 77 | transform_list.append(transforms.Grayscale(1)) 78 | if 'resize' in opt.preprocess: 79 | osize = [opt.load_size, opt.load_size] 80 | transform_list.append(transforms.Resize(osize)) 81 | elif 'scale_width' in opt.preprocess: 82 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) 83 | 84 | if 'crop' in opt.preprocess: 85 | if params is None: 86 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 87 | else: 88 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 89 | 90 | if opt.preprocess == 'none': 91 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 92 | 93 | if not opt.no_flip: 94 | if params is None: 95 | transform_list.append(transforms.RandomHorizontalFlip()) 96 | elif params['flip']: 97 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 98 | 99 | if convert: 100 | transform_list += [transforms.ToTensor()] 101 | if grayscale: 102 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 103 | else: 104 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 105 | return transforms.Compose(transform_list) 106 | 107 | 108 | def __make_power_2(img, base, method=Image.BICUBIC): 109 | ow, oh = img.size 110 | h = int(round(oh / base) * base) 111 | w = int(round(ow / base) * base) 112 | if h == oh and w == ow: 113 | return img 114 | 115 | __print_size_warning(ow, oh, w, h) 116 | return img.resize((w, h), method) 117 | 118 | 119 | def __scale_width(img, target_size, crop_size, method=Image.BICUBIC): 120 | ow, oh = img.size 121 | if ow == target_size and oh >= crop_size: 122 | return img 123 | w = target_size 124 | h = int(max(target_size * oh / ow, crop_size)) 125 | return img.resize((w, h), method) 126 | 127 | 128 | def __crop(img, pos, size): 129 | ow, oh = img.size 130 | x1, y1 = pos 131 | tw = th = size 132 | if (ow > tw or oh > th): 133 | return img.crop((x1, y1, x1 + tw, y1 + th)) 134 | return img 135 | 136 | 137 | def __flip(img, flip): 138 | if flip: 139 | return img.transpose(Image.FLIP_LEFT_RIGHT) 140 | return img 141 | 142 | 143 | def __print_size_warning(ow, oh, w, h): 144 | """Print warning information about image size(only print once)""" 145 | if not hasattr(__print_size_warning, 'has_printed'): 146 | print("The image size needs to be a multiple of 4. " 147 | "The loaded image size was (%d, %d), so it was adjusted to " 148 | "(%d, %d). This adjustment will be done to all images " 149 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 150 | __print_size_warning.has_printed = True -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 3 | so that this class can load images from both current directory and its subdirectories. 4 | """ 5 | 6 | import torch.utils.data as data 7 | 8 | from PIL import Image 9 | import os 10 | 11 | IMG_EXTENSIONS = [ 12 | '.jpg', '.JPG', '.jpeg', '.JPEG', 13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 14 | '.tif', '.TIF', '.tiff', '.TIFF', 15 | ] 16 | 17 | 18 | def is_image_file(filename): 19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 20 | 21 | 22 | def make_dataset(dir, max_dataset_size=float("inf")): 23 | images = [] 24 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 25 | 26 | for root, _, fnames in sorted(os.walk(dir)): 27 | for fname in fnames: 28 | if is_image_file(fname): 29 | path = os.path.join(root, fname) 30 | images.append(path) 31 | return images[:min(max_dataset_size, len(images))] 32 | 33 | 34 | def default_loader(path): 35 | return Image.open(path).convert('RGB') 36 | 37 | 38 | class ImageFolder(data.Dataset): 39 | 40 | def __init__(self, root, transform=None, return_paths=False, 41 | loader=default_loader): 42 | imgs = make_dataset(root) 43 | if len(imgs) == 0: 44 | raise(RuntimeError("Found 0 images in: " + root + "\n" 45 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 46 | 47 | self.root = root 48 | self.imgs = imgs 49 | self.transform = transform 50 | self.return_paths = return_paths 51 | self.loader = loader 52 | 53 | def __getitem__(self, index): 54 | path = self.imgs[index] 55 | img = self.loader(path) 56 | if self.transform is not None: 57 | img = self.transform(img) 58 | if self.return_paths: 59 | return img, path 60 | else: 61 | return img 62 | 63 | def __len__(self): 64 | return len(self.imgs) -------------------------------------------------------------------------------- /models/DFD_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | import numpy as np 5 | import torchvision 6 | from PIL import Image 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import os 10 | from util import util 11 | import torchvision.transforms as transforms 12 | 13 | class DFDModel(BaseModel): 14 | 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train=True): 17 | 18 | parser.set_defaults(norm='batch', netG='unet_af', dataset_mode='aligned') 19 | if is_train: 20 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 21 | parser.add_argument('--lambda_L1', type=float, default=10.0, help='weight for L1 loss') 22 | 23 | return parser 24 | 25 | def __init__(self, opt): 26 | BaseModel.__init__(self, opt) 27 | 28 | if self.isTrain: 29 | self.model_names = ['G_audio', 'G_video'] 30 | else: 31 | self.model_names = ['G_audio', 'G_video'] 32 | 33 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 34 | self.batch_size = opt.batch_size 35 | 36 | self.netG_video = networks.define_G(3, 3, opt.ngf, 'transformer_video', opt.norm, 37 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids).to(self.device) 38 | self.netG_audio = networks.define_G(3, 3, opt.ngf, 'transformer_audio', opt.norm, 39 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids).to(self.device) 40 | 41 | if self.isTrain: 42 | self.triplet_loss = nn.TripletMarginLoss(margin=100.0, p=2) 43 | self.pdist = nn.PairwiseDistance(p=2) 44 | self.optimizer_G = torch.optim.Adam(list(self.netG_audio.parameters()) 45 | + list(self.netG_video.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999), 46 | ) 47 | self.optimizers.append(self.optimizer_G) 48 | 49 | def set_input(self, input_data): 50 | self.img_real = input_data['img_real'] 51 | self.img_fake = input_data['img_fake'] 52 | self.aud_real = input_data['aud_real'] 53 | self.aud_fake = input_data['aud_fake'] 54 | 55 | def set_test_input(self, input_data): 56 | self.target = input_data['label'] 57 | self.img_real = input_data['img'] 58 | self.aud_real = input_data['aud'] 59 | 60 | def forward(self): 61 | if torch.cuda.is_available(): 62 | self.img_real = torch.autograd.Variable(self.img_real).cuda() 63 | self.img_fake = torch.autograd.Variable(self.img_fake).cuda() 64 | self.aud_real = torch.autograd.Variable(self.aud_real).cuda() 65 | self.aud_fake = torch.autograd.Variable(self.aud_fake).cuda() 66 | else: 67 | self.img_real = torch.autograd.Variable(self.img_real) 68 | self.img_fake = torch.autograd.Variable(self.img_fake) 69 | self.aud_real = torch.autograd.Variable(self.aud_real) 70 | self.aud_fake = torch.autograd.Variable(self.aud_fake) 71 | 72 | self.aud_real_feat = self.netG_audio(self.aud_real.squeeze(0)) 73 | self.aud_fake_feat = self.netG_audio(self.aud_fake.squeeze(0)) 74 | self.img_fake_feat = self.netG_video(self.img_fake.squeeze(0)) 75 | self.img_real_feat = self.netG_video(self.img_real.squeeze(0)) 76 | 77 | def forward_test(self): 78 | if torch.cuda.is_available(): 79 | self.img_real = torch.autograd.Variable(self.img_real).cuda() 80 | self.aud_real = torch.autograd.Variable(self.aud_real).cuda() 81 | else: 82 | self.img_real = torch.autograd.Variable(self.img_real) 83 | self.aud_real = torch.autograd.Variable(self.aud_real) 84 | 85 | self.aud_real_feat = self.netG_audio(self.aud_real.squeeze(0)) 86 | self.img_real_feat = self.netG_video(self.img_real.squeeze(0)) 87 | 88 | def backward_G(self): 89 | audio_real = self.aud_real_feat 90 | video_real = self.img_real_feat 91 | video_fake = self.img_fake_feat 92 | 93 | self.loss_A_V = None 94 | for i in video_fake: 95 | if self.loss_A_V is None: 96 | self.loss_A_V = self.triplet_loss(audio_real, video_real, i) 97 | else: 98 | self.loss_A_V += self.triplet_loss(audio_real, video_real, i) 99 | self.loss = self.loss_A_V 100 | self.loss.backward() 101 | return { 102 | 'loss_A_V': self.loss_A_V.detach().item(), 103 | } 104 | 105 | def optimize_parameters(self): 106 | self.forward() 107 | self.optimizer_G.zero_grad() 108 | loss_pack = self.backward_G() 109 | loss_G_AV = loss_pack['loss_A_V'] 110 | self.optimizer_G.step() 111 | return loss_G_AV 112 | 113 | def val(self): 114 | with torch.no_grad(): 115 | self.forward() 116 | audio_real = self.aud_real_feat 117 | audio_fake = self.aud_fake_feat 118 | video_real = self.img_real_feat 119 | video_fake = self.img_fake_feat 120 | self.sim_A_V = self.simi(audio_real, video_real) 121 | self.sim_V_A = [] 122 | for i, j in zip(video_fake, audio_fake): 123 | self.sim_V_A.append(self.simi(i.unsqueeze(0), j.unsqueeze(0))) 124 | return self.sim_A_V, self.sim_V_A 125 | 126 | def simi(self, anchor, pos): 127 | pdist = nn.PairwiseDistance(p=2) 128 | dist = pdist(anchor, pos) 129 | return dist -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | 7 | 8 | class BaseModel(ABC): 9 | """This class is an abstract base class (ABC) for models. 10 | To create a subclass, you need to implement the following five functions: 11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 12 | -- : unpack data from dataset and apply preprocessing. 13 | -- : produce intermediate results. 14 | -- : calculate losses, gradients, and update network weights. 15 | -- : (optionally) add model-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt): 19 | """Initialize the BaseModel class. 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | 24 | When creating your custom class, you need to implement your own initialization. 25 | In this function, you should first call 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): define networks used in our training. 29 | -- self.visual_names (str list): specify the images that you want to display and save. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_ids = opt.gpu_ids 34 | self.isTrain = opt.isTrain 35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 38 | torch.backends.cudnn.benchmark = True 39 | self.loss_names = [] 40 | self.model_names = [] 41 | self.visual_names = [] 42 | self.optimizers = [] 43 | self.image_paths = [] 44 | 45 | self.schedulers = None 46 | 47 | self.metric = 0 # used for learning rate policy 'plateau' 48 | 49 | @staticmethod 50 | def modify_commandline_options(parser, is_train): 51 | """Add new model-specific options, and rewrite default values for existing options. 52 | 53 | Parameters: 54 | parser -- original option parser 55 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 56 | 57 | Returns: 58 | the modified parser. 59 | """ 60 | return parser 61 | 62 | @abstractmethod 63 | def set_input(self, input): 64 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 65 | 66 | Parameters: 67 | input (dict): includes the data itself and its metadata information. 68 | """ 69 | pass 70 | 71 | @abstractmethod 72 | def forward(self): 73 | """Run forward pass; called by both functions and .""" 74 | pass 75 | 76 | @abstractmethod 77 | def optimize_parameters(self): 78 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 79 | pass 80 | 81 | def setup(self, opt): 82 | """Load and print networks; create schedulers 83 | 84 | Parameters: 85 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 86 | """ 87 | if self.isTrain: 88 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 89 | if not self.isTrain or opt.continue_train: 90 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 91 | self.load_networks(load_suffix) 92 | self.print_networks(opt.verbose) 93 | 94 | def eval(self): 95 | """Make models eval mode during test time""" 96 | for name in self.model_names: 97 | if isinstance(name, str): 98 | net = getattr(self, 'net' + name) 99 | net.eval() 100 | 101 | def train(self): 102 | """Make models eval mode during test time""" 103 | for name in self.model_names: 104 | if isinstance(name, str): 105 | net = getattr(self, 'net' + name) 106 | net.train() 107 | 108 | def test(self): 109 | """Forward function used in test time. 110 | 111 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 112 | It also calls to produce additional visualization results 113 | """ 114 | with torch.no_grad(): 115 | self.forward() 116 | self.compute_visuals() 117 | 118 | def compute_visuals(self): 119 | """Calculate additional output images for visdom and HTML visualization""" 120 | pass 121 | 122 | def get_image_paths(self): 123 | """ Return image paths that are used to load current data""" 124 | return self.image_paths 125 | 126 | def update_learning_rate(self): 127 | """Update learning rates for all the networks; called at the end of every epoch""" 128 | old_lr = self.optimizers[0].param_groups[0]['lr'] 129 | for scheduler in self.schedulers: 130 | if self.opt.lr_policy == 'plateau': 131 | scheduler.step(self.metric) 132 | else: 133 | scheduler.step() 134 | 135 | lr = self.optimizers[0].param_groups[0]['lr'] 136 | print('learning rate %.7f -> %.7f' % (old_lr, lr)) 137 | 138 | def get_current_visuals(self): 139 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 140 | visual_ret = OrderedDict() 141 | for name in self.visual_names: 142 | if isinstance(name, str): 143 | visual_ret[name] = getattr(self, name) 144 | return visual_ret 145 | 146 | def get_current_losses(self): 147 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 148 | errors_ret = OrderedDict() 149 | for name in self.loss_names: 150 | if isinstance(name, str): 151 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 152 | return errors_ret 153 | 154 | def save_networks(self, epoch): 155 | """Save all the networks to the disk. 156 | 157 | Parameters: 158 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 159 | """ 160 | for name in self.model_names: 161 | if isinstance(name, str): 162 | save_filename = '%s_net_%s.pth' % (epoch, name) 163 | save_path = os.path.join(self.save_dir, save_filename) 164 | net = getattr(self, 'net' + name) 165 | 166 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 167 | torch.save(net.module.cpu().state_dict(), save_path) 168 | net.cuda(self.gpu_ids[0]) 169 | else: 170 | torch.save(net.cpu().state_dict(), save_path) 171 | 172 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 173 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 174 | key = keys[i] 175 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 176 | if module.__class__.__name__.startswith('InstanceNorm') and \ 177 | (key == 'running_mean' or key == 'running_var'): 178 | if getattr(module, key) is None: 179 | state_dict.pop('.'.join(keys)) 180 | if module.__class__.__name__.startswith('InstanceNorm') and \ 181 | (key == 'num_batches_tracked'): 182 | state_dict.pop('.'.join(keys)) 183 | else: 184 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 185 | 186 | def load_networks(self, epoch): 187 | """Load all the networks from the disk. 188 | 189 | Parameters: 190 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 191 | """ 192 | for name in self.model_names: 193 | if isinstance(name, str): 194 | load_filename = '%s_net_%s.pth' % (epoch, name) 195 | load_path = os.path.join(self.save_dir, load_filename) 196 | net = getattr(self, 'net' + name) 197 | if isinstance(net, torch.nn.DataParallel): 198 | net = net.module 199 | print('loading the model from %s' % load_path) 200 | # if you are using PyTorch newer than 0.4 (e.g., built from 201 | # GitHub source), you can remove str() on self.device 202 | state_dict = torch.load(load_path, map_location=str(self.device)) 203 | if hasattr(state_dict, '_metadata'): 204 | del state_dict._metadata 205 | 206 | # patch InstanceNorm checkpoints prior to 0.4 207 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 208 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 209 | net.load_state_dict(state_dict) 210 | 211 | def print_networks(self, verbose): 212 | """Print the total number of parameters in the network and (if verbose) network architecture 213 | 214 | Parameters: 215 | verbose (bool) -- if verbose: print the network architecture 216 | """ 217 | print('---------- Networks initialized -------------') 218 | for name in self.model_names: 219 | if isinstance(name, str): 220 | net = getattr(self, 'net' + name) 221 | num_params = 0 222 | for param in net.parameters(): 223 | num_params += param.numel() 224 | if verbose: 225 | print(net) 226 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 227 | print('-----------------------------------------------') 228 | 229 | def set_requires_grad(self, nets, requires_grad=False): 230 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 231 | Parameters: 232 | nets (network list) -- a list of networks 233 | requires_grad (bool) -- whether the networks require gradients or not 234 | """ 235 | if not isinstance(nets, list): 236 | nets = [nets] 237 | for net in nets: 238 | if net is not None: 239 | for param in net.parameters(): 240 | param.requires_grad = requires_grad 241 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | import torchfile 7 | import torch.nn.functional as F 8 | import torch 9 | from torch import nn 10 | import math 11 | from einops import rearrange 12 | 13 | class Identity(nn.Module): 14 | def forward(self, x): 15 | return x 16 | 17 | 18 | def get_norm_layer(norm_type='instance'): 19 | if norm_type == 'batch': 20 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 21 | elif norm_type == 'instance': 22 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 23 | elif norm_type == 'none': 24 | def norm_layer(x): return Identity() 25 | else: 26 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 27 | return norm_layer 28 | 29 | 30 | def get_scheduler(optimizer, opt): 31 | if opt.lr_policy == 'linear': 32 | def lambda_rule(epoch): 33 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 34 | return lr_l 35 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 36 | elif opt.lr_policy == 'step': 37 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 38 | elif opt.lr_policy == 'plateau': 39 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 40 | elif opt.lr_policy == 'cosine': 41 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 42 | else: 43 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 44 | return scheduler 45 | 46 | 47 | def init_weights(net, init_type='normal', init_gain=0.02): 48 | 49 | def init_func(m): 50 | classname = m.__class__.__name__ 51 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 52 | if init_type == 'normal': 53 | init.normal_(m.weight.data, 0.0, init_gain) 54 | elif init_type == 'xavier': 55 | init.xavier_normal_(m.weight.data, gain=init_gain) 56 | elif init_type == 'kaiming': 57 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 58 | elif init_type == 'orthogonal': 59 | init.orthogonal_(m.weight.data, gain=init_gain) 60 | else: 61 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 62 | if hasattr(m, 'bias') and m.bias is not None: 63 | init.constant_(m.bias.data, 0.0) 64 | elif classname.find('BatchNorm2d') != -1: 65 | init.normal_(m.weight.data, 1.0, init_gain) 66 | init.constant_(m.bias.data, 0.0) 67 | 68 | print('initialize network with %s' % init_type) 69 | net.apply(init_func) 70 | 71 | 72 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 73 | if len(gpu_ids) > 0: 74 | assert(torch.cuda.is_available()) 75 | net.to(gpu_ids[0]) 76 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 77 | init_weights(net, init_type, init_gain=init_gain) 78 | return net 79 | 80 | 81 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 82 | net = None 83 | norm_layer = get_norm_layer(norm_type=norm) 84 | 85 | if netG == 'resnet_9blocks': 86 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 87 | elif netG == 'resnet_6blocks': 88 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 89 | elif netG == 'unet_128': 90 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 91 | elif netG == 'unet_256': 92 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 93 | elif netG == 'transformer_video': 94 | net = CViT() 95 | elif netG == 'transformer_audio': 96 | net = CAiT() 97 | else: 98 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 99 | # return net 100 | return init_net(net, init_type, init_gain, gpu_ids) 101 | 102 | 103 | class ResnetGenerator(nn.Module): 104 | 105 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 106 | assert(n_blocks >= 0) 107 | super(ResnetGenerator, self).__init__() 108 | if type(norm_layer) == functools.partial: 109 | use_bias = norm_layer.func == nn.InstanceNorm2d 110 | else: 111 | use_bias = norm_layer == nn.InstanceNorm2d 112 | 113 | model = [nn.ReflectionPad2d(3), 114 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 115 | norm_layer(ngf), 116 | nn.ReLU(True)] 117 | 118 | n_downsampling = 2 119 | for i in range(n_downsampling): # add downsampling layers 120 | mult = 2 ** i 121 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 122 | norm_layer(ngf * mult * 2), 123 | nn.ReLU(True)] 124 | 125 | mult = 2 ** n_downsampling 126 | for i in range(n_blocks): # add ResNet blocks 127 | 128 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 129 | 130 | for i in range(n_downsampling): # add upsampling layers 131 | mult = 2 ** (n_downsampling - i) 132 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 133 | kernel_size=3, stride=2, 134 | padding=1, output_padding=1, 135 | bias=use_bias), 136 | norm_layer(int(ngf * mult / 2)), 137 | nn.ReLU(True)] 138 | model += [nn.ReflectionPad2d(3)] 139 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 140 | model += [nn.Tanh()] 141 | 142 | self.model = nn.Sequential(*model) 143 | 144 | return self.model(input) 145 | 146 | 147 | class ResnetBlock(nn.Module): 148 | 149 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 150 | super(ResnetBlock, self).__init__() 151 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 152 | 153 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 154 | conv_block = [] 155 | p = 0 156 | if padding_type == 'reflect': 157 | conv_block += [nn.ReflectionPad2d(1)] 158 | elif padding_type == 'replicate': 159 | conv_block += [nn.ReplicationPad2d(1)] 160 | elif padding_type == 'zero': 161 | p = 1 162 | else: 163 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 164 | 165 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 166 | if use_dropout: 167 | conv_block += [nn.Dropout(0.5)] 168 | 169 | p = 0 170 | if padding_type == 'reflect': 171 | conv_block += [nn.ReflectionPad2d(1)] 172 | elif padding_type == 'replicate': 173 | conv_block += [nn.ReplicationPad2d(1)] 174 | elif padding_type == 'zero': 175 | p = 1 176 | else: 177 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 178 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 179 | 180 | return nn.Sequential(*conv_block) 181 | 182 | def forward(self, x): 183 | 184 | out = x + self.conv_block(x) 185 | return out 186 | 187 | 188 | class UnetGenerator(nn.Module): 189 | 190 | 191 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 192 | 193 | super(UnetGenerator, self).__init__() 194 | 195 | 196 | super(UnetGenerator, self).__init__() 197 | 198 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, 199 | innermost=True) # add the innermost layer 200 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 201 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, 202 | norm_layer=norm_layer, use_dropout=use_dropout) 203 | 204 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, 205 | norm_layer=norm_layer) 206 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, 207 | norm_layer=norm_layer) 208 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 209 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, 210 | norm_layer=norm_layer) # add the outermost layer 211 | 212 | 213 | def forward(self, input, mode=None): 214 | 215 | return self.model(input, mode=mode) 216 | 217 | 218 | class UnetSkipConnectionBlock(nn.Module): 219 | 220 | 221 | def __init__(self, outer_nc, inner_nc, input_nc=None, 222 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, 223 | use_dropout=False, model_split=0): 224 | 225 | self.innernc = inner_nc 226 | super(UnetSkipConnectionBlock, self).__init__() 227 | self.outermost = outermost 228 | self.innermost = innermost 229 | if type(norm_layer) == functools.partial: 230 | use_bias = norm_layer.func == nn.InstanceNorm2d 231 | else: 232 | use_bias = norm_layer == nn.InstanceNorm2d 233 | if input_nc is None: 234 | input_nc = outer_nc 235 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 236 | stride=2, padding=1, bias=use_bias) 237 | downrelu = nn.LeakyReLU(0.2, True) 238 | downnorm = norm_layer(inner_nc) 239 | uprelu = nn.ReLU(True) 240 | upnorm = norm_layer(outer_nc) 241 | 242 | if outermost: 243 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 244 | kernel_size=4, stride=2, 245 | padding=1) 246 | down = [downconv] 247 | up = [uprelu, upconv, nn.Tanh()] 248 | model = down + [submodule] + up 249 | elif innermost: 250 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 251 | kernel_size=4, stride=2, 252 | padding=1, bias=use_bias) 253 | down = [downrelu, downconv] 254 | up = [uprelu, upconv, upnorm] 255 | model = down + up 256 | else: 257 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 258 | kernel_size=4, stride=2, 259 | padding=1, bias=use_bias) 260 | down = [downrelu, downconv, downnorm] 261 | up = [uprelu, upconv, upnorm] 262 | 263 | if use_dropout: 264 | model = down + [submodule] + up + [nn.Dropout(0.5)] 265 | else: 266 | model = down + [submodule] + up 267 | 268 | self.model = nn.Sequential(*model) 269 | 270 | def forward(self, x, mode = None, audio_feat = None): 271 | if self.outermost: 272 | return self.model(x) 273 | else: 274 | return torch.cat([x, self.model(x)], 1) 275 | 276 | 277 | class NLayerDiscriminator(nn.Module): 278 | 279 | 280 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 281 | 282 | super(NLayerDiscriminator, self).__init__() 283 | if type(norm_layer) == functools.partial: 284 | use_bias = norm_layer.func == nn.InstanceNorm2d 285 | else: 286 | use_bias = norm_layer == nn.InstanceNorm2d 287 | 288 | kw = 4 289 | padw = 1 290 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 291 | nf_mult = 1 292 | nf_mult_prev = 1 293 | for n in range(1, n_layers): 294 | nf_mult_prev = nf_mult 295 | nf_mult = min(2 ** n, 8) 296 | sequence += [ 297 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 298 | norm_layer(ndf * nf_mult), 299 | nn.LeakyReLU(0.2, True) 300 | ] 301 | 302 | nf_mult_prev = nf_mult 303 | nf_mult = min(2 ** n_layers, 8) 304 | sequence += [ 305 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 306 | norm_layer(ndf * nf_mult), 307 | nn.LeakyReLU(0.2, True) 308 | ] 309 | 310 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 311 | self.model = nn.Sequential(*sequence) 312 | 313 | def forward(self, input): 314 | """Standard forward.""" 315 | return self.model(input) 316 | 317 | # import torch.functional as F 318 | 319 | class Residual(nn.Module): 320 | def __init__(self, fn): 321 | super().__init__() 322 | self.fn = fn 323 | 324 | def forward(self, x, **kwargs): 325 | return self.fn(x, **kwargs) + x 326 | 327 | class PreNorm(nn.Module): 328 | def __init__(self, dim, fn): 329 | super().__init__() 330 | self.norm = nn.LayerNorm(dim) 331 | self.fn = fn 332 | 333 | def forward(self, x, **kwargs): 334 | return self.fn(self.norm(x), **kwargs) 335 | 336 | class FeedForward(nn.Module): 337 | def __init__(self, dim, hidden_dim): 338 | super().__init__() 339 | self.net = nn.Sequential( 340 | nn.Linear(dim, hidden_dim), 341 | nn.GELU(), 342 | nn.Linear(hidden_dim, dim) 343 | ) 344 | 345 | def forward(self, x): 346 | return self.net(x) 347 | 348 | class Attention(nn.Module): 349 | def __init__(self, dim, heads=8): 350 | super().__init__() 351 | self.heads = heads 352 | self.scale = dim ** -0.5 353 | 354 | self.to_qkv = nn.Linear(dim, dim * 3, bias=False) 355 | self.to_out = nn.Linear(dim, dim) 356 | 357 | def forward(self, x, mask=None): 358 | b, n, _, h = *x.shape, self.heads 359 | qkv = self.to_qkv(x) 360 | q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h) 361 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 362 | 363 | if mask is not None: 364 | mask = F.pad(mask.flatten(1), (1, 0), value=True) 365 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 366 | mask = mask[:, None, :] * mask[:, :, None] 367 | dots.masked_fill_(~mask, float('-inf')) 368 | del mask 369 | 370 | attn = dots.softmax(dim=-1) 371 | 372 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 373 | out = rearrange(out, 'b h n d -> b n (h d)') 374 | out = self.to_out(out) 375 | return out 376 | 377 | class Transformer(nn.Module): 378 | def __init__(self, dim, depth, heads, mlp_dim): 379 | super().__init__() 380 | self.layers = nn.ModuleList([]) 381 | for _ in range(depth): 382 | self.layers.append(nn.ModuleList([ 383 | Residual(PreNorm(dim, Attention(dim, heads=heads))), 384 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim))) 385 | ])) 386 | 387 | def forward(self, x, mask=None): 388 | for attn, ff in self.layers: 389 | x = attn(x, mask=mask) 390 | x = ff(x) 391 | return x 392 | 393 | class CViT(nn.Module): 394 | def __init__(self, image_size=224, patch_size=7, channels=512, 395 | dim=1024, depth=6, heads=8, mlp_dim=2048): 396 | super().__init__() 397 | assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' 398 | 399 | self.features = nn.Sequential( 400 | nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), 401 | nn.BatchNorm2d(num_features=32), 402 | nn.ReLU(), 403 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 404 | nn.BatchNorm2d(num_features=32), 405 | nn.ReLU(), 406 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 407 | nn.BatchNorm2d(num_features=32), 408 | nn.ReLU(), 409 | nn.MaxPool2d(kernel_size=2, stride=2), 410 | 411 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 412 | nn.BatchNorm2d(num_features=64), 413 | nn.ReLU(), 414 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 415 | nn.BatchNorm2d(num_features=64), 416 | nn.ReLU(), 417 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 418 | nn.BatchNorm2d(num_features=64), 419 | nn.ReLU(), 420 | nn.MaxPool2d(kernel_size=2, stride=2), 421 | 422 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 423 | nn.BatchNorm2d(num_features=128), 424 | nn.ReLU(), 425 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), 426 | nn.BatchNorm2d(num_features=128), 427 | nn.ReLU(), 428 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), 429 | nn.BatchNorm2d(num_features=128), 430 | nn.ReLU(), 431 | nn.MaxPool2d(kernel_size=2, stride=2), 432 | 433 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), 434 | nn.BatchNorm2d(num_features=256), 435 | nn.ReLU(), 436 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), 437 | nn.BatchNorm2d(num_features=256), 438 | nn.ReLU(), 439 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), 440 | nn.BatchNorm2d(num_features=256), 441 | nn.ReLU(), 442 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), 443 | nn.BatchNorm2d(num_features=256), 444 | nn.ReLU(), 445 | nn.MaxPool2d(kernel_size=2, stride=2), 446 | 447 | nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), 448 | nn.BatchNorm2d(num_features=512), 449 | nn.ReLU(), 450 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), 451 | nn.BatchNorm2d(num_features=512), 452 | nn.ReLU(), 453 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), 454 | nn.BatchNorm2d(num_features=512), 455 | nn.ReLU(), 456 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), 457 | nn.BatchNorm2d(num_features=512), 458 | nn.ReLU(), 459 | nn.MaxPool2d(kernel_size=2, stride=2) 460 | ) 461 | 462 | patch_dim = channels * patch_size ** 2 463 | self.patch_size = patch_size 464 | self.pattern_matrix = nn.Parameter(torch.randn(1, 32, 512)) 465 | self.latent_matrix = nn.Parameter(torch.randn(1, 49, 32)) 466 | self.pos_embedding = nn.Parameter(torch.randn(32, 1, dim)) 467 | 468 | # self.to_embedding = nn.Linear(32*dim, dim) 469 | 470 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 471 | self.transformer = Transformer(dim, depth, heads, mlp_dim) 472 | 473 | self.to_cls_token = nn.Identity() 474 | 475 | self.mlp_head = nn.Sequential( 476 | nn.Linear(dim, mlp_dim), 477 | nn.ReLU(), 478 | ) 479 | 480 | def forward(self, img, mask=None): 481 | p = self.patch_size 482 | x = self.features(img) 483 | y = rearrange(x, 'b c (h p1) (w p2) -> b (c h w) (p1 p2)', p1=p, p2=p) 484 | y = torch.matmul(self.pattern_matrix, y) 485 | y = torch.matmul(y, self.latent_matrix) 486 | y = rearrange(y, '(b p1) h w -> b p1 (h w)', p1=1) 487 | cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) 488 | x = torch.cat((cls_tokens, y), 1) 489 | shape = x.shape[0] 490 | x += self.pos_embedding[0:shape] 491 | x = self.transformer(x, mask) 492 | x = self.to_cls_token(x[:, 0]) 493 | return self.mlp_head(x) 494 | 495 | class CAiT(nn.Module): 496 | def __init__(self, patch_size=7, dim=1024, depth=6, heads=8, mlp_dim=2048): 497 | super().__init__() 498 | 499 | self.features = nn.Sequential( 500 | nn.Conv2d(1, 96, kernel_size=(7, 7), stride=(2, 2)), 501 | nn.BatchNorm2d(96, track_running_stats=False), 502 | nn.ReLU(inplace=True), 503 | nn.MaxPool2d(kernel_size=(2, 2)), 504 | 505 | nn.Conv2d(96, 256, kernel_size=(5, 5), stride=(2, 2)), 506 | nn.BatchNorm2d(256, track_running_stats=False), 507 | nn.ReLU(inplace=True), 508 | nn.MaxPool2d(kernel_size=(2, 2)), 509 | 510 | nn.Conv2d(256, 256, kernel_size=(3, 3)), 511 | nn.BatchNorm2d(256, track_running_stats=False), 512 | nn.ReLU(inplace=True), 513 | 514 | nn.Conv2d(256, 256, kernel_size=(3, 3)), 515 | nn.BatchNorm2d(256, track_running_stats=False), 516 | nn.ReLU(inplace=True), 517 | 518 | nn.Conv2d(256, 256, kernel_size=(3, 3)), 519 | nn.BatchNorm2d(256, track_running_stats=False), 520 | nn.ReLU(inplace=True), 521 | nn.MaxPool2d(kernel_size=(3, 2)), 522 | ) 523 | 524 | self.patch_size = patch_size 525 | self.pattern_matrix = nn.Parameter(torch.randn(1, 32, 256)) 526 | self.latent_matrix = nn.Parameter(torch.randn(1, 40, 32)) 527 | self.pos_embedding = nn.Parameter(torch.randn(32, 1, dim)) 528 | # self.to_embedding = nn.Linear(32 * dim, dim) 529 | 530 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 531 | self.transformer = Transformer(dim, depth, heads, mlp_dim) 532 | 533 | self.to_cls_token = nn.Identity() 534 | 535 | self.mlp_head = nn.Sequential( 536 | nn.Linear(dim, mlp_dim), 537 | nn.ReLU(), 538 | ) 539 | 540 | def forward(self, img, mask=None): 541 | p = self.patch_size 542 | x = self.features(img) 543 | y = rearrange(x, 'b c (h p1) (w p2) -> b (c h w) (p1 p2)', p1=8, p2=5) 544 | y = torch.matmul(self.pattern_matrix, y) 545 | y = torch.matmul(y, self.latent_matrix) 546 | y = rearrange(y, '(b p1) h w -> b p1 (h w)', p1=1) 547 | cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) 548 | x = torch.cat((cls_tokens, y), 1) 549 | shape = x.shape[0] 550 | x += self.pos_embedding[0:shape] 551 | x = self.transformer(x, mask) 552 | x = self.to_cls_token(x[:, 0]) 553 | return self.mlp_head(x) 554 | 555 | 556 | 557 | 558 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | 11 | def __init__(self): 12 | self.initialized = False 13 | def initialize(self, parser): 14 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 15 | parser.add_argument('--image_list', type=str, default='/train_list.txt', help='name of the image_list') 16 | parser.add_argument('--json_path', type=str, default='', help='name of the audio_list') 17 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 18 | parser.add_argument('--gpu_ids', type=str, default='3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 19 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 20 | # model parameters 21 | parser.add_argument('--model', type=str, default='DFD', help='chooses which model to use.') 22 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 23 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 24 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 25 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 26 | parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') 27 | parser.add_argument('--netG', type=str, default='unet_256', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') 28 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 29 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 30 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 31 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 32 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 33 | parser.add_argument('--train_mode', default='a', help="") 34 | # dataset parameters 35 | parser.add_argument('--dataset_mode', type=str, default='DFDC', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') 36 | parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') 37 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 38 | parser.add_argument('--num_threads', default=8, type=int, help='# threads for loading data') 39 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 40 | parser.add_argument('--load_size', type=int, default=224, help='scale images to this size') 41 | parser.add_argument('--crop_size', type=int, default=224, help='then crop to this size') 42 | parser.add_argument('--mode', type=str, default='train', help='train/val/test') 43 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 44 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 45 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 46 | parser.add_argument('--display_winsize', type=int, default=224, help='display window size for both visdom and HTML') 47 | parser.add_argument('--seg_num', type=int, default=100, help='the image number extracted from video') 48 | parser.add_argument('--frames_in_segs', type=int, default=1, 49 | help='image number for reconstruction, 1 for static reconstruction') 50 | parser.add_argument('--audio_length', type=int, default=300, 51 | help='frames from audio, it should be minimal than 1000. Preferably 2^N for the convenience of U-net') 52 | parser.add_argument('--scale_size', type=int, default=0, 53 | help='frames from audio, it should be minimal than 1000. Preferably 2^N for the convenience of U-net') 54 | 55 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 56 | parser.add_argument('--load_iter', type=int, default=0, help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 57 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 58 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 59 | self.initialized = True 60 | return parser 61 | 62 | def gather_options(self): 63 | 64 | if not self.initialized: # check if it has been initialized 65 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 66 | parser = self.initialize(parser) 67 | 68 | 69 | opt, _ = parser.parse_known_args() 70 | 71 | 72 | model_name = opt.model 73 | model_option_setter = models.get_option_setter(model_name) 74 | parser = model_option_setter(parser, self.isTrain) 75 | opt, _ = parser.parse_known_args() # parse again with new defaults 76 | 77 | self.parser = parser 78 | return parser.parse_args() 79 | 80 | def print_options(self, opt): 81 | 82 | message = '' 83 | message += '----------------- Options ---------------\n' 84 | for k, v in sorted(vars(opt).items()): 85 | comment = '' 86 | default = self.parser.get_default(k) 87 | if v != default: 88 | comment = '\t[default: %s]' % str(default) 89 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 90 | message += '----------------- End -------------------' 91 | print(message) 92 | 93 | # save to the disk 94 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 95 | util.mkdirs(expr_dir) 96 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 97 | with open(file_name, 'wt') as opt_file: 98 | opt_file.write(message) 99 | opt_file.write('\n') 100 | 101 | def parse(self): 102 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 103 | opt = self.gather_options() 104 | opt.isTrain = self.isTrain # train or test 105 | 106 | # process opt.suffix 107 | if opt.suffix: 108 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 109 | opt.name = opt.name + suffix 110 | 111 | self.print_options(opt) 112 | 113 | str_ids = opt.gpu_ids.split(',') 114 | opt.gpu_ids = [] 115 | for str_id in str_ids: 116 | id = int(str_id) 117 | if id >= 0: 118 | opt.gpu_ids.append(id) 119 | if len(opt.gpu_ids) > 0: 120 | torch.cuda.set_device(opt.gpu_ids[0]) 121 | 122 | self.opt = opt 123 | return self.opt 124 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | 2 | from .base_options import BaseOptions 3 | 4 | 5 | class TestOptions(BaseOptions): 6 | """This class includes test options. 7 | 8 | It also includes shared options defined in BaseOptions. 9 | """ 10 | 11 | def initialize(self, parser): 12 | parser = BaseOptions.initialize(self, parser) # define shared options 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=300, help='how many test images to run') 19 | # rewrite devalue values 20 | parser.set_defaults(model='test') 21 | # To avoid cropping, the load_size should be the same as crop_size 22 | # parser.set_defaults(load_size=parser.get_default('crop_size')) 23 | self.isTrain = False 24 | return parser 25 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 29 | # training parameters 30 | parser.add_argument('--n_epochs', type=int, default=2, help='number of epochs with the initial learning rate') 31 | parser.add_argument('--n_epochs_decay', type=int, default=20, help='number of epochs to linearly decay learning rate to zero') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 33 | parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate for adam') 34 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 35 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 36 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 37 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 38 | 39 | self.isTrain = True 40 | return parser 41 | -------------------------------------------------------------------------------- /test_DF.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.test_options import TestOptions 3 | from data import create_dataset 4 | from models import create_model 5 | import os 6 | import math 7 | import cv2 8 | from PIL import Image 9 | from util import util 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from tqdm import tqdm 13 | import pylab as pl 14 | import torch 15 | import random 16 | from torchvision import transforms 17 | from torchcam.cams import SmoothGradCAMpp 18 | from torchvision.transforms.functional import to_pil_image 19 | from torchcam.utils import overlay_mask 20 | 21 | def auc(real, fake): 22 | label_all = [] 23 | target_all = [] 24 | for ind in real: 25 | target_all.append(1) 26 | label_all.append(-ind) 27 | for ind in fake: 28 | target_all.append(0) 29 | label_all.append(-ind) 30 | 31 | from sklearn.metrics import roc_auc_score 32 | return roc_auc_score(target_all, label_all) 33 | 34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | 36 | if __name__ == '__main__': 37 | opt = TestOptions().parse() # get test options 38 | # hard-code some parameters for test 39 | opt.num_threads = 4 # test code only supports num_threads = 1 40 | opt.batch_size = 1 # test code only supports batch_size = 1 41 | opt.serial_batches = False #r disable data shuffling; comment this line if results on randomly chosen images ae needed. 42 | opt.no_flip = True # no flip; comment this line if results on flipped images are needed. 43 | opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 44 | opt.mode = 'test' 45 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 46 | model = create_model(opt) # create a model given opt.model and other options 47 | model.setup(opt) # regular setup: load and print networks; create schedulers 48 | 49 | if opt.eval: 50 | model.eval() 51 | 52 | dataset_size = len(dataset) 53 | print('The number of test images dir = %d' % dataset_size) 54 | 55 | 56 | total_iters = 0 57 | label = None 58 | real = [] 59 | fake = [] 60 | 61 | with tqdm(total=dataset_size) as pbar: 62 | for i, data in enumerate(dataset): 63 | input_data = {'img_real': data['img_real'], 64 | 'img_fake': data['img_fake'], 65 | 'aud_real': data['aud_real'], 66 | 'aud_fake': data['aud_fake'], 67 | } 68 | model.set_input(input_data) 69 | 70 | dist_AV, dist_VA = model.val() 71 | real.append(dist_AV.item()) 72 | for i in dist_VA: 73 | fake.append(i.item()) 74 | total_iters += 1 75 | pbar.update() 76 | 77 | print(auc(real, fake)) -------------------------------------------------------------------------------- /train_DF.py: -------------------------------------------------------------------------------- 1 | """ 2 | Anonymous release of VFD 3 | Part of the framework is borrowed from 4 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 5 | Many thanks to these authors! 6 | """ 7 | import time 8 | from options.train_options import TrainOptions 9 | from data import create_dataset 10 | from models import create_model 11 | import os 12 | import math 13 | import cv2 14 | from PIL import Image 15 | from util import util 16 | from tqdm import tqdm 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | import pylab as pl 20 | import torch 21 | import random 22 | 23 | 24 | def setup_seed(seed): 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | np.random.seed(seed) 28 | random.seed(seed) 29 | torch.backends.cudnn.deterministic = True 30 | setup_seed(20) 31 | 32 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 33 | 34 | def auc(real, fake): 35 | label_all = [] 36 | target_all = [] 37 | for ind in real: 38 | target_all.append(1) 39 | label_all.append(ind) 40 | 41 | for ind in fake: 42 | target_all.append(0) 43 | label_all.append(ind) 44 | 45 | from sklearn.metrics import roc_auc_score 46 | return roc_auc_score(target_all, label_all) 47 | 48 | if __name__ == '__main__': 49 | opt = TrainOptions().parse() 50 | dataset = create_dataset(opt) 51 | dataset_size = len(dataset) 52 | 53 | opt.mode = 'val' 54 | opt.serial_batches = False 55 | dataset_val = create_dataset(opt) 56 | dataset_val_size = len(dataset_val) 57 | print('The number of training images dir = %d' % dataset_size) 58 | print('The number of val images dir = %d' % dataset_val_size) 59 | 60 | model = create_model(opt) # create a model given opt.model and other options 61 | model.setup(opt) # regular setup: load and print networks; create schedulers 62 | total_iters = 0 # the total number of training iterations 63 | 64 | loss_x = [] 65 | loss_y_g = [] 66 | loss_y_l = [] 67 | loss_y_t = [] 68 | loss_y_f = [] 69 | loss_epo = 0 70 | for epoch in range(opt.epoch_count, 71 | opt.n_epochs + opt.n_epochs_decay + 1): 72 | epoch_start_time = time.time() 73 | iter_data_time = time.time() 74 | time_start = epoch_start_time 75 | epoch_iter = 0 76 | iter_start_time = time.time() 77 | if total_iters % opt.print_freq == 0: 78 | t_data = iter_start_time - iter_data_time 79 | loss_G_AV_all = 0 80 | for i, data in enumerate(dataset): 81 | input_data = {'img_real': data['img_real'], 82 | 'img_fake': data['img_fake'], 83 | 'aud_real': data['aud_real'], 84 | 'aud_fake': data['aud_fake'], 85 | } 86 | model.set_input(input_data) 87 | loss_G_AV = model.optimize_parameters() 88 | 89 | loss_G_AV_all += loss_G_AV 90 | loss_epo += loss_G_AV 91 | total_iters += 1 92 | 93 | if total_iters % 10 == 0: 94 | print('epoch %d, total_iters %d: loss_G_AV: %.3f(%.3f), time cost: %.2f s' % 95 | (epoch, total_iters, loss_G_AV, loss_G_AV_all / total_iters, time.time() - time_start)) 96 | time_start = time.time() 97 | 98 | if total_iters % opt.save_latest_freq == 0: 99 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 100 | save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' 101 | model.save_networks(save_suffix) 102 | model.eval() 103 | real = [] 104 | fake = [] 105 | with tqdm(total=len(dataset_val)) as pbar: 106 | for i, data in enumerate(dataset_val): 107 | input_data = {'img_real': data['img_real'], 108 | 'img_fake': data['img_fake'], 109 | 'aud_real': data['aud_real'], 110 | 'aud_fake': data['aud_fake'], 111 | } 112 | model.set_input(input_data) 113 | dist_AV, dist_VA = model.val() 114 | real.append(dist_AV.item()) 115 | for i in dist_VA: 116 | fake.append(i.item()) 117 | pbar.update() 118 | _auc = auc(real, fake) 119 | print('Val auc (for refer) %.4f'%(_auc)) 120 | model.train() 121 | 122 | iter_data_time = time.time() 123 | if epoch % opt.save_epoch_freq == 0: 124 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 125 | model.save_networks('latest') 126 | model.save_networks(epoch) 127 | model.eval() 128 | real = [] 129 | fake = [] 130 | for i, data in enumerate(dataset_val): 131 | input_data = {'img_real': data['img_real'], 132 | 'img_fake': data['img_fake'], 133 | 'aud_real': data['aud_real'], 134 | 'aud_fake': data['aud_fake'], 135 | } 136 | model.set_input(input_data) 137 | dist_AV, dist_VA = model.val() 138 | real.append(dist_AV.item()) 139 | for i in dist_VA: 140 | fake.append(i.item()) 141 | _auc = auc(real, fake) 142 | model.train() 143 | loss_epo = 0 144 | 145 | print('End of epoch %d / %d \t Time Taken: %d sec' % ( 146 | epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) 147 | model.update_learning_rate() -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = image_tensor.cpu().float().numpy() # convert it into a numpy array 22 | # print_numpy(image_numpy.shape) 23 | if image_numpy.shape[0] == 1: # grayscale to RGB 24 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 25 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 26 | else: # if it is a numpy array, do nothing 27 | image_numpy = input_image 28 | return image_numpy.astype(imtype) 29 | 30 | 31 | def diagnose_network(net, name='network'): 32 | """Calculate and print the mean of average absolute(gradients) 33 | 34 | Parameters: 35 | net (torch network) -- Torch network 36 | name (str) -- the name of the network 37 | """ 38 | mean = 0.0 39 | count = 0 40 | for param in net.parameters(): 41 | if param.grad is not None: 42 | mean += torch.mean(torch.abs(param.grad.data)) 43 | count += 1 44 | if count > 0: 45 | mean = mean / count 46 | print(name) 47 | print(mean) 48 | 49 | 50 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 51 | """Save a numpy image to the disk 52 | 53 | Parameters: 54 | image_numpy (numpy array) -- input numpy array 55 | image_path (str) -- the path of the image 56 | """ 57 | 58 | image_pil = Image.fromarray(image_numpy) 59 | h, w, _ = image_numpy.shape 60 | 61 | if aspect_ratio > 1.0: 62 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 63 | if aspect_ratio < 1.0: 64 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 65 | image_pil.save(image_path) 66 | 67 | 68 | def print_numpy(x, val=True, shp=False): 69 | """Print the mean, min, max, median, std, and size of a numpy array 70 | 71 | Parameters: 72 | val (bool) -- if print the values of the numpy array 73 | shp (bool) -- if print the shape of the numpy array 74 | """ 75 | x = x.astype(np.float64) 76 | if shp: 77 | print('shape,', x.shape) 78 | if val: 79 | x = x.flatten() 80 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 81 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 82 | 83 | 84 | def mkdirs(paths): 85 | """create empty directories if they don't exist 86 | 87 | Parameters: 88 | paths (str list) -- a list of directory paths 89 | """ 90 | if isinstance(paths, list) and not isinstance(paths, str): 91 | for path in paths: 92 | mkdir(path) 93 | else: 94 | mkdir(paths) 95 | 96 | 97 | def mkdir(path): 98 | """create a single empty directory if it didn't exist 99 | 100 | Parameters: 101 | path (str) -- a single directory path 102 | """ 103 | if not os.path.exists(path): 104 | os.makedirs(path) 105 | --------------------------------------------------------------------------------