├── .gitignore ├── Readme.md ├── loaders ├── __init__.py ├── v1_sup_id_loader.py ├── v2_unsup_loader.py ├── v3_triplet_loader.py ├── v4_de_hq3_loader.py ├── v5_voxceleb_cluster_ordered_loader.py ├── v6_voxceleb_loader_for_deepcluster.py └── v7_sup_id_contrastive_loader.py ├── models ├── __init__.py ├── cae_model.py ├── de_hq3_model.py ├── fop_model.py └── my_model.py ├── preprocess ├── Readme.md ├── __init__.py ├── face_extractor │ ├── 1_inception_v1.py │ ├── 2_deepface.py │ ├── 3_facexzoo.py │ ├── __init__.py │ ├── loaders │ │ ├── __init__.py │ │ ├── img_loader.py │ │ └── loader4deepface.py │ └── models │ │ ├── __init__.py │ │ └── incep.py ├── preprocess │ ├── 1_mp4_extract_wav.py │ ├── 2_wav_vad.py │ ├── 3_mp4_extract_frames.py │ ├── 4_face_crop_mtcnn.py │ ├── 5_pose_estimation.py │ └── __init__.py └── voice_extractor │ ├── 1_ecapa_tdnn.py │ ├── 2_resemblizer.py │ ├── __init__.py │ └── loaders │ ├── __init__.py │ └── voice_loader.py ├── scripts ├── 1_verification.py ├── 2_matching.py ├── 3_retrieval.py └── __init__.py ├── utils ├── __init__.py ├── angles_utils.py ├── barlow_loss.py ├── config.py ├── deep_coral_loss.py ├── deepcluster_util.py ├── distance_util.py ├── dlib_util.py ├── eva_emb_full.py ├── eval_shortcut.py ├── faceBlendCommon.py ├── keops_kmeans.py ├── losses │ ├── __init__.py │ ├── barlow_loss.py │ ├── center_loss_eccv16.py │ ├── center_loss_learnableW_L2dist.py │ ├── cmpc_loss.py │ ├── fop_loss.py │ ├── my_pml_infonce_v2.py │ ├── softmax_loss.py │ ├── triplet_hq1.py │ ├── triplet_lafv.py │ ├── unsup_nce.py │ ├── wen_explicit_loss.py │ └── wen_reweight.py ├── map_evaluate.py ├── model_selector.py ├── model_util.py ├── my_git.py ├── my_parser.py ├── my_softmax_loss.py ├── pair_selection_util.py ├── path_util.py ├── pickle_util.py ├── sample_util.py ├── seed_util.py ├── vec_util.py ├── wb_util.py └── worker_util.py ├── works ├── 10_DE_HQ3.py ├── 11_SS_DIM_VFMR_Barlow.py ├── 1_pins.py ├── 2_FV-CME.py ├── 3_LAFV.py ├── 5_Wen.py ├── 6_FOP.py ├── 7_CMPC.py ├── 8_CAE.py └── 9_SL.py └── works_loss_cmp ├── 0_loss_compare.py ├── 1_contrastive_loss.py └── 2_triplet_loss.py /.gitignore: -------------------------------------------------------------------------------- 1 | 4_HQ1_pml.py 2 | 4_HQ1.py 3 | dataset/* 4 | outputs/* 5 | jupyters/ 6 | results/ 7 | configs/wb*.json 8 | 说明.txt 9 | 说明.sh 10 | 停泊/ 11 | tmp/ 12 | data/ 13 | outputs 14 | tmp.py 15 | pp.sh 16 | .DS_Store 17 | result.json 18 | #idea 19 | .idea 20 | wandb/ 21 | z_命令.txt 22 | .wb_config.json 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | share/python-wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | MANIFEST 50 | 51 | # PyInstaller 52 | # Usually these files are written by a python script from a template 53 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 54 | *.manifest 55 | *.spec 56 | 57 | # Installer logs 58 | pip-log.txt 59 | pip-delete-this-directory.txt 60 | 61 | # Unit test / coverage reports 62 | htmlcov/ 63 | .tox/ 64 | .nox/ 65 | .coverage 66 | .coverage.* 67 | .cache 68 | nosetests.xml 69 | coverage.xml 70 | *.cover 71 | *.py,cover 72 | .hypothesis/ 73 | .pytest_cache/ 74 | cover/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | db.sqlite3-journal 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | .pybuilder/ 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | 107 | # pyenv 108 | # For a library or package, you might want to ignore these files since the code is 109 | # intended to run in multiple environments; otherwise, check them in: 110 | # .python-version 111 | 112 | # pipenv 113 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 114 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 115 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 116 | # install all needed dependencies. 117 | #Pipfile.lock 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Voice-Face Association Learning Evaluation 2 | 3 | - Reproduce various works based on unified standards 😃 4 | - High-speed training and testing ⚡ 5 | - Easy to extend 💭 6 | 7 | ## Installation 8 | 9 | 1. Clone or download this repository. 10 | 11 | 2. Install the required packages: 12 | 13 | ``` 14 | pytorch>=1.8.1 15 | wandb>=0.12.10 16 | ``` 17 | 18 | 3. Download the dataset: 19 | 20 | The dataset is based on VoxCeleb and is divided into train/valid/test sets according to "Learnable Pins: Crossmodal Embeddings for Person Identity, 2018, ECCV" (901/100/250). 21 | 22 | Download `dataset.zip` from [Google Drive](https://drive.google.com/file/d/1sVQ7I4_9rwWF18vk4VZFVAx-8Inv-wlT/view?usp=sharing) (2.3GB) and unzip it to the project root directory. The folder structure should be as follows: 23 | 24 | ``` 25 | dataset 26 | ├── evals 27 | │ ├── test_matching_10.pkl 28 | │ ├── test_matching_g.pkl 29 | │ ├── test_matching.pkl 30 | │ ├── test_retrieval.pkl 31 | │ ├── test_verification_g.pkl 32 | │ ├── test_verification.pkl 33 | │ └── valid_verification.pkl 34 | ├── info 35 | │ ├── name2gender.pkl 36 | │ ├── name2jpgs_wavs.pkl 37 | │ ├── name2movies.pkl 38 | │ ├── name2voice_id.pkl 39 | │ ├── train_valid_test_names.pkl 40 | │ └── works 41 | │ └── wen_weights.txt 42 | ├── face_input.pkl 43 | └── voice_input.pkl 44 | ``` 45 | 46 | 47 | 48 | ## Run a Production 49 | 50 | - Learnable Pins: Crossmodal Embeddings for Person Identity, 2018, ECCV 51 | 52 | ``` 53 | python works/1_pins.py 54 | ``` 55 | 56 | - Face-Voice Matching using Cross-modal Embeddings, MM, 2018 57 | 58 | ``` 59 | python works/2_FV-CME.py 60 | ``` 61 | 62 | - On Learning Associations of Faces and Voices, ACCV, 2018 63 | 64 | ``` 65 | python works/3_LAFV.py 66 | ``` 67 | 68 | - Disjoint Mapping Network for Cross-modal Matching of Voices and Faces, ICLR, 2019 69 | 70 | ``` 71 | python works/11_SS_DIM_VFMR_Barlow.py --name=DIMNet 72 | ``` 73 | 74 | - Voice-Face Cross-modal Matching and Retrieval - A Benchmark, 2019 75 | 76 | ``` 77 | python works/11_SS_DIM_VFMR_Barlow.py --name=VFMR 78 | ``` 79 | 80 | - Seeking the Shape of Sound: An Adaptive Framework for Learning Voice-Face Association, CVPR, 2021 81 | 82 | ``` 83 | python works/5_Wen.py 84 | ``` 85 | 86 | - Fusion and Orthogonal Projection for Improved Face-Voice Association, ICASSP, 2022 87 | 88 | ``` 89 | python works/6_FOP.py 90 | ``` 91 | 92 | - Unsupervised Voice-Face Representation Learning by Cross-Modal Prototype Contrast, IJCAI, 2022 93 | 94 | ``` 95 | python works/7_CMPC.py 96 | ``` 97 | 98 | - Self-Lifting: A Novel Framework for Unsupervised Voice-Face Association Learning, ICMR, 2022 99 | 100 | ``` 101 | python works/9_SL.py 102 | ``` 103 | 104 | for self-lifting 105 | 106 | ``` 107 | python works/8_CAE.py 108 | ``` 109 | 110 | for the CCAE baseline 111 | 112 | ``` 113 | python works/11_SS_DIM_VFMR_Barlow.py --name=SL-Barlow 114 | ``` 115 | 116 | for the Barlow Twins baseline 117 | 118 | ## Integration with Wandb 119 | 120 | *Use [wandb](https://wandb.ai) to view the training process:* 121 | 122 | 1. Create a `.wb_config.json` file in the project root with the following content: 123 | 124 | ``` 125 | { 126 | "WB_KEY": "Your wandb auth key" 127 | } 128 | ``` 129 | 130 | 2. Add `--dryrun=False` to the training command, for example: 131 | 132 | ``` 133 | python main.py --dryrun=False 134 | ``` 135 | 136 | ## -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/loaders/__init__.py -------------------------------------------------------------------------------- /loaders/v1_sup_id_loader.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | from utils import pickle_util, sample_util, worker_util 4 | from torch.utils.data import DataLoader 5 | from utils.config import face_emb_dict, voice_emb_dict 6 | 7 | 8 | # 在选择人名的时候天然就是不重复的,这个采样器不能用于无监督函数 9 | 10 | def get_iter(batch_size, full_length): 11 | train_iter = DataLoader(DataSet(full_length), batch_size=batch_size, shuffle=True, pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 12 | return train_iter 13 | 14 | 15 | class DataSet(torch.utils.data.Dataset): 16 | 17 | def __init__(self, dataset_length): 18 | train_names = pickle_util.read_pickle("./dataset/info/train_valid_test_names.pkl")["train"] 19 | train_names.sort() 20 | name2id = {train_names[i]: i for i in range(len(train_names))} 21 | self.train_names = train_names 22 | self.name2id = name2id 23 | 24 | self.name2movies = pickle_util.read_pickle("./dataset/info/name2movies.pkl") 25 | self.name2gender = pickle_util.read_pickle("./dataset/info/name2gender.pkl") 26 | self.dataset_length = dataset_length 27 | 28 | def __len__(self): 29 | return self.dataset_length 30 | 31 | def __getitem__(self, index): 32 | # 这种筛选方式会造成人名天然就是不重复的 33 | name = self.train_names[index % len(self.train_names)] 34 | 35 | movie_obj = sample_util.random_element(self.name2movies[name]) 36 | 37 | jpg_path = sample_util.random_element(movie_obj["jpgs"]) 38 | wav_path = sample_util.random_element(movie_obj["wavs"]) 39 | 40 | voice_tensor = voice_emb_dict[wav_path] 41 | face_tensor = face_emb_dict[jpg_path] 42 | 43 | the_id = torch.as_tensor(self.name2id[name]).long() 44 | the_gender = torch.as_tensor(self.name2gender[name]).long() 45 | return voice_tensor, face_tensor, the_id, the_gender 46 | -------------------------------------------------------------------------------- /loaders/v2_unsup_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import pickle_util, sample_util, worker_util, vec_util 3 | from torch.utils.data import DataLoader 4 | from utils.config import face_emb_dict, voice_emb_dict 5 | import numpy as np 6 | 7 | 8 | # 无监督采样器,先选择一个movie,然后选择一个clip 9 | 10 | def get_iter(batch_size, dataset_length): 11 | train_iter = DataLoader(DataSet(dataset_length), batch_size=batch_size, shuffle=True, pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 12 | return train_iter 13 | 14 | 15 | class DataSet(torch.utils.data.Dataset): 16 | 17 | def __init__(self, dataset_length): 18 | train_names = pickle_util.read_pickle("./dataset/info/train_valid_test_names.pkl")["train"] 19 | name2movies = pickle_util.read_pickle("./dataset/info/name2movies.pkl") 20 | train_movies = [] 21 | for name in train_names: 22 | train_movies += name2movies[name] 23 | self.train_movies = train_movies 24 | self.dataset_length = dataset_length 25 | 26 | def __len__(self): 27 | return self.dataset_length 28 | 29 | def __getitem__(self, index): 30 | # find a movie 31 | movie_id = index % len(self.train_movies) 32 | movie_obj = self.train_movies[movie_id] 33 | 34 | # sample an image and a voice clip 35 | jpg_path = sample_util.random_element(movie_obj["jpgs"]) 36 | wav_path = sample_util.random_element(movie_obj["wavs"]) 37 | 38 | voice_tensor = voice_emb_dict[wav_path] 39 | face_tensor = face_emb_dict[jpg_path] 40 | return voice_tensor, face_tensor, torch.tensor(movie_id).long() 41 | -------------------------------------------------------------------------------- /loaders/v3_triplet_loader.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | from utils import pickle_util, sample_util, worker_util 4 | from torch.utils.data import DataLoader 5 | from utils.config import face_emb_dict, voice_emb_dict 6 | 7 | 8 | # 用于LAFV,最后适用于triplet loss 9 | def get_iter(batch_size, full_length): 10 | train_iter = DataLoader(DataSet(full_length), batch_size=batch_size, shuffle=True, pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 11 | return train_iter 12 | 13 | 14 | class DataSet(torch.utils.data.Dataset): 15 | 16 | def __init__(self, dataset_length): 17 | train_names = pickle_util.read_pickle("./dataset/info/train_valid_test_names.pkl")["train"] 18 | train_names.sort() 19 | name2id = {train_names[i]: i for i in range(len(train_names))} 20 | self.train_names = train_names 21 | self.name2id = name2id 22 | 23 | self.name2movies = pickle_util.read_pickle("./dataset/info/name2movies.pkl") 24 | self.name2gender = pickle_util.read_pickle("./dataset/info/name2gender.pkl") 25 | self.dataset_length = dataset_length 26 | 27 | def __len__(self): 28 | return self.dataset_length 29 | 30 | def __getitem__(self, index): 31 | name1, name2 = sample_util.random_elements(self.train_names, 2) 32 | v1, f1 = self.load_one_person(name1) 33 | v2, f2 = self.load_one_person(name2) 34 | return v1, f1, v2, f2 35 | 36 | def load_one_person(self, name1): 37 | movie_obj = sample_util.random_element(self.name2movies[name1]) 38 | wav_name = sample_util.random_element(movie_obj["wavs"]) 39 | jpg_name = sample_util.random_element(movie_obj["jpgs"]) 40 | voice_tensor = torch.FloatTensor(voice_emb_dict[wav_name]) 41 | face_tensor = torch.FloatTensor(face_emb_dict[jpg_name]) 42 | return voice_tensor, face_tensor 43 | -------------------------------------------------------------------------------- /loaders/v4_de_hq3_loader.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | from utils import pickle_util, sample_util, worker_util 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | from utils.config import face_emb_dict, voice_emb_dict 7 | from utils import vec_util 8 | 9 | 10 | def get_iter(batch_size, full_length): 11 | train_iter = DataLoader(DataSet(full_length), batch_size=batch_size, shuffle=True, pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 12 | return train_iter 13 | 14 | 15 | class DataSet(torch.utils.data.Dataset): 16 | 17 | def __init__(self, dataset_length): 18 | train_names = pickle_util.read_pickle("./dataset/info/train_valid_test_names.pkl")["train"] 19 | train_names.sort() 20 | name2id = {train_names[i]: i for i in range(len(train_names))} 21 | self.train_names = train_names 22 | self.name2id = name2id 23 | 24 | self.name2movies = pickle_util.read_pickle("./dataset/info/name2movies.pkl") 25 | self.name2gender = pickle_util.read_pickle("./dataset/info/name2gender.pkl") 26 | self.dataset_length = dataset_length 27 | 28 | def __len__(self): 29 | return self.dataset_length 30 | 31 | def __getitem__(self, index): 32 | name = self.train_names[index % len(self.train_names)] 33 | 34 | movie_obj = sample_util.random_element(self.name2movies[name]) 35 | 36 | jpg_path = sample_util.random_element(movie_obj["jpgs"]) 37 | wav_path = sample_util.random_element(movie_obj["wavs"]) 38 | 39 | voice_tensor = voice_emb_dict[wav_path] 40 | face_tensor = face_emb_dict[jpg_path] 41 | 42 | the_id = torch.as_tensor(self.name2id[name]).long() 43 | return voice_tensor, face_tensor, the_id 44 | 45 | 46 | # def load_voice_data(short_path): 47 | # assert short_path.endswith(".wav") 48 | # local_path = "./dataset/features/voice/" + short_path + ".npy" 49 | # vec = np.load(local_path, allow_pickle=True) 50 | # # to 512dim 51 | # # vec = np.concatenate([vec, vec, vec])[0:512] 52 | # # vec_util.to_unit_vector(vec) 53 | # return torch.FloatTensor(vec) 54 | -------------------------------------------------------------------------------- /loaders/v5_voxceleb_cluster_ordered_loader.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import numpy as np 3 | import torch 4 | from utils import pickle_util, worker_util 5 | from utils.path_util import look_up 6 | 7 | from torch.utils.data import DataLoader 8 | import collections 9 | 10 | 11 | def extract_embeddings(name2face_emb, name2voice_emb, model): 12 | face_iter = get_ordered_iter(512, name2face_emb, name2voice_emb, is_face=True) 13 | movies, emb_face = extract_embeddings_core(face_iter, model.face_encoder) 14 | 15 | voice_iter = get_ordered_iter(512, name2face_emb, name2voice_emb, is_face=False) 16 | movies2, emb_voice = extract_embeddings_core(voice_iter, model.voice_encoder) 17 | 18 | assert len(movies2) == len(movies) 19 | final_emb = np.hstack([emb_voice, emb_face]) 20 | return movies, final_emb, emb_voice, emb_face 21 | 22 | 23 | def extract_embeddings_core(ordered_iter, encoder): 24 | # 1.extract embedding 25 | encoder.eval() 26 | the_dict = collections.defaultdict(list) 27 | for data in ordered_iter: 28 | with torch.no_grad(): 29 | batch_movie, tensor = data 30 | # ipdb.set_trace() 31 | batch_emb = encoder(tensor.cuda()).detach().cpu().numpy() 32 | for emb, movie in zip(batch_emb, batch_movie): 33 | the_dict[movie].append(emb) 34 | encoder.train() 35 | 36 | # 2. merge embedding by video 37 | final_dict = {} 38 | for key, arr in the_dict.items(): 39 | # arr:[batch,emb] 40 | final_dict[key] = np.mean(arr, axis=0) 41 | 42 | # 3.sort 43 | videos = list(final_dict.keys()) 44 | videos.sort() 45 | emb_array = np.array([final_dict[key] for key in videos]) 46 | 47 | return videos, emb_array 48 | 49 | 50 | def get_ordered_iter(batch_size, name2face_emb, name2voice_emb, is_face): 51 | train_iter = DataLoader(OrderedDataSet(is_face, name2face_emb, name2voice_emb), 52 | batch_size=batch_size, shuffle=False, 53 | pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 54 | return train_iter 55 | 56 | 57 | class OrderedDataSet(torch.utils.data.Dataset): 58 | 59 | def __init__(self, is_face, name2face_emb, name2voice_emb): 60 | name2movies = pickle_util.read_pickle(look_up("./dataset/info/name2movies.pkl")) 61 | train_names = pickle_util.read_pickle(look_up("./dataset/info/train_valid_test_names.pkl"))["train"] 62 | 63 | # 3.数据 64 | all_jpgs = [] 65 | all_wavs = [] 66 | 67 | for name in train_names: 68 | movies = name2movies[name] 69 | for movie in movies: 70 | # movie.keys: ['jpgs', 'wavs', 'movie', 'person'] 71 | wavs = movie["wavs"] 72 | jpgs = movie["jpgs"] 73 | # wav: 'A.J._Buckley/J9lHsKG98U8/00025.wav' 74 | # jpg: 'A.J._Buckley/J9lHsKG98U8/0010225.jpg' 75 | for short_path in jpgs: 76 | movie_name = "/".join(short_path.split("/")[0:2]) 77 | # A.J._Buckley/J9lHsKG98U8 78 | all_jpgs.append([movie_name, short_path]) 79 | for short_path in wavs: 80 | movie_name = "/".join(short_path.split("/")[0:2]) 81 | all_wavs.append([movie_name, short_path]) 82 | 83 | if is_face: 84 | self.data = all_jpgs 85 | self.name2emb = name2face_emb 86 | else: 87 | self.data = all_wavs 88 | self.name2emb = name2voice_emb 89 | 90 | self.is_face = is_face 91 | 92 | def __len__(self): 93 | return len(self.data) 94 | 95 | def __getitem__(self, index): 96 | movie, short_path = self.data[index] 97 | tensor = torch.FloatTensor(self.name2emb[short_path]) 98 | if self.is_face: 99 | assert len(tensor) == 512 100 | # ipdb.set_trace() 101 | return movie, tensor 102 | -------------------------------------------------------------------------------- /loaders/v6_voxceleb_loader_for_deepcluster.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import numpy as np 3 | import torch 4 | from utils import pickle_util, sample_util, worker_util, vec_util 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | def get_iter(batch_size, full_length, name2face_emb, name2voice_emb, movie2label): 9 | train_iter = DataLoader(DataSet(name2face_emb, name2voice_emb, full_length, movie2label), 10 | batch_size=batch_size, shuffle=False, pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 11 | return train_iter 12 | 13 | 14 | class DataSet(torch.utils.data.Dataset): 15 | 16 | def __init__(self, name2face_emb, name2voice_emb, full_length, movie2label): 17 | self.train_movie_list = list(movie2label.keys()) 18 | self.full_length = full_length 19 | self.name2face_emb = name2face_emb 20 | self.name2voice_emb = name2voice_emb 21 | self.movie2label = movie2label 22 | 23 | # create movie2jpg, movie2wav dict 24 | self.movie2jpg_path = {} 25 | self.movie2wav_path = {} 26 | name2movies = pickle_util.read_pickle("./dataset/info/name2movies.pkl") 27 | for name, movie_list in name2movies.items(): 28 | for movie_obj in movie_list: 29 | movie_name = movie_obj['person'] + "/" + movie_obj["movie"] 30 | # A.J._Buckley/J9lHsKG98U8 31 | self.movie2wav_path[movie_name] = movie_obj["wavs"] 32 | self.movie2jpg_path[movie_name] = movie_obj["jpgs"] 33 | 34 | def __len__(self): 35 | return self.full_length 36 | 37 | def __getitem__(self, index): 38 | movie = sample_util.random_element(self.train_movie_list) 39 | label = self.movie2label[movie] 40 | 41 | img = sample_util.random_element(self.movie2jpg_path[movie]) 42 | wav = sample_util.random_element(self.movie2wav_path[movie]) 43 | wav, img = self.to_tensor([wav, img]) 44 | 45 | return wav, img, torch.LongTensor([label]) 46 | 47 | def to_tensor(self, path_arr): 48 | ans = [] 49 | for path in path_arr: 50 | if ".wav" in path: 51 | emb = self.name2voice_emb[path] 52 | else: 53 | emb = self.name2face_emb[path] 54 | emb = torch.FloatTensor(emb) 55 | ans.append(emb) 56 | return ans 57 | -------------------------------------------------------------------------------- /loaders/v7_sup_id_contrastive_loader.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | from utils import pickle_util, sample_util, worker_util 4 | from torch.utils.data import DataLoader 5 | from utils.config import face_emb_dict, voice_emb_dict 6 | 7 | 8 | # 用于配合pytorch的默认contrastive loss 9 | 10 | def get_iter(batch_size, full_length): 11 | train_iter = DataLoader(DataSet(full_length), batch_size=batch_size, shuffle=True, pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 12 | return train_iter 13 | 14 | 15 | class DataSet(torch.utils.data.Dataset): 16 | 17 | def __init__(self, dataset_length): 18 | train_names = pickle_util.read_pickle("./dataset/info/train_valid_test_names.pkl")["train"] 19 | train_names.sort() 20 | name2id = {train_names[i]: i for i in range(len(train_names))} 21 | self.train_names = train_names 22 | self.name2id = name2id 23 | 24 | self.name2movies = pickle_util.read_pickle("./dataset/info/name2movies.pkl") 25 | self.name2gender = pickle_util.read_pickle("./dataset/info/name2gender.pkl") 26 | self.dataset_length = dataset_length 27 | 28 | def __len__(self): 29 | return self.dataset_length 30 | 31 | def __getitem__(self, index): 32 | is_same_person = index % 2 33 | if is_same_person == 1: 34 | name1 = sample_util.random_element(self.train_names) 35 | name2 = name1 36 | else: 37 | name1, name2 = sample_util.random_elements(self.train_names, 2) 38 | 39 | jpg_path = sample_util.random_element(sample_util.random_element(self.name2movies[name1])["jpgs"]) 40 | wav_path = sample_util.random_element(sample_util.random_element(self.name2movies[name2])["wavs"]) 41 | 42 | voice_tensor = voice_emb_dict[wav_path] 43 | face_tensor = face_emb_dict[jpg_path] 44 | 45 | is_same_person = torch.as_tensor(is_same_person).long() 46 | return voice_tensor, face_tensor, is_same_person 47 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/models/__init__.py -------------------------------------------------------------------------------- /models/cae_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.my_model import Encoder 3 | 4 | 5 | class Decoder(torch.nn.Module): 6 | def __init__(self, voice_size=192, face_size=512, embedding_size=128, shared=True): 7 | super(Decoder, self).__init__() 8 | # 128->Drop-fc256-Relu-fc256-Relu-xxx 9 | 10 | # 共享层 11 | mid_dim = 256 12 | 13 | def create_rare(): 14 | return torch.nn.Sequential( 15 | torch.nn.Dropout(), 16 | torch.nn.Linear(embedding_size, mid_dim), 17 | torch.nn.ReLU(), 18 | torch.nn.Linear(mid_dim, mid_dim), 19 | torch.nn.ReLU(), 20 | ) 21 | 22 | face_rare = create_rare() 23 | if shared: 24 | voice_rare = face_rare 25 | else: 26 | voice_rare = create_rare() 27 | 28 | self.dec_face = torch.nn.Sequential( 29 | face_rare, 30 | torch.nn.Linear(mid_dim, voice_size), 31 | ) 32 | self.dec_voice = torch.nn.Sequential( 33 | voice_rare, 34 | torch.nn.Linear(mid_dim, face_size) 35 | ) 36 | 37 | def forward(self, v_emb, f_emb): 38 | f_out = self.dec_voice(v_emb) 39 | v_out = self.dec_face(f_emb) 40 | return v_out, f_out 41 | 42 | 43 | class CAE(torch.nn.Module): 44 | def __init__(self): 45 | super(CAE, self).__init__() 46 | self.encoder = Encoder(shared=True) 47 | self.face_encoder = self.encoder.face_encoder 48 | self.voice_encoder = self.encoder.voice_encoder 49 | 50 | self.decoder = Decoder(shared=False) 51 | self.fun_loss_mse = torch.nn.MSELoss() 52 | 53 | def forward(self, voice_data, face_data, only_emb=False): 54 | v_emb = self.voice_encoder(voice_data) 55 | f_emb = self.face_encoder(face_data) 56 | if only_emb: 57 | return v_emb, f_emb 58 | v_out, f_out = self.decoder(v_emb, f_emb) 59 | 60 | fun_loss_mse = self.fun_loss_mse 61 | loss_dec = fun_loss_mse(voice_data, v_out) + fun_loss_mse(face_data, f_out) 62 | loss_emb = fun_loss_mse(v_emb, f_emb) 63 | return loss_emb, loss_dec 64 | -------------------------------------------------------------------------------- /models/de_hq3_model.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | import numpy as np 4 | from torch import nn 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | 9 | class MyL2(nn.Module): 10 | def __init__(self, scale=1.0): 11 | super(MyL2, self).__init__() 12 | self.scale = scale 13 | 14 | def forward(self, x): 15 | return torch.nn.functional.normalize(x, dim=1, p=2) * self.scale 16 | 17 | 18 | def create_encoder(in_dim=512): 19 | return torch.nn.Sequential( 20 | torch.nn.Linear(in_dim, 256), 21 | torch.nn.ReLU(), 22 | torch.nn.Linear(256, 128), 23 | torch.nn.ReLU(), 24 | ) 25 | 26 | 27 | def create_decoder(out_dim=512): 28 | return torch.nn.Sequential( 29 | MyL2(), 30 | torch.nn.Linear(128, 256), 31 | torch.nn.Tanh(), 32 | torch.nn.Linear(256, out_dim), 33 | torch.nn.Tanh(), 34 | ) 35 | 36 | 37 | class Model(nn.Module): 38 | def __init__(self, num_user, args): 39 | super(Model, self).__init__() 40 | self.args = args 41 | # 编码器 42 | self.face_encoder_common = create_encoder() 43 | self.face_encoder_private = create_encoder() 44 | self.voice_encoder_common = create_encoder(in_dim=192) 45 | self.voice_encoder_private = create_encoder(in_dim=192) 46 | 47 | # 解码器 48 | self.face_decoder = create_decoder() 49 | self.voice_decoder = create_decoder(out_dim=192) 50 | 51 | # 损失 52 | self.id_classifier = torch.nn.Linear(128, num_user) 53 | self.fun_mse = torch.nn.MSELoss() 54 | self.fun_cross_entropy = torch.nn.CrossEntropyLoss() 55 | 56 | def face_encoder(self, f): 57 | return self.face_encoder_common(f) 58 | 59 | def voice_encoder(self, v): 60 | return self.voice_encoder_common(v) 61 | 62 | def forward(self, v, f, label, step): 63 | # 1.变成embeding 64 | f_emb_common = self.face_encoder_common(f) 65 | f_emb_private = self.face_encoder_private(f) 66 | v_emb_common = self.voice_encoder_common(v) 67 | v_emb_private = self.voice_encoder_private(v) 68 | 69 | # ============= id分类器 70 | id_logits_f = self.id_classifier(f_emb_common) 71 | id_logits_v = self.id_classifier(v_emb_common) 72 | loss_id = self.fun_cross_entropy(id_logits_f, label) + self.fun_cross_entropy(id_logits_v, label) 73 | 74 | # ============= 解纠缠 75 | loss_instance_level = horizontal_cosine_similarity(v_emb_common, v_emb_private) \ 76 | + horizontal_cosine_similarity(f_emb_common, f_emb_private) 77 | loss_mutual_level = frobenius_norm(f_emb_common @ f_emb_private.T) \ 78 | + frobenius_norm(v_emb_common @ v_emb_private.T) 79 | 80 | emb_and = f_emb_common * v_emb_common 81 | emb_or = f_emb_private + v_emb_private 82 | loss_and_or = frobenius_norm(emb_and @ emb_or.T) 83 | 84 | loss_orth = loss_instance_level + loss_mutual_level + loss_and_or 85 | 86 | # ============= 重构损失 87 | # 重构时必须要存在自己的私有的部分 88 | f_enc1 = self.face_decoder(f_emb_common + f_emb_private) 89 | f_enc2 = self.face_decoder(v_emb_common + f_emb_private) 90 | v_enc1 = self.voice_decoder(v_emb_common + v_emb_private) 91 | v_enc2 = self.voice_decoder(f_emb_common + v_emb_private) 92 | mse = self.fun_mse 93 | loss_rec = mse(f, f_enc1) + mse(f, f_enc2) + mse(v, v_enc1) + mse(v, v_enc2) 94 | 95 | # 最后的损失聚合 96 | loss = loss_id + self.args.ratio_rec * loss_rec + self.args.ratio_orth * loss_orth 97 | 98 | info = { 99 | "loss_ratio_rec": (self.args.ratio_rec * loss_rec).item() / loss.item(), 100 | "loss_ratio_orth": (self.args.ratio_orth * loss_orth).item() / loss.item(), 101 | # "f_emb_common": f_emb_common, 102 | # "f_emb_private": f_emb_private, 103 | # "v_emb_common": v_emb_common, 104 | # "v_emb_private": v_emb_private, 105 | } 106 | 107 | return loss, info 108 | 109 | 110 | def horizontal_cosine_similarity(emb1, emb2): 111 | unit_emb1 = F.normalize(emb1) 112 | unit_emb2 = F.normalize(emb2) 113 | ans = torch.sum(unit_emb1 * unit_emb2, dim=1) 114 | return ans.sum() 115 | 116 | 117 | def frobenius_norm(A): 118 | # if torch.sum(A ** 2).item() < 0: 119 | # ipdb.set_trace() 120 | return torch.norm(A, p='fro') 121 | # return torch.sqrt(torch.sum(A ** 2)) 122 | # return torch.sqrt(torch.abs(torch.sum(A ** 2))) 123 | -------------------------------------------------------------------------------- /models/fop_model.py: -------------------------------------------------------------------------------- 1 | # Fusion and Orthogonal Projection for Improved Face-Voice Association,ICASSP,2022 2 | # FROM https://github.com/msaadsaeed/FOP 3 | import torch 4 | import torch.nn as nn 5 | 6 | class FopModel(nn.Module): 7 | def __init__(self): 8 | super(FopModel, self).__init__() 9 | self.encoder_v = EmbedBranch(192, 128) 10 | self.encoder_f = EmbedBranch(512, 128) 11 | self.gated_fusion = GatedFusion(embed_dim_in=128, mid_att_dim=128, emb_dim_out=128) 12 | self.tanh_mode = False 13 | 14 | def forward(self, voice_input, face_input): 15 | tmp_v_emb = self.encoder_v(voice_input) 16 | tmp_f_emb = self.encoder_f(face_input) 17 | return self.gated_fusion(tmp_v_emb, tmp_f_emb) 18 | 19 | def face_encoder(self, data): 20 | return torch.tanh(self.encoder_f(data)) 21 | 22 | def voice_encoder(self, data): 23 | return torch.tanh(self.encoder_v(data)) 24 | 25 | 26 | class EmbedBranch(nn.Module): 27 | def __init__(self, feat_dim, embedding_dim): 28 | super(EmbedBranch, self).__init__() 29 | self.fc1 = nn.Sequential( 30 | nn.Linear(feat_dim, embedding_dim), 31 | nn.BatchNorm1d(embedding_dim), 32 | nn.ReLU(inplace=True), 33 | nn.Dropout(p=0.5) 34 | ) 35 | 36 | def forward(self, x): 37 | x = self.fc1(x) 38 | x = nn.functional.normalize(x) 39 | return x 40 | 41 | 42 | class GatedFusion(nn.Module): 43 | def __init__(self, embed_dim_in, mid_att_dim, emb_dim_out): 44 | super(GatedFusion, self).__init__() 45 | self.attention = nn.Sequential( 46 | nn.Linear(embed_dim_in * 2, mid_att_dim), 47 | nn.BatchNorm1d(mid_att_dim), 48 | nn.ReLU(), 49 | nn.Dropout(p=0), 50 | nn.Linear(mid_att_dim, emb_dim_out) 51 | ) 52 | 53 | def forward(self, voice_input, face_input): 54 | concat = torch.cat((face_input, voice_input), dim=1) 55 | attention_out = torch.sigmoid(self.attention(concat)) 56 | face_emb = torch.tanh(face_input) 57 | voice_emb = torch.tanh(voice_input) 58 | fused_emb = face_emb * attention_out + (1.0 - attention_out) * voice_emb 59 | return voice_emb, face_emb, fused_emb 60 | 61 | 62 | 63 | 64 | # class FopModel(nn.Module): 65 | # def __init__(self, encoder): 66 | # super(FopModel, self).__init__() 67 | # self.encoder = encoder 68 | # self.gated_fusion = GatedFusion(embed_dim_in=128, mid_att_dim=128, emb_dim_out=128) 69 | # 70 | # def forward(self, voice_input, face_input): 71 | # v_emb0, f_emb0 = self.encoder(voice_input, face_input) 72 | # return self.gated_fusion(v_emb0, f_emb0) 73 | # 74 | # def face_encoder(self, data): 75 | # return torch.tanh(self.encoder.face_encoder(data)) 76 | # 77 | # def voice_encoder(self, data): 78 | # return torch.tanh(self.encoder.voice_encoder(data)) 79 | 80 | -------------------------------------------------------------------------------- /models/my_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Encoder(torch.nn.Module): 6 | def __init__(self, voice_size=192, face_size=512, mid_dim=256, embedding_size=128, shared=True): 7 | super(Encoder, self).__init__() 8 | 9 | if shared: 10 | face_rare = self.create_rare(mid_dim, embedding_size) 11 | voice_rare = face_rare 12 | else: 13 | face_rare = self.create_rare(mid_dim, embedding_size) 14 | voice_rare = self.create_rare(mid_dim, embedding_size) 15 | 16 | self.face_encoder = torch.nn.Sequential( 17 | torch.nn.Dropout(), 18 | torch.nn.Linear(face_size, mid_dim), 19 | torch.nn.ReLU(), 20 | face_rare 21 | ) 22 | 23 | self.voice_encoder = torch.nn.Sequential( 24 | torch.nn.Dropout(), 25 | torch.nn.Linear(voice_size, mid_dim), 26 | torch.nn.ReLU(), 27 | voice_rare 28 | ) 29 | 30 | def create_rare(self, mid_dim, embedding_size): 31 | return torch.nn.Sequential( 32 | torch.nn.Linear(mid_dim, mid_dim), 33 | torch.nn.ReLU(), 34 | torch.nn.Linear(mid_dim, embedding_size), 35 | ) 36 | 37 | def forward(self, voice_data, face_data): 38 | v_emb = self.voice_encoder(voice_data) 39 | f_emb = self.face_encoder(face_data) 40 | return v_emb, f_emb 41 | 42 | 43 | class EncoderWithProjector(torch.nn.Module): 44 | def __init__(self, voice_size=192, face_size=512, mid_dim=256, embedding_size=128, shared=True): 45 | super(EncoderWithProjector, self).__init__() 46 | 47 | if shared: 48 | face_rare = self.create_rare(mid_dim, embedding_size) 49 | voice_rare = face_rare 50 | else: 51 | face_rare = self.create_rare(mid_dim, embedding_size) 52 | voice_rare = self.create_rare(mid_dim, embedding_size) 53 | 54 | self.face_encoder = torch.nn.Sequential( 55 | torch.nn.Dropout(), 56 | torch.nn.Linear(face_size, mid_dim), 57 | torch.nn.ReLU(), 58 | face_rare 59 | ) 60 | 61 | self.voice_encoder = torch.nn.Sequential( 62 | torch.nn.Dropout(), 63 | torch.nn.Linear(voice_size, mid_dim), 64 | torch.nn.ReLU(), 65 | voice_rare 66 | ) 67 | 68 | self.face_projector = torch.nn.Linear(embedding_size, embedding_size) 69 | self.voice_projector = torch.nn.Linear(embedding_size, embedding_size) 70 | 71 | def create_rare(self, mid_dim, embedding_size): 72 | return torch.nn.Sequential( 73 | torch.nn.Linear(mid_dim, mid_dim), 74 | torch.nn.ReLU(), 75 | torch.nn.Linear(mid_dim, embedding_size), 76 | ) 77 | 78 | def forward(self, voice_data, face_data, need_projector=False): 79 | v_emb = self.voice_encoder(voice_data) 80 | f_emb = self.face_encoder(face_data) 81 | 82 | if need_projector: 83 | pf = self.face_projector(f_emb) 84 | pv = self.voice_projector(v_emb) 85 | return v_emb, f_emb, pv, pf 86 | return v_emb, f_emb 87 | -------------------------------------------------------------------------------- /preprocess/Readme.md: -------------------------------------------------------------------------------- 1 | # Preprocess Script 2 | 3 | ## 1.Data Preprocess 4 | 5 | ### Voice 6 | - Extract .mp4 to wav: `preprocess/preprocess/1_mp4_extract_wav.py` 7 | - Perform Voice Activity Detection: `preprocess/preprocess/2_wav_vad.py` 8 | 9 | ### Face 10 | - Extract mp4 to frames: `preprocess/preprocess/3_mp4_extract_frames.py` 11 | - MTCNN: `preprocess/preprocess/4_face_crop_mtcnn.py` 12 | - Pose Estimation: `preprocess/preprocess/5_pose_estimation.py` 13 | 14 | ## 2.Feature Extract 15 | 16 | ### Voice 17 | - ECAPA-TDNN: `preprocess/voice_extractor/1_ecapa_tdnn.py` 18 | - Resemblyzer: `preprocess/voice_extractor/2_resemblizer.py` 19 | 20 | ### Face 21 | - Inception-v1: `preprocess/face_extractor/1_inception_v1.py` 22 | 23 | - Deepface: `preprocess/face_extractor/2_deepface.py` 24 | - Support: "VGG-Face", "Facenet", "OpenFace", "DeepFace", "DeepID", "ArcFace" 25 | 26 | - Face-X-Zoo: `preprocess/face_extractor/3_facexzoo.py` 27 | 28 | Example command for running feature extraction: 29 | 30 | - MobileFaceNet: `python script_name.py --backbone_type MobileFaceNet --model_pkl 1_MobileFaceNet --batch_size 2048 --size 112 --save_name MobileFaceNet.pkl` 31 | - ResNet: `python script_name.py --backbone_type ResNet --model_pkl 5_Resnet152-irse --batch_size 512 --size 112 --save_name ResNet.pkl` 32 | - AttentionNet: `python script_name.py --backbone_type AttentionNet --model_pkl 2_Attention56 --batch_size 512 --size 112 --save_name AttentionNet.pkl` 33 | 34 | ## 35 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/preprocess/__init__.py -------------------------------------------------------------------------------- /preprocess/face_extractor/1_inception_v1.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from .models import incep 5 | from .loaders.img_loader import Dataset 6 | from utils import pickle_util 7 | import glob 8 | 9 | the_dict = {} 10 | 11 | 12 | def handle_emb_batch(all_data, batch_emb, indexies): 13 | batch_emb = batch_emb.detach().cpu().numpy().squeeze() 14 | assert len(batch_emb.shape) == 2 15 | indexies = indexies.detach().cpu().numpy().tolist() 16 | for idx, emb in zip(indexies, batch_emb): 17 | filepath = all_data[idx] 18 | the_dict[filepath] = emb 19 | 20 | 21 | def fun(num_workers, all_img_data, batch_size): 22 | start_time = time.time() 23 | the_iter = DataLoader(Dataset(all_img_data), num_workers=num_workers, batch_size=batch_size, shuffle=False, 24 | pin_memory=True) 25 | all_data = the_iter.dataset.all_image_files 26 | 27 | total_batch = int(len(all_data) / batch_size) + 1 28 | counter = 0 29 | with torch.no_grad(): 30 | for image_tensor, indexies in the_iter: 31 | counter += 1 32 | emb_vec = model(image_tensor.cuda()) 33 | handle_emb_batch(all_data, emb_vec, indexies) 34 | time_cost_h = (time.time() - start_time) / 3600.0 35 | progress = (counter + 1) / total_batch 36 | full_time = time_cost_h / progress 37 | print(counter, progress, "full:", full_time) 38 | 39 | 40 | if __name__ == '__main__': 41 | # 1.load model 42 | model = incep.InceptionResnetV1(pretrained="vggface2", classify=True) 43 | model.cuda() 44 | model.eval() 45 | 46 | # 2.get all img list 47 | all_jpgs = glob.glob("/your_path/*.jpg") 48 | 49 | # 3.processing 50 | fun(8, all_jpgs, batch_size=2048) 51 | 52 | # 4.save 53 | pickle_util.save_pickle("face_emb.pkl", the_dict) 54 | -------------------------------------------------------------------------------- /preprocess/face_extractor/2_deepface.py: -------------------------------------------------------------------------------- 1 | from deepface import DeepFace 2 | import ipdb 3 | import torch 4 | import time 5 | from torch.utils.data import DataLoader 6 | from .loaders.loader4deepface import Dataset 7 | from utils import pickle_util 8 | 9 | 10 | def handle_one(num_workers, model_name): 11 | model = DeepFace.build_model(model_name) 12 | 13 | start_time = time.time() 14 | batch_size = 128 15 | 16 | the_iter = DataLoader(Dataset(root_path, mode="deepface"), 17 | num_workers=num_workers, batch_size=batch_size, 18 | shuffle=False, 19 | pin_memory=True) 20 | all_data = the_iter.dataset.all_image_files 21 | 22 | total_batch = int(len(all_data) / batch_size) + 1 23 | people2emb = {} 24 | counter = 0 25 | vec_dim = 0 26 | with torch.no_grad(): 27 | for image_tensor, indexies in the_iter: 28 | counter += 1 29 | # 获取emb: 30 | batch_emb = model.predict(image_tensor.detach().cpu().numpy()) 31 | assert len(batch_emb.shape) == 2 32 | 33 | indexies = indexies.detach().cpu().numpy().tolist() 34 | for idx, emb in zip(indexies, batch_emb): 35 | filepath = all_data[idx] 36 | # /home/my/datasets/2_VFMR/4_vgg1_mtcnn/Zack_Snyder/0013800.jpg 37 | tmp_arr = filepath.split("/") 38 | short_path = tmp_arr[-2] + "/" + tmp_arr[-1] 39 | people2emb[short_path] = emb 40 | vec_dim = len(emb) 41 | 42 | time_cost_h = (time.time() - start_time) / 3600.0 43 | progress = (counter + 1) / total_batch 44 | full_time = time_cost_h / progress 45 | print(counter, progress, "full:", full_time) 46 | 47 | save_name = "deepface_%s_dim%d.pkl" % (model_name, vec_dim) 48 | 49 | pickle_util.save_pickle(save_name, people2emb) 50 | print(save_name) 51 | 52 | 53 | if __name__ == '__main__': 54 | import argparse 55 | 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--index", default=0, type=int) 58 | args = parser.parse_args() 59 | 60 | root_path = "./jpgs/" 61 | models = ["VGG-Face", "Facenet", "OpenFace", "DeepFace", "DeepID", "ArcFace"] 62 | handle_one(8, models[args.index]) 63 | -------------------------------------------------------------------------------- /preprocess/face_extractor/3_facexzoo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from torch.utils.data import DataLoader 4 | from utils.model_loader import ModelLoader 5 | from utils.extractor.feature_extractor import CommonExtractor 6 | from utils import pickle_util 7 | from data_processor.test_dataset2_resize112 import CommonTestDataset 8 | from backbone.backbone_def import BackboneFactory 9 | 10 | 11 | def load_model(model_pkl_folder_root): 12 | # 1.模型 13 | backbone_type = args.backbone_type 14 | backbone_conf_file = "backbone_conf.yaml" 15 | backbone_factory = BackboneFactory(backbone_type, backbone_conf_file) 16 | model_loader = ModelLoader(backbone_factory) 17 | 18 | # pkl所在文件夹: 19 | model_pkl_folder = os.path.join(model_pkl_folder_root, args.model_pkl) 20 | assert os.path.exists(model_pkl_folder) 21 | 22 | pt_name_list = [i for i in os.listdir(model_pkl_folder) if i.endswith(".pt")] 23 | assert len(pt_name_list) == 1 24 | # 加载参数: 25 | model_path = os.path.join(model_pkl_folder, pt_name_list[0]) 26 | model = model_loader.load_model(model_path) 27 | return model 28 | 29 | 30 | def load_dataloader(cropped_face_folder): 31 | image_list_file_path = "1fps_pathlist.txt" 32 | batch_size = args.batch_size 33 | data_loader = DataLoader(CommonTestDataset(cropped_face_folder, 34 | image_list_file_path, args.size, False), 35 | batch_size=batch_size, num_workers=8, shuffle=False) 36 | 37 | return data_loader 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | 43 | parser.add_argument("--backbone_type", default="AttentionNet") 44 | parser.add_argument("--model_pkl", default="2_Attention56") 45 | parser.add_argument("--batch_size", default=512, type=int) 46 | parser.add_argument("--size", default=112, type=int) 47 | parser.add_argument("--save_name", default="AttentionNet.pkl") 48 | 49 | args = parser.parse_args() 50 | 51 | assert len(args.save_name) > 0 52 | 53 | model = load_model("/home/my/projects/124_FaceX-Zoo/models") 54 | 55 | data_loader = load_dataloader("faces/") 56 | 57 | # 抽取特征 58 | feature_extractor = CommonExtractor('cuda:0') 59 | image_name2feature = feature_extractor.extract_online(model, data_loader) 60 | 61 | # 保存: 62 | dim_len = -1 63 | for k, v in image_name2feature.items(): 64 | dim_len = len(v) 65 | break 66 | 67 | save_name = args.save_name.replace(".pkl", "dim%d.pkl" % (dim_len)) 68 | save_path = os.path.join("/ssd2/1_Voxceleb2/4_faceX_Zoo_extracted_feature/1fps", save_name) 69 | pickle_util.save_pickle(save_path, image_name2feature) 70 | print(save_path) 71 | -------------------------------------------------------------------------------- /preprocess/face_extractor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/preprocess/face_extractor/__init__.py -------------------------------------------------------------------------------- /preprocess/face_extractor/loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/preprocess/face_extractor/loaders/__init__.py -------------------------------------------------------------------------------- /preprocess/face_extractor/loaders/img_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import torchvision 4 | 5 | 6 | class Dataset(torch.utils.data.Dataset): 7 | 8 | def __init__(self, all_image_files): 9 | self.all_image_files = all_image_files 10 | resize_size = 128 11 | self.transform_fn = torchvision.transforms.Compose([ 12 | torchvision.transforms.Resize(size=(resize_size, resize_size)), 13 | torchvision.transforms.ToTensor() 14 | ]) 15 | 16 | def __len__(self): 17 | return len(self.all_image_files) 18 | 19 | def __getitem__(self, index): 20 | file_path = self.all_image_files[index] 21 | img_PIL = Image.open(file_path) 22 | if img_PIL.mode != "RGB": 23 | img_PIL = img_PIL.convert("RGB") 24 | 25 | data = self.transform_fn(img_PIL) 26 | assert data.shape == (3, 128, 128), file_path 27 | return data, index 28 | -------------------------------------------------------------------------------- /preprocess/face_extractor/loaders/loader4deepface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from deepface import DeepFace 4 | import torchvision 5 | from torchvision import transforms 6 | import cv2 7 | from PIL import Image 8 | 9 | 10 | class Dataset(torch.utils.data.Dataset): 11 | 12 | def __init__(self, root_path, mode=None): 13 | all_image_files = [] 14 | for person_name in os.listdir(root_path): 15 | for img in os.listdir(os.path.join(root_path, person_name)): 16 | if ".jpg" not in img: 17 | continue 18 | all_image_files.append(os.path.join(root_path, person_name, img)) 19 | self.all_image_files = all_image_files 20 | self.mode = mode 21 | 22 | def __len__(self): 23 | return len(self.all_image_files) 24 | 25 | def __getitem__(self, index): 26 | file_path = self.all_image_files[index] 27 | if self.mode == "facenet": 28 | img_PIL = Image.open(file_path) 29 | resize_size = 128 30 | transform_fn = torchvision.transforms.Compose([ 31 | torchvision.transforms.Resize(size=(resize_size, resize_size)), 32 | torchvision.transforms.ToTensor() 33 | ]) 34 | data = transform_fn(img_PIL) 35 | 36 | elif self.mode == "deepface": 37 | resize_size = 224 38 | img = DeepFace.functions.preprocess_face(img=file_path, target_size=(resize_size, resize_size), 39 | enforce_detection=False) 40 | img = img.squeeze(axis=0) 41 | data = torch.FloatTensor(img) 42 | else: 43 | data = load(file_path, 128) 44 | return data, index 45 | 46 | 47 | def load(image_path, image_size): 48 | return trans_frame(cv2.imread(image_path), image_size) 49 | 50 | 51 | def trans_frame(frame_npy, image_size): 52 | frame_pil = Image.fromarray(frame_npy) 53 | 54 | trans = torchvision.transforms.Compose([ 55 | transforms.Resize(size=(image_size, image_size)), 56 | transforms.ToTensor(), 57 | ]) 58 | 59 | image = trans(frame_pil) 60 | return image 61 | -------------------------------------------------------------------------------- /preprocess/face_extractor/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/preprocess/face_extractor/models/__init__.py -------------------------------------------------------------------------------- /preprocess/face_extractor/models/incep.py: -------------------------------------------------------------------------------- 1 | from facenet_pytorch.models.inception_resnet_v1 import * 2 | 3 | 4 | class InceptionResnetV1(nn.Module): 5 | """Inception Resnet V1 model with optional loading of pretrained weights. 6 | 7 | Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface 8 | datasets. Pretrained state_dicts are automatically downloaded on model instantiation if 9 | requested and cached in the torch cache. Subsequent instantiations use the cache rather than 10 | redownloading. 11 | 12 | Keyword Arguments: 13 | pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'. 14 | (default: {None}) 15 | classify {bool} -- Whether the model should output classification probabilities or feature 16 | embeddings. (default: {False}) 17 | num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not 18 | equal to that used for the pretrained model, the final linear layer will be randomly 19 | initialized. (default: {None}) 20 | dropout_prob {float} -- Dropout probability. (default: {0.6}) 21 | """ 22 | 23 | def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None): 24 | super().__init__() 25 | 26 | # Set simple attributes 27 | self.pretrained = pretrained 28 | self.classify = classify 29 | self.num_classes = num_classes 30 | 31 | if pretrained == 'vggface2': 32 | tmp_classes = 8631 33 | elif pretrained == 'casia-webface': 34 | tmp_classes = 10575 35 | elif pretrained is None and self.classify and self.num_classes is None: 36 | raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified') 37 | 38 | # Define layers 39 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) 40 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) 41 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) 42 | self.maxpool_3a = nn.MaxPool2d(3, stride=2) 43 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) 44 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) 45 | self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2) 46 | self.repeat_1 = nn.Sequential( 47 | Block35(scale=0.17), 48 | Block35(scale=0.17), 49 | Block35(scale=0.17), 50 | Block35(scale=0.17), 51 | Block35(scale=0.17), 52 | ) 53 | self.mixed_6a = Mixed_6a() 54 | self.repeat_2 = nn.Sequential( 55 | Block17(scale=0.10), 56 | Block17(scale=0.10), 57 | Block17(scale=0.10), 58 | Block17(scale=0.10), 59 | Block17(scale=0.10), 60 | Block17(scale=0.10), 61 | Block17(scale=0.10), 62 | Block17(scale=0.10), 63 | Block17(scale=0.10), 64 | Block17(scale=0.10), 65 | ) 66 | self.mixed_7a = Mixed_7a() 67 | self.repeat_3 = nn.Sequential( 68 | Block8(scale=0.20), 69 | Block8(scale=0.20), 70 | Block8(scale=0.20), 71 | Block8(scale=0.20), 72 | Block8(scale=0.20), 73 | ) 74 | self.block8 = Block8(noReLU=True) 75 | self.avgpool_1a = nn.AdaptiveAvgPool2d(1) 76 | self.dropout = nn.Dropout(dropout_prob) 77 | self.last_linear = nn.Linear(1792, 512, bias=False) 78 | self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True) 79 | 80 | if pretrained is not None: 81 | self.logits = nn.Linear(512, tmp_classes) 82 | load_weights(self, pretrained) 83 | 84 | if self.classify and self.num_classes is not None: 85 | self.logits = nn.Linear(512, self.num_classes) 86 | 87 | self.device = torch.device('cpu') 88 | if device is not None: 89 | self.device = device 90 | self.to(device) 91 | 92 | def forward(self, x): 93 | """Calculate embeddings or logits given a batch of input image tensors. 94 | 95 | Arguments: 96 | x {torch.tensor} -- Batch of image tensors representing faces. 97 | 98 | Returns: 99 | torch.tensor -- Batch of embedding vectors or multinomial logits. 100 | """ 101 | x = self.conv2d_1a(x) 102 | x = self.conv2d_2a(x) 103 | x = self.conv2d_2b(x) 104 | x = self.maxpool_3a(x) 105 | x = self.conv2d_3b(x) 106 | x = self.conv2d_4a(x) 107 | x = self.conv2d_4b(x) 108 | x = self.repeat_1(x) 109 | x = self.mixed_6a(x) 110 | x = self.repeat_2(x) 111 | x = self.mixed_7a(x) 112 | x = self.repeat_3(x) 113 | x = self.block8(x) 114 | x = self.avgpool_1a(x) 115 | x = self.dropout(x) 116 | x = self.last_linear(x.view(x.shape[0], -1)) 117 | x = self.last_bn(x) 118 | # x1 = self.logits(x) 119 | # x2 = F.normalize(x, p=2, dim=1) 120 | # return x2, x1 121 | return x 122 | -------------------------------------------------------------------------------- /preprocess/preprocess/1_mp4_extract_wav.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from concurrent.futures import wait, ProcessPoolExecutor 4 | import time 5 | import glob 6 | 7 | 8 | def core(index): 9 | input_file = all_mp4[index] 10 | assert os.path.exists(input_file) 11 | output_file = os.path.join(save_root, os.path.basename(input_file)) + ".wav" 12 | 13 | cmd = "ffmpeg -y -i %s -ab 160k -ac 1 -ar 16000 -vn %s" % (input_file, output_file) 14 | ans = subprocess.call(cmd, shell=True) 15 | if ans != 0: 16 | print("error:", input_file, ans) 17 | 18 | if index > 0 and index % 100 == 0: 19 | time_cost = time.time() - start_time 20 | percentage = index / len(all_mp4) 21 | total_time = time_cost / percentage / 3600 22 | print(index, "total time:", total_time, "progress:", percentage) 23 | return output_file 24 | 25 | 26 | if __name__ == "__main__": 27 | all_mp4 = glob.glob("data/test/*.mp4") 28 | save_root = "wav_output/" 29 | pool_size = 8 30 | 31 | print("start!") 32 | start_time = time.time() 33 | pool = ProcessPoolExecutor(pool_size) 34 | tasks = [pool.submit(core, i) for i in range(len(all_mp4))] 35 | wait(tasks) 36 | print('done') 37 | -------------------------------------------------------------------------------- /preprocess/preprocess/2_wav_vad.py: -------------------------------------------------------------------------------- 1 | # https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb#scrollTo=nd2zX-kJ84bb 2 | import glob 3 | import torch 4 | import tqdm 5 | 6 | SAMPLING_RATE = 16000 7 | 8 | torch.set_num_threads(1) 9 | 10 | USE_ONNX = False # change this to True if you want to test onnx model 11 | model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', 12 | model='silero_vad', 13 | force_reload=True, 14 | onnx=USE_ONNX) 15 | 16 | 17 | def fun(input_wav, output_wav): 18 | (get_speech_timestamps, 19 | save_audio, 20 | read_audio, 21 | VADIterator, 22 | collect_chunks) = utils 23 | 24 | wav = read_audio(input_wav, sampling_rate=SAMPLING_RATE) 25 | # get speech timestamps from full audio file 26 | speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLING_RATE) 27 | save_audio(output_wav, collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE) 28 | 29 | 30 | if __name__ == '__main__': 31 | all_wavs = glob.glob("./data/test/*.wav") 32 | for wav_path in tqdm.tqdm(all_wavs): 33 | fun(wav_path, wav_path + "_vad.wav") 34 | -------------------------------------------------------------------------------- /preprocess/preprocess/3_mp4_extract_frames.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import glob 4 | 5 | 6 | def extract_frames(video_path, output_dir): 7 | # Create output directory if it doesn't exist 8 | if not os.path.exists(output_dir): 9 | os.makedirs(output_dir) 10 | 11 | # Run ffmpeg command to extract frames 12 | command = f'ffmpeg -i "{video_path}" -vf fps=1 "{output_dir}/%d.jpg"' 13 | subprocess.call(command, shell=True) 14 | 15 | 16 | def main(input_dir, output_dir): 17 | # Get list of mp4 files in input directory 18 | mp4_files = glob.glob(os.path.join(input_dir, '*.mp4')) 19 | 20 | for mp4_file in mp4_files: 21 | filename = os.path.splitext(os.path.basename(mp4_file))[0] 22 | output_subdir = os.path.join(output_dir, filename) 23 | extract_frames(mp4_file, output_subdir) 24 | 25 | 26 | if __name__ == "__main__": 27 | input_directory = '/path/to/input/directory' 28 | output_directory = '/path/to/output/directory' 29 | main(input_directory, output_directory) 30 | -------------------------------------------------------------------------------- /preprocess/preprocess/4_face_crop_mtcnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from facenet_pytorch import MTCNN 4 | import torch 5 | 6 | 7 | def extract_faces_from_jpg_folder(input_folder, output_folder, mtcnn): 8 | # Get a list of all jpg files in the input folder 9 | jpg_files = [f for f in os.listdir(input_folder) if f.endswith('.jpg')] 10 | 11 | # Create output folder if it doesn't exist 12 | if not os.path.exists(output_folder): 13 | os.makedirs(output_folder) 14 | 15 | for jpg_file in jpg_files: 16 | jpg_path = os.path.join(input_folder, jpg_file) 17 | image = Image.open(jpg_path) 18 | 19 | # Detect faces using MTCNN 20 | boxes, _ = mtcnn.detect(image) 21 | 22 | if boxes is None: 23 | print(f"No faces detected in {jpg_file}") 24 | continue 25 | 26 | for i, box in enumerate(boxes): 27 | # Convert box coordinates to integers 28 | box = box.astype(int) 29 | # Crop face from image 30 | cropped_face = image.crop((box[0], box[1], box[2], box[3])) 31 | # Save cropped face as new image 32 | output_path = os.path.join(output_folder, f"{jpg_file[:-4]}_face_{i}.jpg") 33 | cropped_face.save(output_path) 34 | print(f"Face {i + 1} from {jpg_file} saved to {output_path}") 35 | 36 | 37 | def load_mtcnn(): 38 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 39 | mtcnn = MTCNN( 40 | image_size=160, 41 | margin=20, 42 | min_face_size=20, 43 | thresholds=[0.6, 0.7, 0.7], 44 | factor=0.709, 45 | post_process=False, # do not perform tensor normalization 46 | keep_all=True, 47 | device=device 48 | ) 49 | return mtcnn 50 | 51 | 52 | if __name__ == "__main__": 53 | input_folder = '/path/to/input/folder' 54 | output_folder = '/path/to/output/folder' 55 | 56 | # Initialize MTCNN 57 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 58 | mtcnn = load_mtcnn() 59 | 60 | # Extract faces from jpg files in input folder 61 | extract_faces_from_jpg_folder(input_folder, output_folder, mtcnn) 62 | -------------------------------------------------------------------------------- /preprocess/preprocess/5_pose_estimation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from headpose_estimation import Headpose 3 | import os 4 | import glob 5 | from shutil import copyfile 6 | import argparse 7 | 8 | headpose = Headpose() 9 | 10 | 11 | def fun(img_path): 12 | # img = cv2.imread("tmp2.jpg") 13 | img = cv2.imread(img_path) 14 | detections, image = headpose.run(img) 15 | dic = detections[0] 16 | # {'bbox': array([228, 139, 530, 571]), 17 | # 'yaw': 17.79396, 18 | # 'pitch': -12.596962, 19 | # 'roll': 2.096115} 20 | # print(dic['yaw'], dic['pitch'], dic['roll']) 21 | return dic 22 | 23 | 24 | def copy_images_with_small_poses(input_folder, output_folder): 25 | # Create output folder if it doesn't exist 26 | if not os.path.exists(output_folder): 27 | os.makedirs(output_folder) 28 | 29 | # Get list of jpg files in input folder 30 | jpg_files = glob.glob(os.path.join(input_folder, '*.jpg')) 31 | 32 | for jpg_file in jpg_files: 33 | # Perform pose estimation 34 | pose = fun(jpg_file) 35 | 36 | # Check if yaw, pitch, and roll are all within the threshold 37 | if abs(pose['yaw']) < 25 and abs(pose['pitch']) < 25 and abs(pose['roll']) < 25: 38 | # Copy the image to the output folder 39 | filename = os.path.basename(jpg_file) 40 | output_path = os.path.join(output_folder, filename) 41 | copyfile(jpg_file, output_path) 42 | print(f"Image {filename} copied to {output_path}") 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser(description="Copy images with small poses.") 47 | parser.add_argument("input_folder", type=str, help="Path to input folder containing images.") 48 | parser.add_argument("output_folder", type=str, help="Path to output folder for copied images.") 49 | args = parser.parse_args() 50 | copy_images_with_small_poses(args.input_folder, args.output_folder) 51 | -------------------------------------------------------------------------------- /preprocess/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/preprocess/preprocess/__init__.py -------------------------------------------------------------------------------- /preprocess/voice_extractor/1_ecapa_tdnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import model_util, pickle_util 3 | from .loaders import voice_loader 4 | import time 5 | import glob 6 | 7 | 8 | def generate_emb_dict(wav_list, batch_size=16): 9 | loader = voice_loader.get_loader(4, batch_size, wav_list) 10 | the_dict = {} 11 | counter = 0 12 | start_time = time.time() 13 | for data, lens, keys in loader: 14 | try: 15 | core_step(data, lens, model, keys, the_dict) 16 | except Exception as e: 17 | print("error:", e) 18 | continue 19 | 20 | counter += 1 21 | if counter % 10 == 0: 22 | processed = len(the_dict) 23 | progress = processed / len(loader.dataset) 24 | time_cost = time.time() - start_time 25 | total_time = time_cost / progress / 3600.0 26 | print("progress:", progress, "total_time:", total_time) 27 | return the_dict 28 | 29 | 30 | def core_step(wavs, lens, model, keys, the_dict): 31 | with torch.no_grad(): 32 | feats = fun_compute_features(wavs.cuda()) 33 | feats = fun_mean_var_norm(feats, lens) 34 | embedding = model(feats, lens) 35 | embedding_npy = embedding.detach().cpu().numpy().squeeze() 36 | # (batch,192) 37 | for key, emb in zip(keys, embedding_npy): 38 | the_dict[key] = emb 39 | 40 | 41 | def get_ecapa_model(): 42 | from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN 43 | n_mels = 80 44 | channels = [1024, 1024, 1024, 1024, 3072] 45 | kernel_sizes = [5, 3, 3, 3, 1] 46 | dilations = [1, 2, 3, 4, 1] 47 | attention_channels = 128 48 | lin_neurons = 192 49 | model = ECAPA_TDNN(input_size=n_mels, channels=channels, 50 | kernel_sizes=kernel_sizes, dilations=dilations, 51 | attention_channels=attention_channels, 52 | lin_neurons=lin_neurons 53 | ) 54 | # print(model) 55 | return model 56 | 57 | 58 | def get_fun_compute_features(): 59 | from speechbrain.lobes.features import Fbank 60 | 61 | n_mels = 80 62 | left_frames = 0 63 | right_frames = 0 64 | deltas = False 65 | compute_features = Fbank(n_mels=n_mels, left_frames=left_frames, right_frames=right_frames, deltas=deltas) 66 | return compute_features 67 | 68 | 69 | def get_fun_norm(): 70 | from speechbrain.processing.features import InputNormalization 71 | return InputNormalization(norm_type="sentence", std_norm=False) 72 | 73 | 74 | if __name__ == "__main__": 75 | # 1.get model 76 | model = get_ecapa_model().cuda() 77 | pkl_path = "ecapa_acc0.9854.pkl" 78 | model_util.load_model(pkl_path, model) 79 | model.eval() 80 | 81 | fun_compute_features = get_fun_compute_features().cuda() 82 | fun_mean_var_norm = get_fun_norm().cuda() 83 | 84 | # 2.get all wav files 85 | wav_list = glob.glob("/your_path/*.wav") 86 | 87 | the_dict = generate_emb_dict(wav_list) 88 | pickle_util.save_pickle("voice_emb.pkl", the_dict) 89 | -------------------------------------------------------------------------------- /preprocess/voice_extractor/2_resemblizer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from resemblyzer import VoiceEncoder, preprocess_wav 3 | from pathlib import Path 4 | import time 5 | from utils import pickle_util 6 | import numpy as np 7 | 8 | 9 | def get_emb(the_path): 10 | fpath = Path(the_path) 11 | wav = preprocess_wav(fpath) 12 | embed = encoder.embed_utterance(wav) 13 | return embed.tolist() 14 | 15 | 16 | def core(all_wavs): 17 | start_tiem = time.time() 18 | result_dict = {} 19 | counter = 0 20 | for wav_path in all_wavs: 21 | counter += 1 22 | try: 23 | emb = get_emb(wav_path) 24 | except Exception as e: 25 | print("error:", e, wav_path) 26 | continue 27 | 28 | result_dict[wav_path] = np.array(emb) 29 | 30 | if counter % 100 == 0: 31 | progress = counter / len(all_wavs) 32 | time_cost = (time.time() - start_tiem) / 3600.0 33 | total_time = time_cost / progress 34 | print("total_time:%.1fh;progress:%.3f" % (total_time, progress)) 35 | return result_dict 36 | 37 | 38 | if __name__ == '__main__': 39 | encoder = VoiceEncoder("cuda") 40 | all_wavs = glob.glob("wav/*.wav") 41 | result_dict = core(all_wavs) 42 | pickle_util.save_pickle("resemblyzer.pkl", result_dict) 43 | -------------------------------------------------------------------------------- /preprocess/voice_extractor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/preprocess/voice_extractor/__init__.py -------------------------------------------------------------------------------- /preprocess/voice_extractor/loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/preprocess/voice_extractor/loaders/__init__.py -------------------------------------------------------------------------------- /preprocess/voice_extractor/loaders/voice_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torchaudio 4 | import numpy as np 5 | 6 | 7 | class Dataset(torch.utils.data.Dataset): 8 | 9 | def __init__(self, all_wavs): 10 | self.all_wavs = all_wavs 11 | 12 | def __len__(self): 13 | return len(self.all_wavs) 14 | 15 | def __getitem__(self, index): 16 | return { 17 | "key": self.all_wavs[index], 18 | "data": torchaudio.load(self.all_wavs[index])[0] 19 | } 20 | 21 | 22 | def collate_fn(item_list): 23 | data_list = [i['data'] for i in item_list] 24 | the_lengths = np.array([i.shape[-1] for i in data_list]) 25 | max_len = np.max(the_lengths) 26 | len_ratio = the_lengths / max_len 27 | 28 | batch_size = len(item_list) 29 | output = torch.zeros([batch_size, max_len]) 30 | for i in range(batch_size): 31 | cur = data_list[i] 32 | cur_len = data_list[i].shape[-1] 33 | output[i, :cur_len] = cur.squeeze() 34 | 35 | len_ratio = torch.FloatTensor(len_ratio) 36 | keys = [i['key'] for i in item_list] 37 | return output, len_ratio, keys 38 | 39 | 40 | def get_loader(num_workers, batch_size, all_wavs): 41 | loader = DataLoader(Dataset(all_wavs), 42 | num_workers=num_workers, batch_size=batch_size, 43 | shuffle=False, pin_memory=True, collate_fn=collate_fn) 44 | return loader 45 | 46 | 47 | if __name__ == "__main__": 48 | pass 49 | -------------------------------------------------------------------------------- /scripts/1_verification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import pickle_util, sample_util, seed_util 3 | 4 | 5 | def fun(names, count, name2wav, name2jpg): 6 | seed_util.set_seed(100) 7 | 8 | result = [] 9 | jpg_set = set() 10 | wav_set = set() 11 | 12 | for i in range(count): 13 | label = np.random.randint(0, 2) 14 | if label == 1: 15 | # 同一个人 16 | name1 = sample_util.random_element(names) 17 | name2 = name1 18 | else: 19 | name1, name2 = sample_util.random_elements(names, 2) 20 | 21 | wav = sample_util.random_element(name2wav[name1]) 22 | face = sample_util.random_element(name2jpg[name2]) 23 | obj = [wav, face, label] 24 | wav_set.add(wav) 25 | jpg_set.add(face) 26 | result.append(obj) 27 | 28 | obj = { 29 | "wav_set": wav_set, 30 | "jpg_set": jpg_set, 31 | "list": result, 32 | "script": open(__file__).read(), 33 | } 34 | return obj 35 | 36 | 37 | def is_male(gender): 38 | if gender in ['f', 0]: 39 | return False 40 | 41 | assert gender in ['m', 1] 42 | return True 43 | 44 | 45 | def fun_g(names, count, name2wav, name2jpg, name2gender): 46 | seed_util.set_seed(100) 47 | result = [] 48 | jpg_set = set() 49 | wav_set = set() 50 | 51 | male_names = [] 52 | female_names = [] 53 | 54 | for name in names: 55 | gender = name2gender[name] 56 | if is_male(gender): 57 | male_names.append(name) 58 | else: 59 | female_names.append(name) 60 | 61 | for i in range(count): 62 | label = np.random.randint(0, 2) 63 | if i % 2 == 0: 64 | the_names = female_names 65 | else: 66 | the_names = male_names 67 | 68 | if label == 1: 69 | # 同一个人 70 | name1 = sample_util.random_element(the_names) 71 | name2 = name1 72 | else: 73 | name1, name2 = sample_util.random_elements(the_names, 2) 74 | 75 | wav = sample_util.random_element(name2wav[name1]) 76 | face = sample_util.random_element(name2jpg[name2]) 77 | obj = [wav, face, label] 78 | wav_set.add(wav) 79 | jpg_set.add(face) 80 | result.append(obj) 81 | 82 | obj = { 83 | "wav_set": wav_set, 84 | "jpg_set": jpg_set, 85 | "list": result, 86 | "script": open(__file__).read(), 87 | } 88 | return obj 89 | 90 | 91 | if __name__ == '__main__': 92 | name_list_dict = pickle_util.read_pickle("./dataset/info/train_valid_test_names.pkl") 93 | name2jpgs_wavs = pickle_util.read_pickle("./dataset/info/name2jpgs_wavs.pkl") 94 | name2gender = pickle_util.read_pickle("./dataset/info/name2gender.pkl") 95 | 96 | count = 10000 97 | pickle_util.save_pickle("./dataset/evals/valid_verification.pkl", fun(name_list_dict["valid"], count, name2jpgs_wavs["name2wavs"], name2jpgs_wavs["name2jpgs"])) 98 | pickle_util.save_pickle("./dataset/evals/test_verification.pkl", fun(name_list_dict["test"], count, name2jpgs_wavs["name2wavs"], name2jpgs_wavs["name2jpgs"])) 99 | pickle_util.save_pickle("./dataset/evals/test_verification_g.pkl", fun_g(name_list_dict["test"], count, name2jpgs_wavs["name2wavs"], name2jpgs_wavs["name2jpgs"], name2gender)) 100 | -------------------------------------------------------------------------------- /scripts/2_matching.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import numpy as np 3 | import ipdb 4 | import os 5 | from utils import pickle_util, sample_util, seed_util 6 | 7 | 8 | def fun(names, iter_num, name2wav, name2pic): 9 | seed_util.set_seed(100) 10 | data = [] 11 | wav_set = set() 12 | jpg_set = set() 13 | for i in range(iter_num): 14 | # 1.选两个人 15 | name1, name2 = np.random.choice(names, 2, replace=False) 16 | 17 | voice1 = sample_util.random_element(name2wav[name1]) 18 | voice2 = sample_util.random_element(name2wav[name2]) 19 | 20 | face1 = sample_util.random_element(name2pic[name1]) 21 | face2 = sample_util.random_element(name2pic[name2]) 22 | 23 | # 统计出现的内容 24 | assert name1 != name2 25 | assert voice1 != voice2 26 | assert face1 != face2 27 | obj = (name1, voice1, face1, name2, voice2, face2) 28 | data.append(obj) 29 | 30 | wav_set.add(voice1) 31 | wav_set.add(voice2) 32 | jpg_set.add(face1) 33 | jpg_set.add(face2) 34 | 35 | obj = { 36 | "wav_set": wav_set, 37 | "jpg_set": jpg_set, 38 | "match_list": data 39 | } 40 | return obj 41 | 42 | 43 | def fun_g(names, iter_num, name2wav, name2pic): 44 | seed_util.set_seed(100) 45 | 46 | male_names = [name for name in names if is_male(name2gender[name])] 47 | female_names = [name for name in names if not is_male(name2gender[name])] 48 | 49 | data = [] 50 | wav_set = set() 51 | jpg_set = set() 52 | for i in range(iter_num): 53 | # 1.选两个人 54 | if i % 2 == 0: 55 | name1, name2 = np.random.choice(female_names, 2, replace=False) 56 | else: 57 | name1, name2 = np.random.choice(male_names, 2, replace=False) 58 | 59 | voice1 = sample_util.random_element(name2wav[name1]) 60 | voice2 = sample_util.random_element(name2wav[name2]) 61 | 62 | face1 = sample_util.random_element(name2pic[name1]) 63 | face2 = sample_util.random_element(name2pic[name2]) 64 | 65 | # 统计出现的内容 66 | assert name1 != name2 67 | assert voice1 != voice2 68 | assert face1 != face2 69 | obj = (name1, voice1, face1, name2, voice2, face2) 70 | data.append(obj) 71 | 72 | wav_set.add(voice1) 73 | wav_set.add(voice2) 74 | jpg_set.add(face1) 75 | jpg_set.add(face2) 76 | # ipdb.set_trace() 77 | 78 | obj = { 79 | "wav_set": wav_set, 80 | "jpg_set": jpg_set, 81 | "match_list": data 82 | } 83 | return obj 84 | 85 | 86 | def is_male(gender): 87 | if gender in ['f', 0]: 88 | return False 89 | 90 | assert gender in ['m', 1] 91 | return True 92 | 93 | 94 | def fun_1n(names, iter_num, name2wav, name2pic, N): 95 | data = [] 96 | wav_set = set() 97 | jpg_set = set() 98 | for i in range(iter_num): 99 | # 1.选择N个人 100 | name_list = sample_util.random_elements(names, N) 101 | # 2.选择样本(对应位置上,人员是一样的) 102 | voices = [sample_util.random_element(name2wav[name]) for name in name_list] 103 | faces = [sample_util.random_element(name2pic[name]) for name in name_list] 104 | data.append([voices, faces]) 105 | for v in voices: 106 | wav_set.add(v) 107 | 108 | for f in faces: 109 | jpg_set.add(f) 110 | 111 | obj = { 112 | "wav_set": wav_set, 113 | "jpg_set": jpg_set, 114 | "match_list": data 115 | } 116 | return obj 117 | 118 | 119 | if __name__ == "__main__": 120 | name_list_dict = pickle_util.read_pickle("./dataset/info/train_valid_test_names.pkl") 121 | name2jpgs_wavs = pickle_util.read_pickle("./dataset/info/name2jpgs_wavs.pkl") 122 | name2gender = pickle_util.read_pickle("./dataset/info/name2gender.pkl") 123 | name2wav = name2jpgs_wavs["name2wavs"] 124 | name2jpg = name2jpgs_wavs["name2jpgs"] 125 | 126 | # 基本测试 127 | count = 10000 128 | pickle_util.save_pickle("./dataset/evals/test_matching.pkl", fun(name_list_dict["test"], count, name2wav, name2jpg)) 129 | pickle_util.save_pickle("./dataset/evals/test_matching_g.pkl", fun_g(name_list_dict["test"], count, name2wav, name2jpg)) 130 | pickle_util.save_pickle("./dataset/evals/test_matching_10.pkl", fun_1n(name_list_dict["test"], count, name2wav, name2jpg, 10)) 131 | -------------------------------------------------------------------------------- /scripts/3_retrieval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from utils import pickle_util 3 | from utils import seed_util, sample_util 4 | import random 5 | 6 | 7 | def one_gropu(names, name2wav, name2jpg, jpg_set, wav_set, max_person): 8 | if max_person is not None: 9 | names = sample_util.random_elements(names, max_person) 10 | print("重新采样人数:", max_person) 11 | 12 | result = [] 13 | 14 | # 每个人10个声音,10个人脸 15 | wav_size = 10 16 | jpg_size = 10 17 | 18 | for name in names: 19 | wavs = name2wav[name] 20 | random.shuffle(wavs) 21 | wavs = wavs[0:wav_size] 22 | # 选10个声音 23 | 24 | jpgs = name2jpg[name] 25 | random.shuffle(jpgs) 26 | jpgs = jpgs[0:jpg_size] 27 | # 选10个人脸 28 | 29 | for wav, jpg in zip(wavs, jpgs): 30 | tup = (wav, jpg, name) 31 | result.append(tup) 32 | jpg_set.add(jpg) 33 | wav_set.add(wav) 34 | return result 35 | 36 | 37 | def fun(names, name2wav, name2jpg, group=3, max_person=None): 38 | seed_util.set_seed(100) 39 | 40 | jpg_set = set() 41 | wav_set = set() 42 | 43 | # 设置4组 44 | result = [] 45 | for i in range(group): 46 | arr = one_gropu(names, name2wav, name2jpg, jpg_set, wav_set, max_person) 47 | result.append(arr) 48 | 49 | obj = { 50 | "wav_set": wav_set, 51 | "jpg_set": jpg_set, 52 | "retrieval_lists": result 53 | } 54 | 55 | return obj 56 | 57 | 58 | if __name__ == "__main__": 59 | name_list_dict = pickle_util.read_pickle("./dataset/info/train_valid_test_names.pkl") 60 | name2jpgs_wavs = pickle_util.read_pickle("./dataset/info/name2jpgs_wavs.pkl") 61 | name2gender = pickle_util.read_pickle("./dataset/info/name2gender.pkl") 62 | name2wav = name2jpgs_wavs["name2wavs"] 63 | name2jpg = name2jpgs_wavs["name2jpgs"] 64 | pickle_util.save_pickle("./dataset/evals/test_retrieval.pkl", fun(name_list_dict["test"], name2wav, name2jpg)) 65 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/scripts/__init__.py -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/utils/__init__.py -------------------------------------------------------------------------------- /utils/angles_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import math 4 | 5 | 6 | def calculate_pitch_yaw_roll(landmarks_2D, cam_w=256, cam_h=256): 7 | """ Return the the pitch yaw and roll angles associated with the input image. 8 | @param radians When True it returns the angle in radians, otherwise in degrees. 9 | """ 10 | c_x = cam_w/2 11 | c_y = cam_h/2 12 | f_x = c_x / np.tan(60/2 * np.pi / 180) 13 | f_y = f_x 14 | 15 | #Estimated camera matrix values. 16 | camera_matrix = np.float32([[f_x, 0.0, c_x], 17 | [0.0, f_y, c_y], 18 | [0.0, 0.0, 1.0]]) 19 | 20 | camera_distortion = np.float32([0.0, 0.0, 0.0, 0.0, 0.0]) 21 | 22 | #The dlib shape predictor returns 68 points, we are interested only in a few of those 23 | # TRACKED_POINTS = [17, 21, 22, 26, 36, 39, 42, 45, 31, 35, 48, 54, 57, 8] 24 | #wflw(98 landmark) trached points 25 | # TRACKED_POINTS = [33, 38, 50, 46, 60, 64, 68, 72, 55, 59, 76, 82, 85, 16] 26 | #X-Y-Z with X pointing forward and Y on the left and Z up. 27 | #The X-Y-Z coordinates used are like the standard 28 | # coordinates of ROS (robotic operative system) 29 | #OpenCV uses the reference usually used in computer vision: 30 | #X points to the right, Y down, Z to the front 31 | LEFT_EYEBROW_LEFT = [6.825897, 6.760612, 4.402142] 32 | LEFT_EYEBROW_RIGHT = [1.330353, 7.122144, 6.903745] 33 | RIGHT_EYEBROW_LEFT = [-1.330353, 7.122144, 6.903745] 34 | RIGHT_EYEBROW_RIGHT= [-6.825897, 6.760612, 4.402142] 35 | LEFT_EYE_LEFT = [5.311432, 5.485328, 3.987654] 36 | LEFT_EYE_RIGHT = [1.789930, 5.393625, 4.413414] 37 | RIGHT_EYE_LEFT = [-1.789930, 5.393625, 4.413414] 38 | RIGHT_EYE_RIGHT= [-5.311432, 5.485328, 3.987654] 39 | NOSE_LEFT = [2.005628, 1.409845, 6.165652] 40 | NOSE_RIGHT = [-2.005628, 1.409845, 6.165652] 41 | MOUTH_LEFT = [2.774015, -2.080775, 5.048531] 42 | MOUTH_RIGHT=[-2.774015, -2.080775, 5.048531] 43 | LOWER_LIP= [0.000000, -3.116408, 6.097667] 44 | CHIN = [0.000000, -7.415691, 4.070434] 45 | 46 | landmarks_3D = np.float32([LEFT_EYEBROW_LEFT, 47 | LEFT_EYEBROW_RIGHT, 48 | RIGHT_EYEBROW_LEFT, 49 | RIGHT_EYEBROW_RIGHT, 50 | LEFT_EYE_LEFT, 51 | LEFT_EYE_RIGHT, 52 | RIGHT_EYE_LEFT, 53 | RIGHT_EYE_RIGHT, 54 | NOSE_LEFT, 55 | NOSE_RIGHT, 56 | MOUTH_LEFT, 57 | MOUTH_RIGHT, 58 | LOWER_LIP, 59 | CHIN]) 60 | 61 | #Return the 2D position of our landmarks 62 | assert landmarks_2D is not None, 'landmarks_2D is None' 63 | landmarks_2D = np.asarray(landmarks_2D, dtype=np.float32).reshape(-1, 2) 64 | #Applying the PnP solver to find the 3D pose 65 | #of the head from the 2D position of the 66 | #landmarks. 67 | #retval - bool 68 | #rvec - Output rotation vector that, together with tvec, brings 69 | #points from the world coordinate system to the camera coordinate system. 70 | #tvec - Output translation vector. It is the position of the world origin (SELLION) in camera co-ords 71 | retval, rvec, tvec = cv2.solvePnP(landmarks_3D, 72 | landmarks_2D, 73 | camera_matrix, 74 | camera_distortion) 75 | 76 | #Get as input the rotational vector 77 | #Return a rotational matrix 78 | rmat, _ = cv2.Rodrigues(rvec) 79 | pose_mat = cv2.hconcat((rmat, tvec)) 80 | 81 | #euler_angles contain (pitch, yaw, roll) 82 | # euler_angles = cv2.DecomposeProjectionMatrix(projMatrix=rmat, cameraMatrix=self.camera_matrix, rotMatrix, transVect, rotMatrX=None, rotMatrY=None, rotMatrZ=None) 83 | _, _, _, _, _, _, euler_angles = cv2.decomposeProjectionMatrix(pose_mat) 84 | pitch, yaw, roll = map(lambda temp: temp[0], euler_angles) 85 | return pitch, yaw, roll 86 | 87 | # Calculates rotation matrix to euler angles 88 | # The result is the same as MATLAB except the order 89 | # of the euler angles ( x and z are swapped ). 90 | 91 | 92 | def rotationMatrixToEulerAngles(R) : 93 | #assert(isRotationMatrix(R)) 94 | #To prevent the Gimbal Lock it is possible to use 95 | #a threshold of 1e-6 for discrimination 96 | sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0]) 97 | 98 | singular = sy < 1e-6 99 | if not singular : 100 | x = math.atan2(R[2,1] , R[2,2]) 101 | y = math.atan2(-R[2,0], sy) 102 | z = math.atan2(R[1,0], R[0,0]) 103 | else : 104 | x = math.atan2(-R[1,2], R[1,1]) 105 | y = math.atan2(-R[2,0], sy) 106 | z = 0 107 | return np.array([x, y, z]) -------------------------------------------------------------------------------- /utils/barlow_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BarlowTwinsLoss(torch.nn.Module): 5 | 6 | def __init__(self, lambda_param=5e-3): 7 | super(BarlowTwinsLoss, self).__init__() 8 | self.lambda_param = lambda_param 9 | 10 | def forward(self, z_a: torch.Tensor, z_b: torch.Tensor): 11 | # normalize repr. along the batch dimension 12 | z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD 13 | z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD 14 | 15 | N = z_a.size(0) 16 | D = z_a.size(1) 17 | 18 | # cross-correlation matrix 19 | c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD 20 | # loss 21 | c_diff = (c - torch.eye(D, device="cuda")).pow(2) # DxD 22 | # multiply off-diagonal elems of c_diff by lambda 23 | c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param 24 | 25 | loss = c_diff.sum() 26 | 27 | return loss 28 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | 3 | from utils import pickle_util, vec_util, path_util 4 | from utils.eva_emb_full import EmbEva 5 | 6 | 7 | def load_face_emb_dict(): 8 | face_emb_dict = pickle_util.read_pickle(path_util.look_up("./dataset/face_input.pkl")) 9 | vec_util.dict2unit_dict_inplace(face_emb_dict) 10 | # 2.trans key 11 | face_emb_dict2 = {} 12 | for key, v in face_emb_dict.items(): 13 | # 'A.J._Buckley/1.6/1zcIwhmdeo4/0000375.jpg' ==> 'A.J._Buckley/1zcIwhmdeo4/0000375.jpg' 14 | face_emb_dict2[key.replace("/1.6/", "/")] = v 15 | return face_emb_dict2 16 | 17 | 18 | def load_voice_emb_dict(): 19 | voice_emb_dict = pickle_util.read_pickle(path_util.look_up("./dataset/voice_input.pkl")) 20 | vec_util.dict2unit_dict_inplace(voice_emb_dict) 21 | name2voice_id = pickle_util.read_pickle("./dataset/info/name2voice_id.pkl") 22 | voiceid2name = {} 23 | for k, v in name2voice_id.items(): 24 | voiceid2name[v] = k 25 | 26 | voice_emb_dict2 = {} 27 | for k, v in voice_emb_dict.items(): 28 | # id11194/bdFSAep9GQk/00005.wav ==> Ty_Pennington/bdFSAep9GQk/00005.wav 29 | the_id = k.split("/")[0] 30 | name = voiceid2name[the_id] 31 | k2 = k.replace(the_id, name) 32 | voice_emb_dict2[k2] = v 33 | return voice_emb_dict2 34 | 35 | 36 | face_emb_dict = load_face_emb_dict() 37 | voice_emb_dict = load_voice_emb_dict() 38 | 39 | # project_name = "VFALBenchmark" 40 | # total_epoch = 100 41 | # batch_size = 256 42 | # early_stop = 5 43 | # batch_per_epoch = 500 44 | # eval_step = 150 45 | # save_folder = "" 46 | 47 | 48 | # 2.eval 49 | # emb_eva = EmbEva(voice_emb_dict, face_emb_dict) 50 | -------------------------------------------------------------------------------- /utils/deep_coral_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def CORALV2(source, target): 5 | d = source.data.shape[1] 6 | # source covariance 7 | xm = torch.mean(source, 1, keepdim=True) - source 8 | xc = torch.matmul(torch.transpose(xm, 0, 1), xm) 9 | # target covariance 10 | xmt = torch.mean(target, 1, keepdim=True) - target 11 | xct = torch.matmul(torch.transpose(xmt, 0, 1), xmt) 12 | 13 | # frobenius norm between source and target 14 | tmp = torch.sum(torch.mul((xc - xct), (xc - xct))) 15 | loss = tmp / (4 * d * d) 16 | 17 | return loss 18 | -------------------------------------------------------------------------------- /utils/deepcluster_util.py: -------------------------------------------------------------------------------- 1 | import mkl 2 | import collections 3 | 4 | mkl.get_max_threads() 5 | import faiss 6 | from utils import wb_util, distance_util 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def do_k_means(matrix, ncentroids): 12 | niter = 20 13 | verbose = True 14 | d = matrix.shape[1] 15 | 16 | kmeans = faiss.Kmeans(d, 17 | ncentroids, 18 | niter=niter, 19 | verbose=verbose, 20 | spherical=False, 21 | min_points_per_centroid=3, 22 | max_points_per_centroid=100000, 23 | gpu=False, 24 | ) 25 | 26 | kmeans.train(matrix) 27 | 28 | D, I = kmeans.index.search(matrix, 1) 29 | 30 | cluster_label = I.squeeze() 31 | similarity_array = [] 32 | for i in range(len(matrix)): 33 | sample_vec = matrix[i] 34 | sample_label = I[i][0] 35 | center_vec = kmeans.centroids[sample_label] 36 | similarity = distance_util.cosine_similarity(sample_vec, center_vec) 37 | similarity_array.append(similarity) 38 | similarity_array = np.array(similarity_array) 39 | 40 | sorted_similarity_array = similarity_array.copy() 41 | sorted_similarity_array.sort() 42 | 43 | return cluster_label, similarity_array 44 | 45 | 46 | def get_center_matrix(v_emb, f_emb, cluster_label, ncentroids): 47 | tmp_dict = collections.defaultdict(list) 48 | 49 | for v, f, label in zip(v_emb, f_emb, cluster_label): 50 | tmp_dict[label].append(v) 51 | tmp_dict[label].append(f) 52 | 53 | tmp_arr = [] 54 | for i in range(ncentroids): 55 | vec = np.mean(tmp_dict[i], axis=0) 56 | tmp_arr.append(vec) 57 | 58 | center_matrix = np.array(tmp_arr) 59 | return center_matrix 60 | 61 | 62 | def extract_embeddings(ordered_iter, model): 63 | model.eval() 64 | all_emb = [] 65 | all_emb_v = [] 66 | all_emb_f = [] 67 | all_keys = [] 68 | for data in ordered_iter: 69 | with torch.no_grad(): 70 | data = [i.cuda() for i in data] 71 | voice_data, face_data, label = data 72 | v_emb, f_emb = model(voice_data, face_data) 73 | # [v-f] 74 | the_emb = torch.cat([v_emb, f_emb], dim=1).detach().cpu().numpy() 75 | label_npy = label.squeeze().detach().cpu().numpy().tolist() 76 | 77 | for emb, label_int in zip(the_emb, label_npy): 78 | all_emb.append(emb) 79 | all_emb_v.append(emb[0:128]) 80 | all_emb_f.append(emb[128:]) 81 | all_keys.append(ordered_iter.dataset.train_movie_list[label_int]) 82 | model.train() 83 | return all_keys, all_emb, all_emb_v, all_emb_f 84 | 85 | 86 | def do_cluster(ordered_iter, ncentroids, model, input_emb_type="all"): 87 | all_keys, all_emb, all_emb_v, all_emb_f = extract_embeddings(ordered_iter, model) 88 | 89 | if input_emb_type == "v": 90 | input_emb = np.array(all_emb_v) 91 | elif input_emb_type == "f": 92 | input_emb = np.array(all_emb_f) 93 | elif input_emb_type == "all": 94 | input_emb = np.array(all_emb) 95 | else: 96 | raise Exception("wrong type") 97 | 98 | cluster_label, similarity_array = do_k_means(input_emb, ncentroids) 99 | 100 | movie2label = {} 101 | for label, key, sim in zip(cluster_label, all_keys, similarity_array): 102 | movie2label[key] = label 103 | 104 | center_vector = get_center_matrix(all_emb_v, all_emb_f, cluster_label, ncentroids) 105 | return movie2label, center_vector 106 | 107 | 108 | def do_cluster_v2(all_keys, all_emb, all_emb_v, all_emb_f, ncentroids, input_emb_type="all"): 109 | if input_emb_type == "v": 110 | input_emb = np.array(all_emb_v) 111 | elif input_emb_type == "f": 112 | input_emb = np.array(all_emb_f) 113 | elif input_emb_type == "all": 114 | input_emb = np.array(all_emb) 115 | else: 116 | raise Exception("wrong type") 117 | 118 | cluster_label, similarity_array = do_k_means(input_emb, ncentroids) 119 | 120 | movie2label = {} 121 | for label, key, sim in zip(cluster_label, all_keys, similarity_array): 122 | movie2label[key] = label 123 | 124 | center_vector = get_center_matrix(all_emb_v, all_emb_f, cluster_label, ncentroids) 125 | return movie2label, center_vector 126 | -------------------------------------------------------------------------------- /utils/distance_util.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import numpy as np 3 | import scipy.spatial 4 | 5 | 6 | def calc_inter_distance(embedding): 7 | matrix_dot = numpy.dot(embedding, numpy.transpose(embedding)) 8 | # (batch,batch) 9 | 10 | l2_norm_squired = numpy.diagonal(matrix_dot) 11 | 12 | distance_matrix_squired = numpy.expand_dims(l2_norm_squired, axis=0) + numpy.expand_dims(l2_norm_squired, 13 | axis=1) - 2.0 * matrix_dot 14 | 15 | distance_matrix = numpy.maximum(distance_matrix_squired, 0.0) 16 | distance_matrix = numpy.sqrt(distance_matrix) 17 | return distance_matrix 18 | 19 | 20 | def calc_matrix_distance(matrix_a, matrix_b): 21 | # matrix_a: (batch_a,dim) 22 | # matrix_b: (batch_b,dim) 23 | 24 | matrix_dot = numpy.dot(matrix_a, numpy.transpose(matrix_b)) 25 | # (batch_a,batch_b) 26 | 27 | a_square = numpy.sum(matrix_a * matrix_a, axis=1) 28 | # (batch_a) 29 | 30 | b_square = numpy.sum(matrix_b * matrix_b, axis=1) 31 | # (batch_b) 32 | 33 | a_square_2d = numpy.expand_dims(a_square, axis=1) 34 | # (1,batch_a) 35 | 36 | b_square_2d = numpy.expand_dims(b_square, axis=0) 37 | # (batch_b,1) 38 | 39 | distance_matrix_squired = a_square_2d - 2.0 * matrix_dot + b_square_2d 40 | 41 | distance_matrix = numpy.maximum(distance_matrix_squired, 0.0) 42 | distance_matrix = numpy.sqrt(distance_matrix) 43 | return distance_matrix 44 | 45 | 46 | def parallel_distance(a, b): 47 | a = numpy.array(a) 48 | b = numpy.array(b) 49 | 50 | assert len(a) == len(b) 51 | 52 | c = a - b 53 | return numpy.sqrt(numpy.sum(c * c, axis=1)) 54 | 55 | 56 | def parallel_distance_cosine_based_distance(a, b): 57 | assert len(a.shape) == 2 58 | assert a.shape == b.shape 59 | ab = np.sum(a * b, axis=1) 60 | # (batch_size,) 61 | 62 | a_norm = np.sqrt(np.sum(a * a, axis=1)) 63 | b_norm = np.sqrt(np.sum(b * b, axis=1)) 64 | cosine = ab / (a_norm * b_norm) 65 | 66 | dist = 1 - cosine 67 | # 0~2 68 | return dist 69 | 70 | 71 | def distance_of_2point(a, b): 72 | return parallel_distance([a], [b])[0] 73 | 74 | 75 | def cosine_similarity(v1, v2): 76 | return (1 - scipy.spatial.distance.cosine(v1, v2) + 1) / 2.0 77 | -------------------------------------------------------------------------------- /utils/dlib_util.py: -------------------------------------------------------------------------------- 1 | import dlib 2 | import skimage.draw 3 | import numpy as np 4 | import cv2 5 | import math 6 | # ======== 人脸矫正技术 7 | from utils import faceBlendCommon 8 | 9 | face_detector = dlib.get_frontal_face_detector() 10 | 11 | root_path = "/ssd2/7_论文数据/23_reconstructing_faces_from_voices/dlib-models/" 12 | landmark68_predictor = dlib.shape_predictor(root_path + 'shape_predictor_68_face_landmarks.dat') 13 | landmark5_detector = dlib.shape_predictor(root_path + "shape_predictor_5_face_landmarks.dat") 14 | 15 | 16 | def crop_frames(origin_frames): 17 | croped_frames = [] 18 | for frame in origin_frames: 19 | try: 20 | landmark = get_landmark(frame) 21 | output = crop_image(frame, landmark) 22 | croped_frames.append(output) 23 | except Exception as e: 24 | print("error:", e) 25 | assert len(croped_frames) > 0, "未提取到frame" 26 | return croped_frames 27 | 28 | 29 | def get_landmark(img): 30 | ans = face_detector(img) 31 | assert len(ans) > 0, "未检测到人脸" 32 | rect = ans[0] 33 | sp = landmark68_predictor(img, rect) 34 | landmarks = np.array([[p.x, p.y] for p in sp.parts()]) 35 | return landmarks 36 | 37 | 38 | def crop_image(img, landmarks): 39 | # 并不是进行裁切提取,而是将需要的部分给复制了一份出来 40 | outline = landmarks[[*range(17), *range(26, 16, -1)]] 41 | # [[x,y],[x,y].... ] 42 | 43 | Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0]) 44 | cropped_img = np.zeros(img.shape, dtype=np.uint8) 45 | cropped_img[Y, X] = img[Y, X] 46 | return cropped_img 47 | 48 | 49 | # 利用5点对齐,从图片中截取出面部图片,含有人脸旋转矫正功能,-。-可以不用自己写的了-。- 50 | def face_crop_and_alignment(image, out_height=224, out_width=224): 51 | points = faceBlendCommon.getLandmarks(face_detector, landmark5_detector, image) 52 | landmarks = np.array(points) 53 | assert len(landmarks) > 0 54 | 55 | # 由[0,255]区间变为[0,1] 56 | image = np.float32(image) / 255.0 57 | normalized_image, normalized_landmarks = faceBlendCommon.normalizeImagesAndLandmarks((out_height, out_width), image, landmarks) 58 | # 变回[0,255]区间: 59 | normalized_image = np.uint8(normalized_image * 255) 60 | return normalized_image 61 | 62 | 63 | # ======== landmark稳定技术: 64 | 65 | def crop_frames_with_stabled_landmark(origin_frames): 66 | landmarks = get_stabled_landmarks(origin_frames) 67 | ans = [] 68 | for frame, landmark in zip(origin_frames, landmarks): 69 | if landmark is not None: 70 | output = crop_image(frame, landmark) 71 | ans.append(output) 72 | assert len(ans) > 0, "未提取到frame" 73 | return ans 74 | 75 | 76 | def get_stabled_landmarks(frames): 77 | stable_points_list = [] 78 | eyeDistance = None 79 | for step in range(0, len(frames)): 80 | # 1.检测人脸框 81 | imDlib = cv2.cvtColor(frames[step], cv2.COLOR_BGR2RGB) 82 | faces = face_detector(imDlib, 0) 83 | 84 | if len(faces) == 0: 85 | print("未检测到人脸") 86 | stable_points_list.append(None) 87 | continue 88 | 89 | # 2.关键点检测 90 | newRect = faces[0] 91 | landmark_point_list = landmark68_predictor(imDlib, newRect).parts() 92 | if eyeDistance is None: 93 | eyeDistance = interEyeDistance(landmark_point_list) 94 | 95 | landmark_float_matrix2d = np.array([(p.x, p.y) for p in landmark_point_list], dtype=np.float32) 96 | 97 | # 3.稳定点: 98 | if step == 0: 99 | # 对于第一帧,并没可以用来进行稳定的信息 100 | stable_points_list.append(landmark_float_matrix2d) 101 | continue 102 | 103 | last_frame = to_grave(frames[step - 1]) 104 | this_frame = to_grave(frames[step]) 105 | last_stable_point = stable_points_list[step - 1] 106 | 107 | if last_stable_point is None: 108 | # 无法进行stable运算: 109 | this_stable_points_matrix = landmark_float_matrix2d 110 | else: 111 | this_stable_points_matrix = calc_stable_points(last_frame, this_frame, last_stable_point, landmark_float_matrix2d, eyeDistance) 112 | 113 | stable_points_list.append(this_stable_points_matrix) 114 | return stable_points_list 115 | 116 | 117 | def interEyeDistance(predict): 118 | leftEyeLeftCorner = (predict[36].x, predict[36].y) 119 | rightEyeRightCorner = (predict[45].x, predict[45].y) 120 | distance = cv2.norm(np.array(rightEyeRightCorner) - np.array(leftEyeLeftCorner)) 121 | distance = int(distance) 122 | return distance 123 | 124 | 125 | def to_grave(frame): 126 | return cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 127 | 128 | 129 | def calc_stable_points(imGrayPrev, imGray, landmarks_pre, landmarks_this, eyeDistance): 130 | # 计算参数 131 | if eyeDistance > 100: 132 | dotRadius = 3 133 | else: 134 | dotRadius = 2 135 | sigma = eyeDistance * eyeDistance / 400 136 | s = dotRadius * int(eyeDistance / 4) + 1 137 | lk_params = dict(winSize=(s, s), maxLevel=5, criteria=(cv2.TERM_CRITERIA_COUNT | cv2.TERM_CRITERIA_EPS, 20, 0.03)) 138 | 139 | # 基于光流得到的点 140 | points_optical_flow, _, _ = cv2.calcOpticalFlowPyrLK(imGrayPrev, imGray, landmarks_pre, landmarks_this, **lk_params) 141 | 142 | ans = [] 143 | for i in range(len(landmarks_this)): 144 | point_now = landmarks_pre[i] 145 | point_last = landmarks_pre[i] 146 | point_flo = points_optical_flow[i] 147 | 148 | # 计算两帧各个点之间的预测差值距离 149 | distance = cv2.norm(point_now - point_last) 150 | 151 | # 距离越大则光流结果越不可靠,因此更应该相信当前的预测值(alpha越小,1-alpha越大,当前比重越大) 152 | weight = math.exp(-distance * distance / sigma) 153 | 154 | # 组合权重 155 | final_point_value = (1 - weight) * point_now + weight * point_flo 156 | 157 | ans.append(final_point_value) 158 | return np.array(ans, np.float32) 159 | -------------------------------------------------------------------------------- /utils/eva_emb_full.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from utils import distance_util, path_util 4 | from utils import map_evaluate 5 | from sklearn.metrics import roc_auc_score 6 | import scipy.spatial 7 | import numpy as np 8 | import collections 9 | from utils import pickle_util 10 | import time 11 | 12 | 13 | # class Timer: 14 | # def __init__(self): 15 | # self.last_time = None 16 | # 17 | # def 18 | 19 | 20 | class EmbEva: 21 | 22 | def __init__(self): 23 | pass 24 | 25 | def do_valid(self, model): 26 | obj = {"valid/auc": self.do_verification(model, "./dataset/evals/valid_verification.pkl")} 27 | return obj 28 | 29 | def do_full_test(self, model, gender_constraint=False): 30 | obj = {} 31 | # 1.verification 32 | t1 = time.time() 33 | obj["test/auc"] = self.do_verification(model, "./dataset/evals/test_verification.pkl") 34 | if gender_constraint: 35 | obj["test/auc_g"] = self.do_verification(model, "./dataset/evals/test_verification_g.pkl") 36 | t2 = time.time() 37 | 38 | # 2.retrieval 39 | obj["test/map_v2f"], obj["test/map_f2v"] = self.do_retrival(model, "./dataset/evals/test_retrieval.pkl") 40 | t3 = time.time() 41 | 42 | # 3.matching 43 | obj["test/ms_v2f"], obj["test/ms_f2v"] = self.do_matching(model, "./dataset/evals/test_matching.pkl") 44 | if gender_constraint: 45 | obj["test/ms_v2f_g"], obj["test/ms_f2v_g"] = self.do_matching(model, "./dataset/evals/test_matching_g.pkl") 46 | t4 = time.time() 47 | 48 | print("time spend: auc%.1f map%.1f matching%.1f" % (t2 - t1, t3 - t2, t4 - t3)) 49 | return obj 50 | 51 | def do_1_N_matching(self, model): 52 | data = pickle_util.read_pickle(path_util.look_up("./dataset/evals/test_matching_10.pkl")) 53 | v2emb, f2emb = self.to_emb_dict(model, data["jpg_set"], data["wav_set"]) 54 | key2emb = {**v2emb, **f2emb} 55 | ans = {} 56 | ans["v2f"] = handle_1_n(data["match_list"], is_v2f=True, key2emb=key2emb) 57 | ans["f2v"] = handle_1_n(data["match_list"], is_v2f=False, key2emb=key2emb) 58 | return ans 59 | 60 | def do_matching(self, model, pkl_path): 61 | data = pickle_util.read_pickle(pkl_path) 62 | v2emb, f2emb = self.to_emb_dict(model, data["jpg_set"], data["wav_set"]) 63 | ms_vf, ms_fv = calc_ms(data["match_list"], v2emb, f2emb) 64 | return ms_vf * 100, ms_fv * 100 65 | 66 | def do_verification(self, model, pkl_path): 67 | data = pickle_util.read_pickle(pkl_path) 68 | v2emb, f2emb = self.to_emb_dict(model, data["jpg_set"], data["wav_set"]) 69 | return calc_vrification(data["list"], v2emb, f2emb) * 100 70 | 71 | def do_retrival(self, model, pkl_path): 72 | data = pickle_util.read_pickle(pkl_path) 73 | v2emb, f2emb = self.to_emb_dict(model, data["jpg_set"], data["wav_set"]) 74 | map_vf, map_fv = calc_map_value(data["retrieval_lists"], v2emb, f2emb) 75 | return map_vf * 100, map_fv * 100 76 | 77 | def to_emb_dict(self, model, all_jpg_set, all_wav_set): 78 | model.eval() 79 | batch_size = 512 80 | image_loader = DataLoader(DataSet(list(all_jpg_set)), batch_size=batch_size, shuffle=False, pin_memory=True) 81 | voice_loader = DataLoader(DataSet(list(all_wav_set)), batch_size=batch_size, shuffle=False, pin_memory=True) 82 | f2emb = get_path2emb(image_loader.dataset.data, model.face_encoder, image_loader) 83 | v2emb = get_path2emb(voice_loader.dataset.data, model.voice_encoder, voice_loader) 84 | model.train() 85 | return v2emb, f2emb 86 | 87 | 88 | def calc_ms(all_data, v2emb, f2emb): 89 | voice1_emb = [] 90 | voice2_emb = [] 91 | face1_emb = [] 92 | face2_emb = [] 93 | 94 | for name1, voice1, face1, name2, voice2, face2 in all_data: 95 | voice1_emb.append(v2emb[voice1]) 96 | voice2_emb.append(v2emb[voice2]) 97 | face1_emb.append(f2emb[face1]) 98 | face2_emb.append(f2emb[face2]) 99 | 100 | voice1_emb = np.array(voice1_emb) 101 | voice2_emb = np.array(voice2_emb) 102 | face1_emb = np.array(face1_emb) 103 | face2_emb = np.array(face2_emb) 104 | 105 | dist_vf1 = distance_util.parallel_distance_cosine_based_distance(voice1_emb, face1_emb) 106 | dist_vf2 = distance_util.parallel_distance_cosine_based_distance(voice1_emb, face2_emb) 107 | dist_fv1 = distance_util.parallel_distance_cosine_based_distance(face1_emb, voice1_emb) 108 | dist_fv2 = distance_util.parallel_distance_cosine_based_distance(face1_emb, voice2_emb) 109 | 110 | vf_result = dist_vf1 < dist_vf2 111 | fv_result = dist_fv1 < dist_fv2 112 | ms_vf = np.mean(vf_result) 113 | ms_fv = np.mean(fv_result) 114 | 115 | obj = { 116 | "dist_vf1": dist_vf1, 117 | "dist_vf2": dist_vf2, 118 | "dist_fv1": dist_fv1, 119 | "dist_fv2": dist_fv2, 120 | "test_data": all_data, # name1, voice1, face1, name2, voice2, face2 121 | "result_fv": fv_result, 122 | "result_vf": vf_result, 123 | "score_vf": ms_vf, 124 | "score_fv": ms_fv, 125 | } 126 | return ms_vf, ms_fv 127 | 128 | 129 | def calc_map_value(retrieval_lists, v2emb, f2emb): 130 | tmp_dic = collections.defaultdict(list) 131 | for arr in retrieval_lists: 132 | map_vf, map_fv = calc_map_recall_at_k(arr, v2emb, f2emb) 133 | tmp_dic["map_vf"].append(map_vf) 134 | tmp_dic["map_fv"].append(map_fv) 135 | map_fv = np.mean(tmp_dic["map_fv"]) 136 | map_vf = np.mean(tmp_dic["map_vf"]) 137 | return map_vf, map_fv 138 | 139 | 140 | # class DataSet(torch.utils.data.Dataset): 141 | # 142 | # def __init__(self, data): 143 | # self.data = data 144 | # 145 | # def __len__(self): 146 | # return len(self.data) 147 | # 148 | # def __getitem__(self, index): 149 | # short_path = self.data[index] 150 | # return input_loader.load_data(short_path), index 151 | from utils.config import face_emb_dict, voice_emb_dict 152 | 153 | 154 | class DataSet(torch.utils.data.Dataset): 155 | 156 | def __init__(self, data): 157 | self.data = data 158 | 159 | def __len__(self): 160 | return len(self.data) 161 | 162 | def __getitem__(self, index): 163 | short_path = self.data[index] 164 | if ".wav" in short_path: 165 | emb = voice_emb_dict[short_path] 166 | else: 167 | emb = face_emb_dict[short_path] 168 | return emb, index 169 | 170 | 171 | def handle_1_n(match_list, is_v2f, key2emb): 172 | tmp_dict = collections.defaultdict(list) 173 | for voices, faces in match_list: 174 | if is_v2f: 175 | prob = voices[0] 176 | gallery = faces 177 | else: 178 | prob = faces[0] 179 | gallery = voices 180 | 181 | # 1. to vector 182 | prob_vec = np.array([key2emb[prob]]) 183 | gallery_vec = np.array([key2emb[i] for i in gallery]) 184 | 185 | # 2. calc similarity 186 | distances = scipy.spatial.distance.cdist(prob_vec, gallery_vec, 'cosine') 187 | distances = distances.squeeze() 188 | assert len(distances) == len(gallery_vec) 189 | 190 | # 3. get results of 2~N matching 191 | for index in range(2, len(gallery) + 1): 192 | arr = distances[:index] 193 | is_correct = int(np.argmin(arr) == 0) 194 | tmp_dict[index].append(is_correct) 195 | 196 | for key, arr in tmp_dict.items(): 197 | tmp_dict[key] = np.mean(arr) 198 | return tmp_dict 199 | 200 | 201 | # 202 | 203 | def get_path2emb(all_path_list, encoder, loader): 204 | f2emb = {} 205 | for data, path_indexes in loader: 206 | emb_batch = encoder(data.cuda()).detach().cpu().numpy() 207 | path_indexes = path_indexes.detach().cpu().numpy() 208 | for p_index, emb in zip(path_indexes, emb_batch): 209 | the_path = all_path_list[p_index] 210 | f2emb[the_path] = emb 211 | 212 | return f2emb 213 | 214 | 215 | def cosine_similarity(a, b): 216 | assert len(a.shape) == 2 217 | assert a.shape == b.shape 218 | 219 | ab = np.sum(a * b, axis=1) 220 | # (batch_size,) 221 | 222 | a_norm = np.sqrt(np.sum(a * a, axis=1)) 223 | b_norm = np.sqrt(np.sum(b * b, axis=1)) 224 | cosine = ab / (a_norm * b_norm) 225 | # [-1,1] 226 | prob = (cosine + 1) / 2.0 227 | return prob 228 | 229 | 230 | def calc_vrification(the_list, v2emb, f2emb): 231 | voice_emb = np.array([v2emb[tup[0]] for tup in the_list]) 232 | face_emb = np.array([f2emb[tup[1]] for tup in the_list]) 233 | real_label = np.array([tup[2] for tup in the_list]) 234 | 235 | # AUC 236 | prob = cosine_similarity(voice_emb, face_emb) 237 | auc = roc_auc_score(real_label, prob) 238 | return auc 239 | 240 | 241 | def calc_map_recall_at_k(all_data, v2emb, f2emb): 242 | # 1.get embedding 243 | labels = [] 244 | v_emb_list = [] 245 | f_emb_list = [] 246 | for v, f, name in all_data: 247 | labels.append(name) 248 | v_emb_list.append(v2emb[v]) 249 | f_emb_list.append(f2emb[f]) 250 | 251 | v_emb_list = np.array(v_emb_list) 252 | f_emb_list = np.array(f_emb_list) 253 | 254 | # 2. calculate distance 255 | vf_dist = scipy.spatial.distance.cdist(v_emb_list, f_emb_list, 'cosine') 256 | fv_dist = vf_dist.T 257 | 258 | # 3.map value 259 | map_vf = map_evaluate.fx_calc_map_label_v2(vf_dist, labels) 260 | map_fv = map_evaluate.fx_calc_map_label_v2(fv_dist, labels) 261 | return map_vf, map_fv 262 | 263 | 264 | def calc_ms_f2v(all_data, v2emb, f2emb): 265 | voice1_emb = [] 266 | voice2_emb = [] 267 | face1_emb = [] 268 | 269 | for face1, voice1, voice2 in all_data: 270 | voice1_emb.append(v2emb[voice1]) 271 | voice2_emb.append(v2emb[voice2]) 272 | face1_emb.append(f2emb[face1]) 273 | 274 | voice1_emb = np.array(voice1_emb) 275 | voice2_emb = np.array(voice2_emb) 276 | face1_emb = np.array(face1_emb) 277 | 278 | dist_fv1 = distance_util.parallel_distance_cosine_based_distance(face1_emb, voice1_emb) 279 | dist_fv2 = distance_util.parallel_distance_cosine_based_distance(face1_emb, voice2_emb) 280 | 281 | fv_result = dist_fv1 < dist_fv2 282 | ms_fv = np.mean(fv_result) 283 | return ms_fv 284 | 285 | 286 | def calc_ms_v2f(all_data, v2emb, f2emb): 287 | voice1_emb = [] 288 | face1_emb = [] 289 | face2_emb = [] 290 | 291 | for voice1, face1, face2 in all_data: 292 | voice1_emb.append(v2emb[voice1]) 293 | face1_emb.append(f2emb[face1]) 294 | face2_emb.append(f2emb[face2]) 295 | 296 | voice1_emb = np.array(voice1_emb) 297 | face1_emb = np.array(face1_emb) 298 | face2_emb = np.array(face2_emb) 299 | 300 | dist_vf1 = distance_util.parallel_distance_cosine_based_distance(voice1_emb, face1_emb) 301 | dist_vf2 = distance_util.parallel_distance_cosine_based_distance(voice1_emb, face2_emb) 302 | 303 | vf_result = dist_vf1 < dist_vf2 304 | ms_vf = np.mean(vf_result) 305 | return ms_vf 306 | -------------------------------------------------------------------------------- /utils/eval_shortcut.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils import wb_util, model_util, pickle_util, my_git 3 | from utils import model_selector 4 | 5 | 6 | class Cut(): 7 | 8 | def __init__(self, emb_eva, model, args): 9 | self.modelSelector = model_selector.ModelSelector() 10 | self.emb_eva = emb_eva 11 | self.model = model 12 | self.args = args 13 | self.best_obj = None 14 | 15 | def eval_short_cut(self, test_threshold=80): 16 | emb_eva = self.emb_eva 17 | model = self.model 18 | modelSelector = self.modelSelector 19 | args = self.args 20 | 21 | # 1.do validation 22 | valid_obj = emb_eva.do_valid(model) 23 | modelSelector.log(valid_obj) 24 | indicator = "valid/auc" 25 | 26 | if valid_obj[indicator] < test_threshold: 27 | obj = valid_obj 28 | print("model too weak, skip test") 29 | elif modelSelector.is_best_model(indicator): 30 | # 2. do test 31 | test_obj = emb_eva.do_full_test(model, gender_constraint=True) 32 | obj = {**valid_obj, **test_obj} 33 | model_util.delete_last_saved_model() 34 | model_save_name = "auc[%.2f,%.2f]_ms[%.2f,%.2f]_map[%.2f,%.2f].pkl" % ( 35 | obj["valid/auc"], 36 | obj["test/auc"], 37 | obj["test/ms_v2f"], 38 | obj["test/ms_f2v"], 39 | obj["test/map_v2f"], 40 | obj["test/map_f2v"], 41 | ) 42 | model_save_path = os.path.join(args.model_save_folder, args.project, args.name, model_save_name) 43 | model_util.save_model(0, model, None, model_save_path) 44 | pickle_util.save_json(model_save_path + ".json", test_obj) 45 | self.best_obj = obj 46 | else: 47 | obj = valid_obj 48 | print("not best model") 49 | 50 | # 2.log 51 | wb_util.log(obj) 52 | print(obj) 53 | wb_util.init(args) 54 | # my_git.commit_v2(args) 55 | 56 | if modelSelector.should_stop(indicator, args.early_stop): 57 | print("early_stop") 58 | if len(model_util.history_array) > 0: 59 | print(model_util.history_array[-1]) 60 | # 上传best信息 61 | best_obj = {} 62 | for k, v in self.best_obj.items(): 63 | best_obj["best_" + k] = v 64 | wb_util.log(best_obj) 65 | return True 66 | return False 67 | -------------------------------------------------------------------------------- /utils/faceBlendCommon.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 BIG VISION LLC ALL RIGHTS RESERVED 2 | # 3 | # This code is made available to the students of 4 | # the online course titled "Computer Vision for Faces" 5 | # by Satya Mallick for personal non-commercial use. 6 | # 7 | # Sharing this code is strictly prohibited without written 8 | # permission from Big Vision LLC. 9 | # 10 | # For licensing and other inquiries, please email 11 | # spmallick@bigvisionllc.com 12 | # 13 | import cv2 14 | import dlib 15 | import numpy as np 16 | import math 17 | 18 | # Returns 8 points on the boundary of a rectangle 19 | def getEightBoundaryPoints(h, w): 20 | boundaryPts = [] 21 | boundaryPts.append((0,0)) 22 | boundaryPts.append((w/2, 0)) 23 | boundaryPts.append((w-1,0)) 24 | boundaryPts.append((w-1, h/2)) 25 | boundaryPts.append((w-1, h-1)) 26 | boundaryPts.append((w/2, h-1)) 27 | boundaryPts.append((0, h-1)) 28 | boundaryPts.append((0, h/2)) 29 | return np.array(boundaryPts, dtype=np.float) 30 | 31 | 32 | # Constrains points to be inside boundary 33 | def constrainPoint(p, w, h): 34 | p = (min(max(p[0], 0), w - 1), min(max(p[1], 0), h - 1)) 35 | return p 36 | 37 | # convert Dlib shape detector object to list of tuples 38 | def dlibLandmarksToPoints(shape): 39 | points = [] 40 | for p in shape.parts(): 41 | pt = (p.x, p.y) 42 | points.append(pt) 43 | return points 44 | 45 | # Compute similarity transform given two sets of two points. 46 | # OpenCV requires 3 pairs of corresponding points. 47 | # We are faking the third one. 48 | def similarityTransform(inPoints, outPoints): 49 | s60 = math.sin(60*math.pi/180) 50 | c60 = math.cos(60*math.pi/180) 51 | 52 | inPts = np.copy(inPoints).tolist() 53 | outPts = np.copy(outPoints).tolist() 54 | 55 | # The third point is calculated so that the three points make an equilateral triangle 56 | xin = c60*(inPts[0][0] - inPts[1][0]) - s60*(inPts[0][1] - inPts[1][1]) + inPts[1][0] 57 | yin = s60*(inPts[0][0] - inPts[1][0]) + c60*(inPts[0][1] - inPts[1][1]) + inPts[1][1] 58 | 59 | inPts.append([np.int32(xin), np.int32(yin)]) 60 | 61 | xout = c60*(outPts[0][0] - outPts[1][0]) - s60*(outPts[0][1] - outPts[1][1]) + outPts[1][0] 62 | yout = s60*(outPts[0][0] - outPts[1][0]) + c60*(outPts[0][1] - outPts[1][1]) + outPts[1][1] 63 | 64 | outPts.append([np.int32(xout), np.int32(yout)]) 65 | 66 | # Now we can use estimateRigidTransform for calculating the similarity transform. 67 | tform = cv2.estimateAffinePartial2D(np.array([inPts]), np.array([outPts])) 68 | return tform[0] 69 | 70 | # Normalizes a facial image to a standard size given by outSize. 71 | # Normalization is done based on Dlib's landmark points passed as pointsIn 72 | # After normalization, left corner of the left eye is at (0.3 * w, h/3 ) 73 | # and right corner of the right eye is at ( 0.7 * w, h / 3) where w and h 74 | # are the width and height of outSize. 75 | def normalizeImagesAndLandmarks(outSize, imIn, pointsIn): 76 | h, w = outSize 77 | 78 | # Corners of the eye in input image 79 | if len(pointsIn) == 68: 80 | eyecornerSrc = [pointsIn[36], pointsIn[45]] 81 | elif len(pointsIn) == 5: 82 | eyecornerSrc = [pointsIn[2], pointsIn[0]] 83 | 84 | # Corners of the eye in normalized image 85 | eyecornerDst = [(np.int32(0.3 * w), np.int32(h/3)), 86 | (np.int32(0.7 * w), np.int32(h/3))] 87 | 88 | # Calculate similarity transform 89 | tform = similarityTransform(eyecornerSrc, eyecornerDst) 90 | imOut = np.zeros(imIn.shape, dtype=imIn.dtype) 91 | 92 | # Apply similarity transform to input image 93 | imOut = cv2.warpAffine(imIn, tform, (w, h)) 94 | 95 | # reshape pointsIn from numLandmarks x 2 to numLandmarks x 1 x 2 96 | points2 = np.reshape(pointsIn, (pointsIn.shape[0], 1, pointsIn.shape[1])) 97 | 98 | # Apply similarity transform to landmarks 99 | pointsOut = cv2.transform(points2, tform) 100 | 101 | # reshape pointsOut to numLandmarks x 2 102 | pointsOut = np.reshape(pointsOut, (pointsIn.shape[0], pointsIn.shape[1])) 103 | 104 | return imOut, pointsOut 105 | 106 | # find the point closest to an array of points 107 | # pointsArray is a Nx2 and point is 1x2 ndarray 108 | def findIndex(pointsArray, point): 109 | dist = np.linalg.norm(pointsArray-point, axis=1) 110 | minIndex = np.argmin(dist) 111 | return minIndex 112 | 113 | 114 | # Check if a point is inside a rectangle 115 | def rectContains(rect, point): 116 | if point[0] < rect[0]: 117 | return False 118 | elif point[1] < rect[1]: 119 | return False 120 | elif point[0] > rect[2]: 121 | return False 122 | elif point[1] > rect[3]: 123 | return False 124 | return True 125 | 126 | 127 | # Calculate Delaunay triangles for set of points 128 | # Returns the vector of indices of 3 points for each triangle 129 | def calculateDelaunayTriangles(rect, points): 130 | 131 | # Create an instance of Subdiv2D 132 | subdiv = cv2.Subdiv2D(rect) 133 | 134 | # Insert points into subdiv 135 | for p in points: 136 | subdiv.insert((p[0], p[1])) 137 | 138 | # Get Delaunay triangulation 139 | triangleList = subdiv.getTriangleList() 140 | 141 | # Find the indices of triangles in the points array 142 | delaunayTri = [] 143 | 144 | for t in triangleList: 145 | # The triangle returned by getTriangleList is 146 | # a list of 6 coordinates of the 3 points in 147 | # x1, y1, x2, y2, x3, y3 format. 148 | # Store triangle as a list of three points 149 | pt = [] 150 | pt.append((t[0], t[1])) 151 | pt.append((t[2], t[3])) 152 | pt.append((t[4], t[5])) 153 | 154 | pt1 = (t[0], t[1]) 155 | pt2 = (t[2], t[3]) 156 | pt3 = (t[4], t[5]) 157 | 158 | if rectContains(rect, pt1) and rectContains(rect, pt2) and rectContains(rect, pt3): 159 | # Variable to store a triangle as indices from list of points 160 | ind = [] 161 | # Find the index of each vertex in the points list 162 | for j in range(0, 3): 163 | for k in range(0, len(points)): 164 | if(abs(pt[j][0] - points[k][0]) < 1.0 and abs(pt[j][1] - points[k][1]) < 1.0): 165 | ind.append(k) 166 | # Store triangulation as a list of indices 167 | if len(ind) == 3: 168 | delaunayTri.append((ind[0], ind[1], ind[2])) 169 | 170 | return delaunayTri 171 | 172 | # Apply affine transform calculated using srcTri and dstTri to src and 173 | # output an image of size. 174 | def applyAffineTransform(src, srcTri, dstTri, size): 175 | 176 | # Given a pair of triangles, find the affine transform. 177 | warpMat = cv2.getAffineTransform(np.float32(srcTri), np.float32(dstTri)) 178 | 179 | # Apply the Affine Transform just found to the src image 180 | dst = cv2.warpAffine(src, warpMat, (size[0], size[1]), None, 181 | flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) 182 | 183 | return dst 184 | 185 | # Warps and alpha blends triangular regions from img1 and img2 to img 186 | def warpTriangle(img1, img2, t1, t2): 187 | # Find bounding rectangle for each triangle 188 | r1 = cv2.boundingRect(np.float32([t1])) 189 | r2 = cv2.boundingRect(np.float32([t2])) 190 | 191 | # Offset points by left top corner of the respective rectangles 192 | t1Rect = [] 193 | t2Rect = [] 194 | t2RectInt = [] 195 | 196 | for i in range(0, 3): 197 | t1Rect.append(((t1[i][0] - r1[0]), (t1[i][1] - r1[1]))) 198 | t2Rect.append(((t2[i][0] - r2[0]), (t2[i][1] - r2[1]))) 199 | t2RectInt.append(((t2[i][0] - r2[0]), (t2[i][1] - r2[1]))) 200 | 201 | # Get mask by filling triangle 202 | mask = np.zeros((r2[3], r2[2], 3), dtype=np.float32) 203 | cv2.fillConvexPoly(mask, np.int32(t2RectInt), (1.0, 1.0, 1.0), 16, 0) 204 | 205 | # Apply warpImage to small rectangular patches 206 | img1Rect = img1[r1[1]:r1[1] + r1[3], r1[0]:r1[0] + r1[2]] 207 | 208 | size = (r2[2], r2[3]) 209 | 210 | img2Rect = applyAffineTransform(img1Rect, t1Rect, t2Rect, size) 211 | 212 | img2Rect = img2Rect * mask 213 | 214 | # Copy triangular region of the rectangular patch to the output image 215 | img2[r2[1]:r2[1]+r2[3], r2[0]:r2[0]+r2[2]] = img2[r2[1]:r2[1]+r2[3], r2[0]:r2[0]+r2[2]] * ((1.0, 1.0, 1.0) - mask) 216 | img2[r2[1]:r2[1]+r2[3], r2[0]:r2[0]+r2[2]] = img2[r2[1]:r2[1]+r2[3], r2[0]:r2[0]+r2[2]] + img2Rect 217 | 218 | # detect facial landmarks in image 219 | def getLandmarks(faceDetector, landmarkDetector, im, FACE_DOWNSAMPLE_RATIO = 1): 220 | points = [] 221 | imSmall = cv2.resize(im,None, 222 | fx=1.0/FACE_DOWNSAMPLE_RATIO, 223 | fy=1.0/FACE_DOWNSAMPLE_RATIO, 224 | interpolation = cv2.INTER_LINEAR) 225 | 226 | faceRects = faceDetector(imSmall, 0) 227 | 228 | if len(faceRects) > 0: 229 | maxArea = 0 230 | maxRect = None 231 | # TODO: test on images with multiple faces 232 | for face in faceRects: 233 | if face.area() > maxArea: 234 | maxArea = face.area() 235 | maxRect = [face.left(), 236 | face.top(), 237 | face.right(), 238 | face.bottom() 239 | ] 240 | 241 | rect = dlib.rectangle(*maxRect) 242 | scaledRect = dlib.rectangle(int(rect.left()*FACE_DOWNSAMPLE_RATIO), 243 | int(rect.top()*FACE_DOWNSAMPLE_RATIO), 244 | int(rect.right()*FACE_DOWNSAMPLE_RATIO), 245 | int(rect.bottom()*FACE_DOWNSAMPLE_RATIO)) 246 | 247 | landmarks = landmarkDetector(im, scaledRect) 248 | points = dlibLandmarksToPoints(landmarks) 249 | return points 250 | 251 | # Warps an image in a piecewise affine manner. 252 | # The warp is defined by the movement of landmark points specified by pointsIn 253 | # to a new location specified by pointsOut. The triangulation beween points is specified 254 | # by their indices in delaunayTri. 255 | def warpImage(imIn, pointsIn, pointsOut, delaunayTri): 256 | h, w, ch = imIn.shape 257 | # Output image 258 | imOut = np.zeros(imIn.shape, dtype=imIn.dtype) 259 | 260 | # Warp each input triangle to output triangle. 261 | # The triangulation is specified by delaunayTri 262 | for j in range(0, len(delaunayTri)): 263 | # Input and output points corresponding to jth triangle 264 | tin = [] 265 | tout = [] 266 | 267 | for k in range(0, 3): 268 | # Extract a vertex of input triangle 269 | pIn = pointsIn[delaunayTri[j][k]] 270 | # Make sure the vertex is inside the image. 271 | pIn = constrainPoint(pIn, w, h) 272 | 273 | # Extract a vertex of the output triangle 274 | pOut = pointsOut[delaunayTri[j][k]] 275 | # Make sure the vertex is inside the image. 276 | pOut = constrainPoint(pOut, w, h) 277 | 278 | # Push the input vertex into input triangle 279 | tin.append(pIn) 280 | # Push the output vertex into output triangle 281 | tout.append(pOut) 282 | 283 | # Warp pixels inside input triangle to output triangle. 284 | warpTriangle(imIn, imOut, tin, tout) 285 | return imOut 286 | -------------------------------------------------------------------------------- /utils/keops_kmeans.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================================ 3 | K-means clustering - PyTorch API 4 | ================================ 5 | 6 | The :meth:`pykeops.torch.LazyTensor.argmin` reduction supported by KeOps :class:`pykeops.torch.LazyTensor` allows us 7 | to perform **bruteforce nearest neighbor search** with four lines of code. 8 | It can thus be used to implement a **large-scale** 9 | `K-means clustering `_, 10 | **without memory overflows**. 11 | 12 | .. note:: 13 | For large and high dimensional datasets, this script 14 | **outperforms its NumPy counterpart** 15 | as it avoids transfers between CPU (host) and GPU (device) memories. 16 | 17 | 18 | """ 19 | 20 | ######################################################################## 21 | # Setup 22 | # ----------------- 23 | # Standard imports: 24 | 25 | import time 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | import numpy as np 30 | from matplotlib import pyplot as plt 31 | from pykeops.torch import LazyTensor 32 | from IPython import embed 33 | import ipdb 34 | 35 | 36 | def run_kmeans(x, num_cluster_list, Niter, temperature, verbose=True): 37 | results = {'inst2cluster': [], 'centroids': [], 'Dist': [], 'density': []} 38 | 39 | for K in num_cluster_list: 40 | cl, centroids, Dist = KMeans(x, K, Niter, verbose) 41 | 42 | # sample-to-centroid distances for each cluster 43 | Dist = Dist.cpu() 44 | Dcluster = [[] for c in range(K)] 45 | for im, i in enumerate(cl): 46 | Dcluster[i].append(Dist[im][0]) 47 | 48 | # print('Dcluster', len(Dcluster), len(Dcluster[0]), Dcluster[0][0]) # k, points_per_cluster, dist 49 | 50 | # concentration estimation (phi) 51 | density = np.zeros(K) 52 | for i, dist in enumerate(Dcluster): 53 | if len(dist) > 1: 54 | d = (np.asarray(dist) ** 0.5).mean() / np.log(len(dist) + 10) 55 | density[i] = d 56 | # if cluster only has one point, use the max to estimate its concentration 57 | dmax = density.max() 58 | for i, dist in enumerate(Dcluster): 59 | if len(dist) <= 1: 60 | density[i] = dmax 61 | 62 | density = density.clip( 63 | np.percentile(density, 10), np.percentile(density, 90) 64 | ) # clamp extreme values for stability 65 | density = temperature * density / density.mean() # scale the mean to temperature 66 | 67 | # centroids = F.normalize(centroids, p=2, dim=1) 68 | density = torch.Tensor(density).cuda() 69 | 70 | results['inst2cluster'].append(cl) 71 | results['centroids'].append(centroids) 72 | results['Dist'].append(Dist) 73 | results['density'].append(density) 74 | 75 | return results 76 | 77 | 78 | ######################################################################## 79 | # Simple implementation of the K-means algorithm: 80 | 81 | 82 | def KMeans(x, K=10, Niter=10, verbose=True): 83 | """Implements Lloyd's algorithm for the Euclidean metric.""" 84 | 85 | start = time.time() 86 | N, D = x.shape # Number of samples, dimension of the ambient space 87 | 88 | centroids = x[:K, :].clone() # Simplistic initialization for the centroids 89 | 90 | x_i = LazyTensor(x.view(N, 1, D)) # (N, 1, D) samples 91 | c_j = LazyTensor(centroids.view(1, K, D)) # (1, K, D) centroids 92 | 93 | # K-means loop: 94 | # - x is the (N, D) point cloud, 95 | # - cl is the (N,) vector of class labels 96 | # - c is the (K, D) cloud of cluster centroids 97 | for i in range(Niter): 98 | # E step: assign points to the closest cluster ------------------------- 99 | D_ij = ((x_i - c_j) ** 2).sum(-1) # (N, K) symbolic squared distances 100 | cl = D_ij.argmin(dim=1).long().view(-1) # Points -> Nearest cluster 101 | Dist = D_ij.min(dim=1) 102 | 103 | # M step: update the centroids to the normalized cluster average: ------ 104 | # Compute the sum of points per cluster: 105 | centroids.zero_() 106 | centroids.scatter_add_(0, cl[:, None].repeat(1, D), x) 107 | 108 | # Divide by the number of points per cluster: 109 | Ncl = torch.bincount(cl, minlength=K).type_as(centroids).view(K, 1) 110 | centroids /= Ncl # in-place division to compute the average 111 | 112 | if verbose: # Fancy display ----------------------------------------------- 113 | torch.cuda.synchronize() 114 | end = time.time() 115 | print(f"K-means for the Euclidean metric with {N:,} points in dimension {D:,}, K = {K:,}:") 116 | print( 117 | "Timing for {} iterations: {:.5f}s = {} x {:.5f}s\n".format( 118 | Niter, end - start, Niter, (end - start) / Niter 119 | ) 120 | ) 121 | 122 | return cl, centroids, Dist 123 | 124 | 125 | def KMeans_cosine(x, K=10, Niter=10, verbose=True): 126 | """Implements Lloyd's algorithm for the Cosine similarity metric.""" 127 | 128 | start = time.time() 129 | N, D = x.shape # Number of samples, dimension of the ambient space 130 | 131 | c = x[:K, :].clone() # Simplistic initialization for the centroids 132 | # Normalize the centroids for the cosine similarity: 133 | c = torch.nn.functional.normalize(c, dim=1, p=2) 134 | 135 | x_i = LazyTensor(x.view(N, 1, D)) # (N, 1, D) samples 136 | c_j = LazyTensor(c.view(1, K, D)) # (1, K, D) centroids 137 | 138 | # K-means loop: 139 | # - x is the (N, D) point cloud, 140 | # - cl is the (N,) vector of class labels 141 | # - c is the (K, D) cloud of cluster centroids 142 | for i in range(Niter): 143 | # E step: assign points to the closest cluster ------------------------- 144 | S_ij = x_i | c_j # (N, K) symbolic Gram matrix of dot products 145 | cl = S_ij.argmax(dim=1).long().view(-1) # Points -> Nearest cluster 146 | 147 | # M step: update the centroids to the normalized cluster average: ------ 148 | # Compute the sum of points per cluster: 149 | c.zero_() 150 | c.scatter_add_(0, cl[:, None].repeat(1, D), x) 151 | 152 | # Normalize the centroids, in place: 153 | c[:] = torch.nn.functional.normalize(c, dim=1, p=2) 154 | 155 | if verbose: # Fancy display ----------------------------------------------- 156 | torch.cuda.synchronize() 157 | end = time.time() 158 | print(f"K-means for the cosine similarity with {N:,} points in dimension {D:,}, K = {K:,}:") 159 | print( 160 | "Timing for {} iterations: {:.5f}s = {} x {:.5f}s\n".format( 161 | Niter, end - start, Niter, (end - start) / Niter 162 | ) 163 | ) 164 | 165 | return cl, c 166 | -------------------------------------------------------------------------------- /utils/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/my-yy/vfal-eva/c1ca050d22821bf60fcdca096429edb193df2ae6/utils/losses/__init__.py -------------------------------------------------------------------------------- /utils/losses/barlow_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import ipdb 4 | 5 | 6 | class BarlowTwinsLoss(torch.nn.Module): 7 | 8 | def __init__(self, lambda_param=5e-3): 9 | super(BarlowTwinsLoss, self).__init__() 10 | self.lambda_param = lambda_param 11 | 12 | def forward(self, z_a: torch.Tensor, z_b: torch.Tensor): 13 | # normalize repr. along the batch dimension 14 | z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD 15 | z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD 16 | 17 | N = z_a.size(0) 18 | D = z_a.size(1) 19 | 20 | # cross-correlation matrix 21 | c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD 22 | # loss 23 | c_diff = (c - torch.eye(D, device="cuda")).pow(2) # DxD 24 | # multiply off-diagonal elems of c_diff by lambda 25 | c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param 26 | 27 | loss = c_diff.sum() 28 | 29 | return loss 30 | 31 | 32 | class MyBarlowTwinsLoss(nn.Module): 33 | def __init__(self, feature_dim, num_out_dim): 34 | super(MyBarlowTwinsLoss, self).__init__() 35 | self.in_feats = feature_dim 36 | self.W = torch.nn.Parameter(torch.randn(feature_dim, num_out_dim)) 37 | nn.init.xavier_normal_(self.W, gain=1) 38 | self.fun_barlow = BarlowTwinsLoss() 39 | 40 | def forward(self, v_emb, f_emb): 41 | v_out = torch.mm(v_emb, self.W) 42 | f_out = torch.mm(f_emb, self.W) 43 | loss = self.fun_barlow(v_out, f_out) 44 | return loss 45 | -------------------------------------------------------------------------------- /utils/losses/center_loss_eccv16.py: -------------------------------------------------------------------------------- 1 | # https://github.com/jxgu1016/MNIST_center_loss_pytorch 2 | # A Discriminative Feature Learning Approach for Deep Face Recognition,ECCV,2016 3 | # import ipdb 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd.function import Function 7 | 8 | 9 | class CenterLoss(nn.Module): 10 | def __init__(self, num_classes, feat_dim, size_average=True): 11 | super(CenterLoss, self).__init__() 12 | self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) 13 | self.centerlossfunc = CenterlossFunc.apply 14 | self.feat_dim = feat_dim 15 | self.size_average = size_average 16 | 17 | def forward(self, feat, label): 18 | batch_size = feat.size(0) 19 | feat = feat.view(batch_size, -1) 20 | # To check the dim of centers and features 21 | # ipdb.set_trace() 22 | if feat.size(1) != self.feat_dim: 23 | raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim, feat.size(1))) 24 | batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1) 25 | loss = self.centerlossfunc(feat, label, self.centers, batch_size_tensor) 26 | return loss 27 | 28 | 29 | class CenterlossFunc(Function): 30 | @staticmethod 31 | def forward(ctx, feature, label, centers, batch_size): 32 | ctx.save_for_backward(feature, label, centers, batch_size) 33 | centers_batch = centers.index_select(0, label.long()) 34 | return (feature - centers_batch).pow(2).sum() / 2.0 / batch_size 35 | 36 | @staticmethod 37 | def backward(ctx, grad_output): 38 | feature, label, centers, batch_size = ctx.saved_tensors 39 | centers_batch = centers.index_select(0, label.long()) 40 | diff = centers_batch - feature 41 | # init every iteration 42 | counts = centers.new_ones(centers.size(0)) 43 | ones = centers.new_ones(label.size(0)) 44 | grad_centers = centers.new_zeros(centers.size()) 45 | 46 | counts = counts.scatter_add_(0, label.long(), ones) 47 | grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff) 48 | grad_centers = grad_centers / counts.view(-1, 1) 49 | return - grad_output * diff / batch_size, None, grad_centers / batch_size, None 50 | 51 | 52 | def main(test_cuda=False): 53 | print('-' * 80) 54 | device = torch.device("cuda" if test_cuda else "cpu") 55 | ct = CenterLoss(10, 2, size_average=True).to(device) 56 | y = torch.Tensor([0, 0, 2, 1]).to(device) 57 | feat = torch.zeros(4, 2).to(device).requires_grad_() 58 | print(list(ct.parameters())) 59 | print(ct.centers.grad) 60 | out = ct(y, feat) 61 | print(out.item()) 62 | out.backward() 63 | print(ct.centers.grad) 64 | print(feat.grad) 65 | 66 | 67 | if __name__ == '__main__': 68 | torch.manual_seed(999) 69 | main(test_cuda=False) 70 | if torch.cuda.is_available(): 71 | main(test_cuda=True) 72 | 73 | # centerloss = CenterLoss(num_classes=10, feat_dim=2).to(device) 74 | -------------------------------------------------------------------------------- /utils/losses/center_loss_learnableW_L2dist.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | # 单纯的center loss,其中"类中心矩阵"是可以训练的参数 7 | class CenterLoss(nn.Module): 8 | 9 | def __init__(self, num_classes, feat_dim, use_gpu): 10 | super(CenterLoss, self).__init__() 11 | self.num_classes = num_classes 12 | self.feat_dim = feat_dim 13 | self.use_gpu = use_gpu 14 | 15 | if self.use_gpu: 16 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 17 | else: 18 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 19 | 20 | def forward(self, x, labels): 21 | """ 22 | Args: 23 | x: feature matrix with shape (batch_size, feat_dim). 24 | labels: ground truth labels with shape (batch_size). 25 | """ 26 | batch_size = x.size(0) 27 | 28 | # 1 计算出与类中心之间的 欧几里得距离 29 | # 1)先计算a^2+b^2 30 | A = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) 31 | B = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 32 | distmat = A + B 33 | 34 | # 2)进一步-2ab 35 | distmat.addmm_(1, -2, x, self.centers.t()) 36 | 37 | # 2.将位于真实label位置处的距离抠出来: 38 | classes = torch.arange(self.num_classes).long() 39 | if self.use_gpu: 40 | classes = classes.cuda() 41 | 42 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 43 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 44 | dist = distmat * mask.float() 45 | 46 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 47 | 48 | return loss 49 | 50 | 51 | if __name__ == "__main__": 52 | from utils import seed_util 53 | 54 | seed_util.set_seed(10086) 55 | batch_size = 3 56 | feature_dim = 128 57 | num_class = 4 58 | 59 | embedding = torch.rand([batch_size, feature_dim]) 60 | label = torch.LongTensor([0, 1, 3]) 61 | 62 | fun = CenterLoss(num_class, feature_dim, use_gpu=False) 63 | loss = fun(embedding, label) 64 | print(loss) 65 | -------------------------------------------------------------------------------- /utils/losses/cmpc_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class IR_CMPC(nn.Module): 8 | def __init__(self, temperature, delta, ka, R): 9 | super(IR_CMPC, self).__init__() 10 | self.inst_inst_criterion = InstInstCLR(temperature=temperature).cuda() 11 | self.inst_proto_criterion = InstProtoCLR(temperature=temperature).cuda() 12 | self.delta = delta 13 | self.ka = ka 14 | self.R = R 15 | 16 | def forward(self, audio_emb, frame_emb, audio_cluster_result, frame_cluster_result, video_index): 17 | bs = audio_emb.size(0) 18 | 19 | loss_v2f_ii = self.inst_inst_criterion(audio_emb, frame_emb) 20 | loss_f2v_ii = self.inst_inst_criterion(frame_emb, audio_emb) 21 | 22 | loss_v2f_ip = self.inst_proto_criterion(audio_emb, frame_cluster_result, video_index) 23 | loss_f2v_ip = self.inst_proto_criterion(frame_emb, audio_cluster_result, video_index) 24 | 25 | # print(loss_v2f_ii.shape, loss_f2v_ii.shape, loss_v2f_ip.shape, loss_f2v_ip.shape) 26 | w = torch.ones(bs).cuda() 27 | if frame_cluster_result is not None: 28 | features_audio = F.normalize(audio_emb, dim=1) 29 | features_frame = F.normalize(frame_emb, dim=1) 30 | audio_frame_matrix = torch.mm(features_audio, features_frame.transpose(0, 1)).detach().cpu().numpy() 31 | 32 | inst2cluster_matrix = np.zeros((bs, bs)) 33 | 34 | for i in range(self.R): 35 | inst2cluster_voice = audio_cluster_result['inst2cluster'][i][video_index] 36 | inst2cluster_face = frame_cluster_result['inst2cluster'][i][video_index] 37 | inst2cluster_matrix += ( 38 | torch.mm( 39 | audio_cluster_result['centroids'][i][inst2cluster_voice], 40 | frame_cluster_result['centroids'][i][inst2cluster_face].transpose(0, 1), 41 | ) 42 | .cpu() 43 | .numpy() 44 | ) 45 | 46 | rho = audio_frame_matrix - inst2cluster_matrix 47 | rho = rho.diagonal() 48 | 49 | sorted_rho = np.sort(rho) 50 | argsort_rho = np.argsort(rho) 51 | 52 | mu = rho.mean() + self.delta * rho.std() 53 | sigma = rho.std() * self.ka ** (1 / 2) 54 | 55 | y = (1 / (np.sqrt(2 * np.pi) * sigma)) * np.exp(-0.5 * (1 / sigma * (sorted_rho - mu)) ** 2) 56 | y = y.cumsum() 57 | y /= y[-1] 58 | 59 | w = y[np.argsort(argsort_rho)] 60 | w = torch.tensor(w).cuda() 61 | 62 | loss_v2f = torch.sum(w * (loss_v2f_ii + loss_v2f_ip)) / torch.sum(w) 63 | loss_f2v = torch.sum(w * (loss_f2v_ii + loss_f2v_ip)) / torch.sum(w) 64 | 65 | loss = loss_v2f + loss_f2v 66 | 67 | return loss 68 | 69 | 70 | class InstInstCLR(nn.Module): 71 | def __init__(self, temperature): 72 | super(InstInstCLR, self).__init__() 73 | self.ce = nn.CrossEntropyLoss(reduction='none') 74 | self.temperature = temperature 75 | 76 | def forward(self, anchor, pos): 77 | batch_size = anchor.size(0) 78 | 79 | anchor = F.normalize(anchor) # (bs, out_dim) 80 | 81 | pos = F.normalize(pos) # (bs, out_dim) 82 | 83 | similarity_matrix = torch.matmul(anchor, pos.T) # (bs, bs) 84 | # mask the main diagonal for positives 85 | mask = torch.eye(batch_size, dtype=torch.bool) # (bs, bs) 86 | 87 | assert similarity_matrix.shape == mask.shape 88 | 89 | # select and combine multiple positives 90 | positives = similarity_matrix[mask].view(batch_size, -1) # (bs, 1) 91 | 92 | # select only the negatives the negatives 93 | 94 | negatives = similarity_matrix[~mask].view(batch_size, -1) # (bs, bs-1) 95 | # combine pos and neg 96 | logits = torch.cat([positives, negatives], dim=1) # (bs, bs) 97 | 98 | labels = torch.zeros(batch_size, dtype=torch.long).cuda() # (bs) 99 | 100 | logits = logits / self.temperature 101 | 102 | loss = self.ce(logits, labels) 103 | return loss 104 | 105 | 106 | class InstProtoCLR(nn.Module): 107 | def __init__(self, temperature): 108 | super(InstProtoCLR, self).__init__() 109 | self.ce = nn.CrossEntropyLoss(reduction='none') 110 | self.temperature = temperature 111 | 112 | def forward(self, anchor, cluster_result=None, index=None): 113 | batch_size = anchor.size(0) 114 | 115 | loss_proto = torch.zeros(batch_size).cuda() 116 | if cluster_result is None: 117 | return loss_proto 118 | 119 | anchor = F.normalize(anchor) # (bs, out_dim) 120 | 121 | for n, (inst2cluster, prototypes, density) in enumerate(zip(cluster_result['inst2cluster'], cluster_result['centroids'], cluster_result['density'])): 122 | prototypes = F.normalize(prototypes) 123 | 124 | # get positive prototypes 125 | 126 | pos_proto_id = inst2cluster[index] 127 | pos_prototypes = prototypes[pos_proto_id] 128 | 129 | # embed() 130 | proto_similarity_matrix = torch.matmul(anchor, pos_prototypes.T) # [bs, dim]x[dim, bs]->[bs, bs] 131 | # mask the main diagonal for positives 132 | mask = torch.eye(batch_size, dtype=torch.bool) # (bs, bs) 133 | assert proto_similarity_matrix.shape == mask.shape 134 | 135 | # select and combine multiple positives 136 | proto_positives = proto_similarity_matrix[mask].view(batch_size, -1) # (bs, 1) 137 | # select only the negatives the negatives 138 | proto_negatives = proto_similarity_matrix[~mask].view(batch_size, -1) # (bs, bs-1) 139 | # combine pos and neg 140 | proto_logits = torch.cat([proto_positives, proto_negatives], dim=1) # (bs, bs) 141 | # targets for prototype assignment 142 | proto_labels = torch.zeros(batch_size, dtype=torch.long).cuda() # (bs) 143 | 144 | # scaling temperatures for the selected prototypes 145 | temp_proto = density[pos_proto_id] 146 | proto_logits /= temp_proto 147 | loss_proto += self.ce(proto_logits, proto_labels) 148 | 149 | # average loss across all sets of prototypes 150 | loss_proto /= len(cluster_result) 151 | 152 | return loss_proto 153 | -------------------------------------------------------------------------------- /utils/losses/fop_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class OrthogonalProjectionLoss(nn.Module): 7 | def __init__(self): 8 | super(OrthogonalProjectionLoss, self).__init__() 9 | self.device = torch.device('cuda') 10 | 11 | def forward(self, features, labels=None): 12 | features = F.normalize(features, p=2, dim=1) 13 | 14 | labels = labels[:, None] 15 | 16 | mask = torch.eq(labels, labels.t()).bool().to(self.device) 17 | eye = torch.eye(mask.shape[0], mask.shape[1]).bool().to(self.device) 18 | 19 | mask_pos = mask.masked_fill(eye, 0).float() 20 | mask_neg = (~mask).float() 21 | dot_prod = torch.matmul(features, features.t()) 22 | 23 | pos_pairs_mean = (mask_pos * dot_prod).sum() / (mask_pos.sum() + 1e-6) 24 | neg_pairs_mean = torch.abs(mask_neg * dot_prod).sum() / (mask_neg.sum() + 1e-6) 25 | 26 | loss = (1.0 - pos_pairs_mean) + (0.7 * neg_pairs_mean) 27 | 28 | return loss 29 | -------------------------------------------------------------------------------- /utils/losses/my_pml_infonce_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ipdb 3 | import numpy 4 | 5 | 6 | # 目标:只是扩充ap、an set,不要增加过多的a 7 | 8 | def my_info_nce(anchor, pos_neg, l_anchor, l_pos_neg, temperature=0.07, reduction="mean"): 9 | # 1.获得可行的组合 10 | a1, p, a2, n = get_all_pairs_indices(l_anchor, l_pos_neg) 11 | # (a1,p) AP样本对在sim_matrix矩阵中对应的(行下标,列下标) 12 | # (a2,n) AN样本对在sim_matrix矩阵中对应的(行下标,列下标) 13 | 14 | # 2.计算相似度矩阵 15 | sim_matrix = cosine_similarity(anchor, pos_neg) 16 | 17 | # 3.获得ap、an的相似度值 18 | pos_pairs = sim_matrix[a1, p] 19 | # [total_AP] 此数量作为最终的batch数 20 | neg_pairs = sim_matrix[a2, n] 21 | # [total_AN] 22 | 23 | # 4.相似度值分别除以温度系数 24 | pos_pairs = pos_pairs.unsqueeze(1) / temperature 25 | neg_pairs = neg_pairs / temperature 26 | 27 | # 5.为每个(a,p)构建对应的负样本集合 28 | 29 | # 5.1 先假定假设所有的负样本都是可行的,因此获得 [total_AP,total_AN] 大小的矩阵 30 | neg_pairs_2D = neg_pairs.repeat(len(pos_pairs), 1) 31 | 32 | # 5.2 创建mask,将隶属于同一个anchor的(a,p)(a,n)处为1,其余为0 33 | mask = a1.unsqueeze(1) == a2.unsqueeze(0) 34 | # ipdb.set_trace() 35 | # [total_AP,total_AN] 36 | 37 | # 5.3 将0处的score值设置为负无穷,即不考虑这些位置 38 | neg_pairs_2D[mask == 0] = torch.finfo(torch.float32).min 39 | 40 | # 6.相当于分子、分母都有了,然后计算-log(exp())即可 41 | loss = neg_log_exp(pos_pairs, neg_pairs_2D) 42 | # ipdb.set_trace() 43 | if reduction == "mean": 44 | return loss.mean() 45 | return loss, a1 46 | 47 | 48 | def neg_log_exp(pos_pairs, neg_pairs): 49 | # ipdb.set_trace() 50 | max_val = torch.max(pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0]).detach() 51 | 52 | numerator = torch.exp(pos_pairs - max_val).squeeze(1) 53 | 54 | denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator 55 | 56 | p = numerator / denominator 57 | 58 | loss = - torch.log(p + torch.finfo(torch.float32).tiny) 59 | return loss 60 | 61 | 62 | def get_all_pairs_indices(labels_anchor, labels_ref): 63 | # 根据label,生成所有的ap、an搭配 64 | labels1 = labels_anchor.unsqueeze(1) 65 | labels2 = labels_ref.unsqueeze(0) 66 | matches = (labels1 == labels2).byte() 67 | # [4,10] 68 | 69 | diffs = matches ^ 1 70 | # matches.fill_diagonal_(0) # 这里我注释掉了 71 | a1_idx, p_idx = torch.where(matches) 72 | a2_idx, n_idx = torch.where(diffs) 73 | 74 | return a1_idx, p_idx, a2_idx, n_idx 75 | 76 | 77 | def cosine_similarity(a, pn): 78 | a = torch.nn.functional.normalize(a) 79 | pn = torch.nn.functional.normalize(pn) 80 | similarity_matrix = torch.matmul(a, pn.T) 81 | return similarity_matrix 82 | 83 | 84 | class InfoNCE(torch.nn.Module): 85 | def __init__(self, temperature, reduction): 86 | super(InfoNCE, self).__init__() 87 | self.temperature = temperature 88 | self.reduction = reduction 89 | 90 | def forward(self, emb_anchor, emb_pn, label_anchor, label_pn): 91 | return my_info_nce(emb_anchor, emb_pn, label_anchor, label_pn, self.temperature, self.reduction) 92 | 93 | 94 | def set_seed(seed): 95 | random.seed(seed) 96 | numpy.random.seed(seed) 97 | 98 | torch.manual_seed(seed) 99 | torch.cuda.manual_seed(seed) 100 | 101 | torch.backends.cudnn.deterministic = True 102 | torch.backends.cudnn.benchmark = False 103 | # https://pytorch.org/docs/stable/notes/randomness.html 104 | 105 | 106 | if __name__ == "__main__": 107 | import random 108 | import numpy 109 | import torch 110 | 111 | numpy.set_printoptions(linewidth=180, precision=5, suppress=True) 112 | torch.set_printoptions(linewidth=180, sci_mode=False) 113 | set_seed(10086) 114 | emb_dim = 2 115 | embedding_anchor = torch.rand([10, emb_dim]) 116 | embedding_pn = torch.rand([20, emb_dim]) 117 | l1 = torch.LongTensor([1, 2, 1, 3, 0, 1, 2, 1, 2, 0]) 118 | l2 = torch.LongTensor([1, 2, 1, 3, 0, 1, 2, 1, 2, 0, 1, 2, 1, 3, 0, 1, 2, 1, 2, 0]) 119 | 120 | print(my_info_nce(embedding_anchor, embedding_pn, l1, l2)) 121 | -------------------------------------------------------------------------------- /utils/losses/softmax_loss.py: -------------------------------------------------------------------------------- 1 | # 带参数的cel损失 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class SoftmaxLoss(nn.Module): 7 | def __init__(self, feature_dim, num_class, reduction="mean"): 8 | super(SoftmaxLoss, self).__init__() 9 | self.in_feats = feature_dim 10 | self.W = torch.nn.Parameter(torch.randn(feature_dim, num_class)) 11 | self.cel = nn.CrossEntropyLoss(reduction=reduction) 12 | nn.init.xavier_normal_(self.W, gain=1) 13 | 14 | def forward(self, embedding, labels): 15 | assert embedding.size()[0] == labels.size()[0] 16 | assert embedding.size()[1] == self.in_feats 17 | logits = torch.mm(embedding, self.W) 18 | loss = self.cel(logits, labels) 19 | return loss 20 | 21 | 22 | if __name__ == "__main__": 23 | feature_dim = 128 24 | num_class = 10 25 | batch_size = 3 26 | 27 | from utils import seed_util 28 | 29 | seed_util.set_seed(10086) 30 | embedding = torch.randn(batch_size, feature_dim) 31 | label = torch.randint(0, num_class, (batch_size,), dtype=torch.long) 32 | 33 | criteria = MySoftmaxLoss(feature_dim, num_class).cuda() 34 | loss = criteria(embedding.cuda(), label.cuda()) 35 | print(loss) 36 | -------------------------------------------------------------------------------- /utils/losses/triplet_hq1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ipdb 3 | 4 | def l2_distance(v, f): 5 | # [batch,emb] -> [batch,1] 6 | return torch.sqrt((v - f).pow(2).sum(dim=1, keepdim=True) + 1e-12) 7 | 8 | 9 | def triplet_loss(v_origin, f_origin, label, alpha=0.6, beta=0.2, lamd=0.1): 10 | # 变为单位向量 11 | v = torch.nn.functional.normalize(v_origin, p=2, dim=1) 12 | f = torch.nn.functional.normalize(f_origin, p=2, dim=1) 13 | 14 | label = label.squeeze() 15 | # [batch] 16 | 17 | # 计算欧几里得距离 18 | l2_dist_matrix = torch.cdist(v, f) 19 | dis_ap = torch.diagonal(l2_dist_matrix) 20 | 21 | # 寻找Hardest样本: 22 | mask = (label.unsqueeze(1) == label.unsqueeze(0)).byte() 23 | # 相同标签的部分值为1 24 | MAX_VAL = l2_dist_matrix.max() 25 | tmp_l2_dist_matrix = mask * MAX_VAL + l2_dist_matrix 26 | 27 | # 然后找这一行中的最小值: 28 | dis_hardest_an, indices = tmp_l2_dist_matrix.min(dim=1) 29 | 30 | # ipdb.set_trace() 31 | # 第一部分的损失 32 | loss1 = torch.nn.functional.relu(alpha + dis_ap - dis_hardest_an) 33 | 34 | # 第二部分的损失 35 | # emb_hardest = f[indices] 36 | # dis_pn = l2_distance(f, emb_hardest) 37 | # loss2 = torch.nn.functional.relu(beta - dis_pn) 38 | 39 | # loss = loss1 + lamd * loss2 40 | loss = loss1 41 | return loss.mean() 42 | 43 | 44 | if __name__ == "__main__": 45 | from utils import seed_util 46 | 47 | seed_util.set_seed(1) 48 | import ipdb 49 | 50 | batch_size = 5 51 | v = torch.rand(batch_size, 10).cuda() 52 | f = torch.rand(batch_size, 10).cuda() 53 | label = torch.LongTensor([0, 0, 1, 2, 3]).cuda() 54 | triplet_loss(v, f, label) 55 | -------------------------------------------------------------------------------- /utils/losses/triplet_lafv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import ipdb 4 | 5 | 6 | def triplet_loss(v, f_p, f_n): 7 | d_ap = l2_distance(v, f_p) 8 | d_an = l2_distance(v, f_n) 9 | cated = torch.cat([d_ap, d_an], dim=1) 10 | probs = torch.nn.functional.softmax(cated, dim=1) 11 | batch_size = len(v) 12 | target = torch.FloatTensor([[0, 1]] * batch_size).cuda() 13 | loss = l2_distance(probs, target) 14 | loss_mean = loss.mean() 15 | return loss_mean 16 | 17 | 18 | def l2_distance(v, f): 19 | # [batch,emb] -> [batch,1] 20 | return torch.sqrt(torch.sum((v - f) ** 2, dim=1, keepdim=True) + 1e-12) 21 | 22 | 23 | if __name__ == "__main__": 24 | from utils import seed_util 25 | 26 | seed_util.set_seed(1) 27 | import ipdb 28 | 29 | v = torch.rand(4, 10).cuda() 30 | fp = torch.rand(4, 10).cuda() 31 | fn = torch.rand(4, 10).cuda() 32 | triplet_loss(v, fp, fn) 33 | -------------------------------------------------------------------------------- /utils/losses/unsup_nce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class InfoNCE(nn.Module): 8 | def __init__(self, temperature, reduction="mean"): 9 | super(InfoNCE, self).__init__() 10 | self.cel = nn.CrossEntropyLoss(reduction=reduction) 11 | self.temperature = temperature 12 | 13 | # 两个batch是平行对应的 14 | def forward(self, anchor, positive, need_logits=False): 15 | batch_size = anchor.size(0) 16 | # 1.变成单位向量 17 | anchor = F.normalize(anchor) 18 | positive = F.normalize(positive) 19 | 20 | # 2.计算相似度矩阵,也即logits 21 | similarity_matrix = torch.matmul(anchor, positive.T) # (bs, bs) 22 | 23 | # 3.对角线元素为"分类目标" 24 | logits = similarity_matrix / self.temperature 25 | labels = torch.LongTensor([i for i in range(batch_size)]).cuda() 26 | loss = self.cel(logits, labels) 27 | if need_logits: 28 | return loss, logits 29 | return loss 30 | -------------------------------------------------------------------------------- /utils/losses/wen_explicit_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Loss function of equation (10) of paper "Seeking the Shape of Sound: An Adaptive Framework for Learning Voice-Face Association" 7 | 8 | # From https://github.com/KID-7391/seeking-the-shape-of-sound models/backbones.py 9 | 10 | def cross_logit(x, v): 11 | # x、v就是两组embedding向量,比如声音、人脸向量,同一行对应的是同一个人的 12 | dist = l2dist(F.normalize(x).unsqueeze(0), v.unsqueeze(1)) 13 | # [batch,batch] 14 | # 默认情况下,x、v都不是单位向量, F.normalize将x变为单位向量 15 | 16 | one_hot = torch.zeros(dist.size()).to(x.device) 17 | # 全0矩阵,[batch,batch] 18 | 19 | one_hot.scatter_(1, torch.arange(len(x)).view(-1, 1).long().to(x.device), 1) 20 | # 将对角线变为1 [\] 21 | 22 | pos = (one_hot * dist).sum(-1, keepdim=True) 23 | # 将那个对角线上的值取出来,这个是同一个人的声音与人脸向量 24 | 25 | logit = (1.0 - one_hot) * (dist - pos) 26 | # "不同人音、脸的距离",比"同一人" 音脸之间的距离大多少 27 | 28 | loss = torch.log(1 + torch.exp(logit).sum(-1) + 3.4) 29 | 30 | return loss 31 | 32 | 33 | def l2dist(a, b): 34 | # L2 distance 35 | dist = (a * b).sum(-1) 36 | return dist 37 | -------------------------------------------------------------------------------- /utils/losses/wen_reweight.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | 4 | import ipdb 5 | import numpy as np 6 | 7 | 8 | def not_zero_count(the_dict): 9 | return len([v for v in the_dict.values() if not np.isclose(v, 0.0)]) 10 | 11 | 12 | # old_weights: 13 | # 1) empty: init dict; the bottom 30% to 1.0, others to 0.0 14 | # 2)not empty: set 22 identities weights to 1.0 where "old_weight==0 " ; others weight=weight*0.99 15 | def update_weight(old_weights, hardness, init_ratio=0.3, k=22, alpha=0.99): 16 | # 1.sorted by hardness 17 | # [hard,....,easy] 18 | tup_list = [(k, v) for k, v in hardness.items()] 19 | tup_list.sort(key=lambda x: x[-1], reverse=True) 20 | sorted_identities_list = [tup[0] for tup in tup_list] 21 | 22 | if len(old_weights) == 0: # init 23 | print("init weight dict") 24 | new_weights = {} 25 | rare_point = int(len(sorted_identities_list) * (1 - init_ratio)) 26 | for i in range(len(sorted_identities_list)): 27 | if i > rare_point: 28 | new_weights[i] = 1 29 | else: 30 | new_weights[i] = 0 31 | return new_weights 32 | 33 | # 2.find identities with zero weights: 34 | # [hard,....,easy] 35 | zero_weight_identities = [i for i in sorted_identities_list if np.isclose(old_weights[i], 0.0)] 36 | 37 | # 3.find the easiest 22 identities: 38 | if len(zero_weight_identities) > k: 39 | zero_weight_identities_bottom = zero_weight_identities[int(len(zero_weight_identities) - k):] 40 | assert len(zero_weight_identities_bottom) == k 41 | else: 42 | zero_weight_identities_bottom = zero_weight_identities 43 | zero_weight_identities_bottom = set(zero_weight_identities_bottom) 44 | 45 | # 4.assign new weights: 46 | new_weights = copy.deepcopy(old_weights) 47 | for k, v in new_weights.items(): 48 | if k in zero_weight_identities_bottom: 49 | new_v = 1.0 50 | else: 51 | new_v = v * alpha 52 | new_weights[k] = new_v 53 | return new_weights 54 | 55 | 56 | def update_hardness(old_hardness, label_list, loss_list, beta=0.9): 57 | new_hardness = copy.deepcopy(old_hardness) 58 | for label, loss in zip(label_list, loss_list): 59 | old_h = old_hardness.get(label, loss) 60 | new_h = beta * old_h + (1 - beta) * loss 61 | new_hardness[label] = new_h 62 | return new_hardness 63 | -------------------------------------------------------------------------------- /utils/map_evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.spatial 4 | 5 | 6 | def cos_dist(query_matrix, result_matrix): 7 | return scipy.spatial.distance.cdist(query_matrix, result_matrix, 'cosine') 8 | 9 | 10 | def fx_calc_map_label(query_matrix, result_matrix, labels, k=0, dist_method='COS'): 11 | if dist_method == 'L2': 12 | dist = scipy.spatial.distance.cdist(query_matrix, result_matrix, 'euclidean') 13 | elif dist_method == 'COS': 14 | dist = scipy.spatial.distance.cdist(query_matrix, result_matrix, 'cosine') 15 | ord = dist.argsort() 16 | numcases = dist.shape[0] 17 | if k == 0: 18 | k = numcases 19 | res = [] 20 | 21 | for i in range(numcases): 22 | order = ord[i] 23 | p = 0.0 24 | r = 0.0 25 | for j in range(k): 26 | if labels[i] == labels[order[j]]: 27 | r += 1 28 | p += (r / (j + 1)) 29 | if r > 0: 30 | res += [p / r] 31 | else: 32 | res += [0] 33 | 34 | return np.mean(res) 35 | 36 | 37 | def fx_calc_map_label_v2(dist, label, k=0): 38 | ord = dist.argsort() 39 | numcases = dist.shape[0] 40 | if k == 0: 41 | k = numcases 42 | res = [] 43 | 44 | for i in range(numcases): 45 | order = ord[i] 46 | p = 0.0 47 | r = 0.0 48 | for j in range(k): 49 | if label[i] == label[order[j]]: 50 | r += 1 51 | p += (r / (j + 1)) 52 | if r > 0: 53 | res += [p / r] 54 | else: 55 | res += [0] 56 | 57 | return np.mean(res) 58 | 59 | 60 | def fx_calc_map_label_v3(ord, label, k=0): 61 | numcases = ord.shape[0] 62 | if k == 0: 63 | k = numcases 64 | res = [] 65 | 66 | for i in range(numcases): 67 | order = ord[i] 68 | p = 0.0 69 | r = 0.0 70 | for j in range(k): 71 | if label[i] == label[order[j]]: 72 | r += 1 73 | p += (r / (j + 1)) 74 | if r > 0: 75 | res += [p / r] 76 | else: 77 | res += [0] 78 | 79 | return np.mean(res) 80 | 81 | 82 | if __name__ == "__main__": 83 | pass 84 | -------------------------------------------------------------------------------- /utils/model_selector.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | 4 | 5 | class ModelSelector: 6 | 7 | def __init__(self): 8 | self.history = collections.defaultdict(list) 9 | 10 | def log(self, the_dict): 11 | for key, value in the_dict.items(): 12 | self.history[key].append(value) 13 | best_info = {} 14 | for key in self.history: 15 | # valid/ms_fv 16 | # => best-valid/valid_ms_fv 17 | best_info["best-" + key] = max(self.history[key]) 18 | return best_info 19 | 20 | def is_best_model(self, indicator): 21 | assert indicator in self.history 22 | arr = self.history[indicator] 23 | return np.argmax(arr) == (len(arr) - 1) 24 | 25 | def should_stop(self, indicator, early_stop=10): 26 | arr = self.history[indicator] 27 | if len(arr) - 1 - np.argmax(arr) >= early_stop: 28 | return True 29 | return False 30 | 31 | 32 | def get_best_step_info(self, indictor, print_it=True): 33 | index = np.argmax(self.history[indictor]) 34 | ans = {} 35 | if print_it: 36 | for key in self.history: 37 | v = self.history[key][index] 38 | print("%s\t%.4f" % (key, v)) 39 | ans[key] = v 40 | return ans 41 | -------------------------------------------------------------------------------- /utils/model_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | import sys 5 | from utils import pickle_util 6 | 7 | history_array = [] 8 | 9 | 10 | def save_model(epoch, model, optimizer, file_save_path): 11 | dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir)) 12 | if not os.path.exists(dirpath): 13 | print("mkdir:", dirpath) 14 | os.makedirs(dirpath) 15 | 16 | opti = None 17 | if optimizer is not None: 18 | opti = optimizer.state_dict() 19 | 20 | torch.save(obj={ 21 | 'epoch': epoch, 22 | 'model': model.state_dict(), 23 | 'optimizer': opti, 24 | }, f=file_save_path) 25 | 26 | history_array.append(file_save_path) 27 | 28 | 29 | def delete_last_saved_model(): 30 | if len(history_array) == 0: 31 | return 32 | last_path = history_array.pop() 33 | if os.path.exists(last_path): 34 | os.remove(last_path) 35 | print("delete model:", last_path) 36 | 37 | if os.path.exists(last_path + ".json"): 38 | os.remove(last_path + ".json") 39 | 40 | 41 | def load_model(resume_path, model, optimizer=None, strict=True): 42 | checkpoint = torch.load(resume_path) 43 | start_epoch = checkpoint['epoch'] + 1 44 | model.load_state_dict(checkpoint['model'], strict=strict) 45 | if optimizer is not None: 46 | optimizer.load_state_dict(checkpoint['optimizer']) 47 | print("checkpoint loaded!") 48 | return start_epoch 49 | 50 | 51 | def save_model_v2(model, args, model_save_name): 52 | model_save_path = os.path.join(args.model_save_folder, args.project, args.name, model_save_name) 53 | save_model(0, model, None, model_save_path) 54 | print("save:", model_save_path) 55 | 56 | 57 | def save_project_info(args): 58 | run_info = { 59 | "cmd_str": ' '.join(sys.argv[1:]), 60 | "args": vars(args), 61 | } 62 | 63 | name = "run_info.json" 64 | folder = os.path.join(args.model_save_folder, args.project, args.name) 65 | if not os.path.exists(folder): 66 | os.makedirs(folder) 67 | 68 | json_file_path = os.path.join(folder, name) 69 | with open(json_file_path, "w") as f: 70 | json.dump(run_info, f) 71 | 72 | print("save_project_info:", json_file_path) 73 | 74 | 75 | def get_pkl_json(folder): 76 | names = [i for i in os.listdir(folder) if ".pkl.json" in i] 77 | assert len(names) == 1 78 | json_path = os.path.join(folder, names[0]) 79 | obj = pickle_util.read_json(json_path) 80 | return obj 81 | 82 | 83 | # ======================================= 分析 84 | import numpy as np 85 | import os 86 | import json 87 | 88 | 89 | # 不同seed,结果聚合为一个 90 | # name_start_with: 通过名字来过滤目标文件夹 91 | # 附加了std 92 | def result_ensemble(project_folder, name_start_with, format="mean_std"): 93 | tmp_arr = [] 94 | for run_name in os.listdir(project_folder): 95 | if run_name.startswith(name_start_with): # 排除不对的 96 | obj = find_json(os.path.join(project_folder, run_name)) 97 | tmp_arr.append(obj) 98 | 99 | assert len(tmp_arr) > 1 100 | new_obj = { 101 | "name": name_start_with, 102 | } 103 | for key in tmp_arr[0].keys(): 104 | vs = [o[key] for o in tmp_arr] 105 | mean = np.mean(vs) 106 | std = np.std(vs) 107 | if format == "mean_std": 108 | txt = "%.1f±%.1f" % (mean, std) 109 | else: 110 | txt = "%.1f" % (mean) 111 | new_obj[key] = txt 112 | return new_obj 113 | 114 | 115 | def find_json(the_path): 116 | files = os.listdir(the_path) 117 | files = [f for f in files if f.endswith('.json')] 118 | assert len(files) == 1 119 | with open(os.path.join(the_path, files[0]), 'r') as f: 120 | json_data = json.load(f) 121 | return json_data 122 | 123 | 124 | def array_objs_merge_to_single_obj(array): 125 | new_obj = {} 126 | for key in array[0].keys(): 127 | new_obj[key] = float("%.2f" % (np.mean([o[key] for o in array]) * 100)) 128 | return new_obj 129 | -------------------------------------------------------------------------------- /utils/my_git.py: -------------------------------------------------------------------------------- 1 | # 11月27日 2 | # 11月28日:加入了get git info 3 | 4 | import git 5 | # pip install gitpython 6 | 7 | is_commited = False 8 | 9 | 10 | def commit_v2(args): 11 | import datetime 12 | date_str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 13 | content = "%s %s %s" % (args.project, args.name, date_str) 14 | commit(content) 15 | # 应该不需要更新内容,因为此时一般wb还没有初始化,更新了也没用。 16 | 17 | 18 | def commit(content): 19 | global is_commited 20 | if is_commited: 21 | # 如果已经commit了一次,就跳过 22 | return 23 | 24 | do_real_commit(content) 25 | is_commited = True 26 | return get_git_info() 27 | 28 | 29 | def do_real_commit(content): 30 | repo = git.Repo(search_parent_directories=True) 31 | try: 32 | g = repo.git 33 | g.add("--all") 34 | res = g.commit("-m " + content) 35 | print(res) 36 | except Exception as e: 37 | print("无需commit") 38 | 39 | 40 | def get_git_info(): 41 | repo = git.Repo(search_parent_directories=True) 42 | sha = repo.head.object.hexsha 43 | branch = str(repo.active_branch) 44 | return {"branch": branch, "git_id": sha} 45 | -------------------------------------------------------------------------------- /utils/my_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | 4 | 5 | class MyParser(): 6 | 7 | def __init__(self, epoch, batch_size, worker=0, seed=2526, 8 | max_hour=100, early_stop=5, lr=1e-4, 9 | model_save_folder=None): 10 | super(MyParser, self).__init__() 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--seed", default=seed, type=int) 13 | parser.add_argument("--worker", default=worker, type=int) 14 | parser.add_argument("--epoch", default=epoch, type=int) 15 | parser.add_argument("--batch_size", default=batch_size, type=int) 16 | parser.add_argument("--max_hour", default=max_hour, type=int) 17 | parser.add_argument("--early_stop", default=early_stop, type=int) 18 | parser.add_argument("--lr", default=lr, type=float) 19 | parser.add_argument("--model_save_folder", default=model_save_folder, type=str) 20 | self.core_parser = parser 21 | 22 | def use_wb(self, project, name, dryrun=False): 23 | self.project = project 24 | self.name = name 25 | self.dryrun = dryrun 26 | parser = self.core_parser 27 | parser.add_argument("--project", default=self.project, type=str) 28 | parser.add_argument("--name", default=self.name, type=str) 29 | parser.add_argument("--dryrun", default=self.dryrun, type=ast.literal_eval) 30 | 31 | def custom(self, the_dict): 32 | parser = self.core_parser 33 | for key in the_dict: 34 | value = the_dict[key] 35 | if type(value) == str or value is None: 36 | parser.add_argument("--" + key, default=value, type=str) 37 | elif type(value) == int: 38 | parser.add_argument("--" + key, default=value, type=int) 39 | elif type(value) == float: 40 | parser.add_argument("--" + key, default=value, type=float) 41 | elif type(value) == bool: 42 | parser.add_argument("--" + key, default=value, type=ast.literal_eval) 43 | else: 44 | raise Exception("unsupported type:" + type(value)) 45 | 46 | def parse(self): 47 | args = parse_it(self.core_parser) 48 | return args 49 | 50 | def show(self): 51 | the_dic = vars(self.parse()) 52 | keys = list(the_dic.keys()) 53 | keys.sort() 54 | for key in keys: 55 | print(key, ":", the_dic[key]) 56 | 57 | 58 | def parse_it(parser): 59 | args = parser.parse_args() 60 | return args 61 | 62 | 63 | if __name__ == "__main__": 64 | pass 65 | -------------------------------------------------------------------------------- /utils/my_softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MySoftmaxLoss(nn.Module): 6 | def __init__(self, feature_dim, num_class): 7 | super(MySoftmaxLoss, self).__init__() 8 | self.in_feats = feature_dim 9 | self.W = torch.nn.Parameter(torch.randn(feature_dim, num_class)) 10 | self.cel = nn.CrossEntropyLoss() 11 | nn.init.xavier_normal_(self.W, gain=1) 12 | 13 | def forward(self, embedding, labels): 14 | assert embedding.size()[0] == labels.size()[0] 15 | assert embedding.size()[1] == self.in_feats 16 | logits = torch.mm(embedding, self.W) 17 | loss = self.cel(logits, labels) 18 | return loss 19 | -------------------------------------------------------------------------------- /utils/pair_selection_util.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def contrastive_loss(f_emb, v_emb, margin, tau): 7 | # 1. positive pair distances 8 | loss_pos = parallel_l2_distance(f_emb, v_emb).mean() 9 | 10 | # 2.negative pairs 11 | negative_tuple_list = negative_pair_selection(v_emb, f_emb, tau) 12 | f_emb2 = f_emb[[tup[0] for tup in negative_tuple_list]] 13 | v_emb2 = v_emb[[tup[1] for tup in negative_tuple_list]] 14 | dist = parallel_l2_distance(f_emb2, v_emb2) 15 | loss_neg = torch.nn.functional.relu(margin - dist).mean() 16 | loss = loss_pos + loss_neg 17 | return loss 18 | 19 | 20 | def negative_pair_selection(f_emb, v_emb, tau): 21 | # calculate pari-wise similarity: 22 | f_emb_npy = f_emb.detach().cpu().numpy() 23 | v_emb_npy = v_emb.detach().cpu().numpy() 24 | f2v_distance = pairwise_l2_distance(f_emb_npy, v_emb_npy) 25 | 26 | # create (anchor,neg) pairs: 27 | batch_size = len(v_emb) 28 | pair_list = [] 29 | 30 | for i in range(batch_size): 31 | # anchor-negative distance list: 32 | distance_list = f2v_distance[i] 33 | distance_sorted_list = np.argsort(distance_list)[::-1] 34 | # [big_distance_index,.....,small_distance_index] 35 | for j in range(int(tau * batch_size)): 36 | neg_idx = distance_sorted_list[j] 37 | if neg_idx == i: # find positive sample,early exit 38 | break 39 | pair_list.append([i, j]) 40 | return pair_list 41 | 42 | 43 | def parallel_l2_distance(matrix_a, matrix_b): 44 | # matrix_a: (batch_a,dim) 45 | # matrix_b: (batch_b,dim) 46 | # output: (batch) 47 | 48 | c = matrix_a - matrix_b 49 | return torch.sqrt(torch.sum(c * c, dim=1)) 50 | 51 | 52 | def pairwise_l2_distance(matrix_a, matrix_b): 53 | # matrix_a: (batch_a,dim) 54 | # matrix_b: (batch_b,dim) 55 | 56 | matrix_dot = np.dot(matrix_a, np.transpose(matrix_b)) 57 | # (batch_a,batch_b) 58 | 59 | a_square = np.sum(matrix_a * matrix_a, axis=1) 60 | # (batch_a) 61 | 62 | b_square = np.sum(matrix_b * matrix_b, axis=1) 63 | # (batch_b) 64 | 65 | a_square_2d = np.expand_dims(a_square, axis=1) 66 | # (1,batch_a) 67 | 68 | b_square_2d = np.expand_dims(b_square, axis=0) 69 | # (batch_b,1) 70 | 71 | distance_matrix_squired = a_square_2d - 2.0 * matrix_dot + b_square_2d 72 | 73 | distance_matrix = np.maximum(distance_matrix_squired, 0.0) 74 | distance_matrix = np.sqrt(distance_matrix) 75 | return distance_matrix 76 | -------------------------------------------------------------------------------- /utils/path_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def look_up(path): 5 | if os.path.exists(path): 6 | return path 7 | 8 | upper = "." + path 9 | if os.path.exists(upper): 10 | print("switch", path, "==>", upper) 11 | return upper 12 | 13 | return path 14 | 15 | 16 | import pathlib 17 | def mk_parent_dir_if_necessary(img_save_path): 18 | folder = pathlib.Path(img_save_path).parent 19 | if not os.path.exists(folder): 20 | os.makedirs(folder) 21 | -------------------------------------------------------------------------------- /utils/pickle_util.py: -------------------------------------------------------------------------------- 1 | import _pickle as pickle # python3 2 | import time 3 | import json 4 | 5 | 6 | def read_pickle(filepath): 7 | f = open(filepath, 'rb') 8 | word2mfccs = pickle.load(f) 9 | f.close() 10 | return word2mfccs 11 | 12 | 13 | def save_pickle(save_path, save_data): 14 | f = open(save_path, 'wb') 15 | pickle.dump(save_data, f) 16 | f.close() 17 | 18 | 19 | def read_json(filepath): 20 | with open(filepath) as f: 21 | obj = json.load(f) 22 | return obj 23 | 24 | 25 | def save_json(save_path, obj): 26 | with open(save_path, 'w') as f: 27 | json.dump(obj, f) 28 | -------------------------------------------------------------------------------- /utils/sample_util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | def random_element(array, need_index=False): 6 | length = len(array) 7 | assert length > 0, length 8 | rand_index = random.randint(0, length - 1) 9 | if need_index: 10 | return array[rand_index], rand_index 11 | else: 12 | return array[rand_index] 13 | 14 | 15 | def random_elements(array, number): 16 | return np.random.choice(array, number, replace=False) 17 | -------------------------------------------------------------------------------- /utils/seed_util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy 3 | import torch 4 | 5 | 6 | def set_seed(seed): 7 | random.seed(seed) 8 | numpy.random.seed(seed) 9 | 10 | 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | # https://pytorch.org/docs/stable/notes/randomness.html 17 | -------------------------------------------------------------------------------- /utils/vec_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_vec_length(vec): 5 | if type(vec) == list: 6 | vec = np.array(vec) 7 | return np.sqrt(np.sum(vec * vec)) 8 | 9 | 10 | def dict2unit_dict_inplace(the_dict): 11 | for key in the_dict: 12 | vec = the_dict[key] 13 | the_len = get_vec_length(vec) 14 | the_dict[key] = vec / the_len 15 | return the_dict 16 | 17 | 18 | def assert_is_unit_tensor(tensor): 19 | npy = tensor.detach().cpu().numpy() 20 | length = get_vec_length(npy) 21 | assert np.isclose(length, 1.0) 22 | 23 | 24 | def assert_is_unit_vec(npy): 25 | length = get_vec_length(npy) 26 | assert np.isclose(length, 1.0) 27 | 28 | 29 | def assert_dict_unit_vector(the_dic): 30 | for key in the_dic: 31 | v = the_dic[key] 32 | the_len = get_vec_length(v) 33 | assert np.isclose(the_len, 1.0) 34 | break 35 | 36 | 37 | def get_vec_dim_in_dict(the_dic): 38 | for key in the_dic: 39 | v = the_dic[key] 40 | return len(v) 41 | 42 | 43 | def to_unit_vector(vector): 44 | return vector / get_vec_length(vector) 45 | 46 | 47 | def norm_batch_vector(matix): 48 | # matix = np.array([ 49 | # [3, 4], 50 | # [1, 1] 51 | # ]) 52 | vec_length = np.linalg.norm(matix, axis=1, keepdims=True) 53 | out = matix / vec_length 54 | # [[0.6 , 0.8 ], 55 | # [0.707, 0.707]] 56 | return out 57 | -------------------------------------------------------------------------------- /utils/wb_util.py: -------------------------------------------------------------------------------- 1 | # used for late-initialize wandb in case of crash before a full evaluation (which add an idle log on the monitoring panel) 2 | import collections 3 | import wandb 4 | import os 5 | from utils import pickle_util 6 | 7 | cache = collections.defaultdict(list) 8 | 9 | is_inited = False 10 | 11 | 12 | def save(the_path): 13 | if is_inited: 14 | wandb.save(the_path) 15 | else: 16 | cache["save"].append(the_path) 17 | 18 | 19 | def update_config(obj): 20 | if is_inited: 21 | wandb.config.update(obj) 22 | else: 23 | cache["config"].append(obj) 24 | 25 | 26 | def log(obj): 27 | if is_inited: 28 | wandb.log(obj) 29 | else: 30 | cache["log"].append(obj) 31 | 32 | 33 | def init(args): 34 | init_core(args.project, args.name, args.dryrun) 35 | 36 | 37 | def init_core(project, name, dryrun): 38 | global is_inited 39 | if is_inited: 40 | return 41 | is_inited = True 42 | 43 | if dryrun: 44 | os.environ['WANDB_MODE'] = 'dryrun' 45 | wandb.log = do_nothing 46 | wandb.save = do_nothing 47 | wandb.watch = do_nothing 48 | wandb.config = {} 49 | print("wb dryrun mode") 50 | return 51 | 52 | init_based_on_config_file(project, name) 53 | 54 | 55 | def init_based_on_config_file(project, name, config_path=".wb_config.json"): 56 | assert os.path.exists(config_path) 57 | json_dict = pickle_util.read_json(config_path) 58 | 59 | # use self-hosted wb server 60 | key = "WANDB_BASE_URL" 61 | if key in json_dict: 62 | os.environ[key] = json_dict[key] 63 | 64 | # login 65 | wandb.login(key=json_dict["WB_KEY"]) 66 | wandb.init(project=project, name=name) 67 | print("wandb inited") 68 | 69 | # supplement config and logs 70 | for obj in cache["config"]: 71 | wandb.config.update(obj) 72 | 73 | for log in cache["log"]: 74 | wandb.log(log) 75 | 76 | for the_path in cache["save"]: 77 | wandb.save(the_path) 78 | 79 | 80 | def do_nothing(v): 81 | pass 82 | -------------------------------------------------------------------------------- /utils/worker_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy 4 | 5 | 6 | def worker_init_fn(worker_id): 7 | pytorch_seed = torch.utils.data.get_worker_info().seed 8 | seed = pytorch_seed % (2 ** 32 - 1) 9 | random.seed(seed) 10 | numpy.random.seed(seed) 11 | print("worker:%d,pytorch_seed:%d" % (worker_id, seed)) 12 | -------------------------------------------------------------------------------- /works/10_DE_HQ3.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | 3 | from utils import my_parser, seed_util, wb_util, pair_selection_util, eva_emb_full 4 | import os 5 | from models.de_hq3_model import Model 6 | import torch 7 | from loaders import v4_de_hq3_loader 8 | from utils.eval_shortcut import Cut 9 | from utils.eva_emb_full import EmbEva 10 | 11 | 12 | def do_step(epoch, step, data): 13 | optimizer.zero_grad() 14 | data = [i.cuda() for i in data] 15 | voice_data, face_data, id_label = data 16 | loss, info = model(voice_data, face_data, id_label, step) 17 | loss.backward() 18 | optimizer.step() 19 | return loss.item(), {} 20 | 21 | 22 | def train(): 23 | step = 0 24 | model.train() 25 | 26 | for epo in range(args.epoch): 27 | wb_util.log({"train/epoch": epo}) 28 | for data in train_iter: 29 | loss, info = do_step(epo, step, data) 30 | step += 1 31 | if step % 50 == 0: 32 | obj = { 33 | "train/step": step, 34 | "train/loss": loss, 35 | } 36 | obj = {**obj, **info} 37 | print(obj) 38 | wb_util.log(obj) 39 | 40 | if step > 0 and step % args.eval_step == 0: 41 | if eval_cut.eval_short_cut(): 42 | return 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=10) 47 | parser.custom({ 48 | "batch_per_epoch": 500, 49 | "eval_step": 250, 50 | "ratio_orth": 1.0, 51 | "ratio_rec": 1.0, 52 | }) 53 | parser.use_wb("VFALBenchmark", "DE_HQ3") 54 | args = parser.parse() 55 | seed_util.set_seed(args.seed) 56 | train_iter = de_hq3_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 57 | 58 | # model 59 | num_user = len(train_iter.dataset.train_names) 60 | model = Model(num_user, args).cuda() 61 | model_params = model.parameters() 62 | 63 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 64 | emb_eva = EmbEva() 65 | eval_cut = Cut(emb_eva, model, args) 66 | train() 67 | 68 | 69 | # ts python 10_DE_HQ3.py --project=测试DE_HQ3 --name=def 70 | # ts python 10_DE_HQ3.py --project=测试DE_HQ3 --name=ratio_rec_0 --ratio_rec=0 71 | # ts python 10_DE_HQ3.py --project=测试DE_HQ3 --name=ratio_orth_0 --ratio_orth=0 72 | # ts python 10_DE_HQ3.py --project=测试DE_HQ3 --name=only_id --ratio_orth=0 --ratio_rec=0 -------------------------------------------------------------------------------- /works/11_SS_DIM_VFMR_Barlow.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util, pair_selection_util, eva_emb_full 2 | import os 3 | from models.my_model import Encoder 4 | import torch 5 | from loaders import v1_sup_id_loader, v2_unsup_loader 6 | from utils.losses.softmax_loss import SoftmaxLoss 7 | from utils.losses import triplet_hq1, center_loss_learnableW_L2dist, center_loss_eccv16 8 | from utils.eval_shortcut import Cut 9 | from utils.eva_emb_full import EmbEva 10 | from pytorch_metric_learning import losses 11 | from utils.losses import barlow_loss 12 | 13 | 14 | def do_step(epoch, step, data): 15 | optimizer.zero_grad() 16 | data = [i.cuda() for i in data] 17 | voice_data, face_data, id_label, gender_label = data 18 | v_emb, f_emb = model(voice_data, face_data) 19 | 20 | cat_emb = torch.cat([v_emb, f_emb], dim=0) 21 | cat_id = torch.cat([id_label, id_label], dim=0).squeeze() 22 | 23 | if args.name.startswith("SSNet"): 24 | loss_center = fun_center_loss(cat_emb, cat_id) 25 | loss_id = fun_id_classifier(cat_emb, cat_id) 26 | loss = loss_center + loss_id 27 | elif args.name.startswith("DIMNet"): 28 | cat_gender_label = torch.cat([gender_label, gender_label], dim=0).squeeze() 29 | loss_id = fun_id_classifier(cat_emb, cat_id) 30 | loss_gender = fun_id_classifier(cat_emb, cat_gender_label) 31 | loss = loss_gender + loss_id 32 | elif args.name.startswith("LAFV") or args.name.startswith("VFMR"): 33 | loss = fun_loss_metric(cat_emb, cat_id) 34 | elif args.name.startswith("SL-Barlow"): 35 | loss = fun_barlow(v_emb, f_emb) 36 | 37 | loss.backward() 38 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 39 | optimizer.step() 40 | return loss.item(), {} 41 | 42 | 43 | def train(): 44 | step = 0 45 | model.train() 46 | 47 | for epo in range(args.epoch): 48 | wb_util.log({"train/epoch": epo}) 49 | for data in train_iter: 50 | loss, info = do_step(epo, step, data) 51 | step += 1 52 | if step % 50 == 0: 53 | obj = { 54 | "train/step": step, 55 | "train/loss": loss, 56 | } 57 | obj = {**obj, **info} 58 | print(obj) 59 | wb_util.log(obj) 60 | 61 | if step > 0 and step % args.eval_step == 0: 62 | if eval_cut.eval_short_cut(): 63 | return 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=10) 68 | parser.custom({ 69 | "batch_per_epoch": 500, 70 | "eval_step": 100, 71 | "margin": 0.6, 72 | "clip": 1.0 73 | }) 74 | parser.use_wb("VFALBenchmark", "HQ1_") 75 | args = parser.parse() 76 | seed_util.set_seed(args.seed) 77 | train_iter = v1_sup_id_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 78 | 79 | # model 80 | model = Encoder().cuda() 81 | model_params = model.parameters() 82 | 83 | # loss 84 | num_class = len(train_iter.dataset.train_names) 85 | if args.name.startswith("SSNet"): 86 | # 只能达到75%,暂时放弃 87 | # http://wbz.huacishu.com/cgy/VFALBenchmark/runs/0r6ippc4?workspace=user-chenguangyu 88 | fun_center_loss = center_loss_eccv16.CenterLoss(num_classes=num_class, feat_dim=128).cuda() 89 | fun_id_classifier = SoftmaxLoss(128, num_class=num_class).cuda() 90 | model_params = list(model.parameters()) + list(fun_id_classifier.parameters()) + list(fun_center_loss.parameters()) 91 | elif args.name.startswith("DIMNet"): 92 | fun_id_classifier = SoftmaxLoss(128, num_class=num_class).cuda() 93 | fun_gender_classifier = SoftmaxLoss(128, num_class=2).cuda() 94 | model_params = list(model.parameters()) + list(fun_gender_classifier.parameters()) + list(fun_id_classifier.parameters()) 95 | elif args.name.startswith("VFMR"): 96 | fun_loss_metric = losses.LiftedStructureLoss(neg_margin=1, pos_margin=0) 97 | model_params = list(model.parameters()) + list(fun_loss_metric.parameters()) 98 | elif args.name.startswith("SL-Barlow"): 99 | fun_barlow = barlow_loss.BarlowTwinsLoss() 100 | else: 101 | raise Exception("Not Support Name:", args.name) 102 | 103 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 104 | emb_eva = EmbEva() 105 | eval_cut = Cut(emb_eva, model, args) 106 | train() 107 | -------------------------------------------------------------------------------- /works/1_pins.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util, pair_selection_util, eva_emb_full 2 | import os 3 | from models.my_model import Encoder 4 | import torch 5 | from loaders import v2_unsup_loader 6 | from utils.eval_shortcut import Cut 7 | from utils.eva_emb_full import EmbEva 8 | 9 | 10 | def do_step(epoch, step, data): 11 | optimizer.zero_grad() 12 | data = [i.cuda() for i in data] 13 | voice_data, face_data, _ = data 14 | v_emb, f_emb = model(voice_data, face_data) 15 | loss = pair_selection_util.contrastive_loss(f_emb, v_emb, args.margin, tau_value) 16 | loss.backward() 17 | optimizer.step() 18 | info = { 19 | "train/tau_value": tau_value, 20 | } 21 | return loss.item(), info 22 | 23 | 24 | def train(): 25 | step = 0 26 | model.train() 27 | 28 | for epo in range(args.epoch): 29 | wb_util.log({"train/epoch": epo}) 30 | for data in train_iter: 31 | loss, info = do_step(epo, step, data) 32 | step += 1 33 | if step % 50 == 0: 34 | obj = { 35 | "train/step": step, 36 | "train/loss": loss, 37 | } 38 | obj = {**obj, **info} 39 | print(obj) 40 | wb_util.log(obj) 41 | 42 | if step > 0 and step % args.eval_step == 0: 43 | if eval_cut.eval_short_cut(): 44 | return 45 | 46 | global tau_value 47 | if step > 0 and step % 500 == 0 and tau_value < 0.8: 48 | tau_value = tau_value + 0.1 49 | print("Update tau:", tau_value) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=5) 54 | parser.custom({ 55 | "batch_per_epoch": 500, 56 | "eval_step": 100, 57 | "margin": 0.6, 58 | "tau": 0.3, 59 | }) 60 | parser.use_wb("VFALBenchmark", "Pins") 61 | args = parser.parse() 62 | seed_util.set_seed(args.seed) 63 | train_iter = v2_unsup_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 64 | 65 | tau_value = args.tau 66 | 67 | # model 68 | model = Encoder().cuda() 69 | model_params = model.parameters() 70 | 71 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 72 | emb_eva = EmbEva() 73 | eval_cut = Cut(emb_eva, model, args) 74 | train() 75 | -------------------------------------------------------------------------------- /works/2_FV-CME.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util, pair_selection_util, eva_emb_full 2 | import os 3 | from models.my_model import Encoder 4 | import torch 5 | from loaders import v1_sup_id_loader 6 | from utils.losses import unsup_nce 7 | from utils.eval_shortcut import Cut 8 | from utils.eva_emb_full import EmbEva 9 | 10 | 11 | def do_step(epoch, step, data): 12 | optimizer.zero_grad() 13 | data = [i.cuda() for i in data] 14 | voice_data, face_data, id_label, gender_label = data 15 | v_emb, f_emb = model(voice_data, face_data) 16 | loss = npair_loss(f_emb, v_emb) 17 | loss.backward() 18 | optimizer.step() 19 | return loss.item(), {} 20 | 21 | 22 | def train(): 23 | step = 0 24 | model.train() 25 | 26 | for epo in range(args.epoch): 27 | wb_util.log({"train/epoch": epo}) 28 | for data in train_iter: 29 | loss, info = do_step(epo, step, data) 30 | step += 1 31 | if step % 50 == 0: 32 | obj = { 33 | "train/step": step, 34 | "train/loss": loss, 35 | } 36 | obj = {**obj, **info} 37 | print(obj) 38 | wb_util.log(obj) 39 | 40 | if step > 0 and step % args.eval_step == 0: 41 | if eval_cut.eval_short_cut(): 42 | return 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=5) 47 | parser.custom({ 48 | "batch_per_epoch": 500, 49 | "eval_step": 100, 50 | "k": -1 # 无意义 51 | }) 52 | parser.use_wb("VFALBenchmark", "FV-CME") 53 | args = parser.parse() 54 | seed_util.set_seed(args.seed) 55 | train_iter = v1_sup_id_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 56 | 57 | # model 58 | model = Encoder().cuda() 59 | model_params = model.parameters() 60 | 61 | npair_loss = unsup_nce.InfoNCE(temperature=1.0) 62 | 63 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 64 | emb_eva = EmbEva() 65 | eval_cut = Cut(emb_eva, model, args) 66 | train() 67 | -------------------------------------------------------------------------------- /works/3_LAFV.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util, pair_selection_util, eva_emb_full 2 | import os 3 | from models.my_model import Encoder 4 | import torch 5 | from loaders import v3_triplet_loader 6 | from utils.losses import triplet_lafv 7 | from utils.eval_shortcut import Cut 8 | from utils.eva_emb_full import EmbEva 9 | 10 | 11 | def check_nan(model): 12 | # 打印出含有nan的权重层 13 | for name, param in model.named_parameters(): 14 | if torch.isnan(param).any(): 15 | print(f"{name} contains NaN values.") 16 | 17 | 18 | def print_max_val(model): 19 | # 打印出出权重层中最大数值与最小数值 20 | max_vas = [] 21 | min_vals = [] 22 | for name, param in model.named_parameters(): 23 | max_vas.append(param.max().item()) 24 | min_vals.append(param.min().item()) 25 | print(max(max_vas), min(min_vals)) 26 | 27 | 28 | def print_max_grad(model): 29 | for name, param in model.named_parameters(): 30 | if 'weight' in name: 31 | grad = param.grad 32 | if grad is not None: 33 | print(name, torch.max(grad).item()) 34 | 35 | 36 | def do_step(epoch, step, data): 37 | optimizer.zero_grad() 38 | data = [i.cuda() for i in data] 39 | voice_data, face_data_pos, _, face_data_neg = data 40 | v_emb, f_emb_pos = model(voice_data, face_data_pos) 41 | f_emb_neg = model.face_encoder(face_data_neg) 42 | loss = triplet_lafv.triplet_loss(v_emb, f_emb_pos, f_emb_neg) 43 | loss.backward() 44 | optimizer.step() 45 | return loss.item(), {} 46 | 47 | 48 | def train(): 49 | step = 0 50 | model.train() 51 | 52 | for epo in range(args.epoch): 53 | wb_util.log({"train/epoch": epo}) 54 | for data in train_iter: 55 | loss, info = do_step(epo, step, data) 56 | step += 1 57 | if step % 50 == 0: 58 | obj = { 59 | "train/step": step, 60 | "train/loss": loss, 61 | } 62 | obj = {**obj, **info} 63 | print(obj) 64 | wb_util.log(obj) 65 | 66 | if step > 0 and step % args.eval_step == 0: 67 | if eval_cut.eval_short_cut(): 68 | return 69 | 70 | 71 | if __name__ == "__main__": 72 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=5, worker=4) 73 | parser.custom({ 74 | "batch_per_epoch": 500, 75 | "eval_step": 100, 76 | }) 77 | parser.use_wb("VFALBenchmark", "LAFV") 78 | args = parser.parse() 79 | seed_util.set_seed(args.seed) 80 | train_iter = v3_triplet_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 81 | 82 | # model 83 | model = Encoder().cuda() 84 | model_params = model.parameters() 85 | 86 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 87 | emb_eva = EmbEva() 88 | eval_cut = Cut(emb_eva, model, args) 89 | train() 90 | -------------------------------------------------------------------------------- /works/5_Wen.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | 3 | from utils import my_parser, seed_util, wb_util, pair_selection_util, eva_emb_full, pickle_util, path_util 4 | import os 5 | from models.my_model import Encoder 6 | import torch 7 | from loaders import v1_sup_id_loader 8 | from utils.losses.wen_reweight import * 9 | from utils.losses import wen_explicit_loss 10 | from utils.losses.softmax_loss import SoftmaxLoss 11 | from utils.eval_shortcut import Cut 12 | from utils.eva_emb_full import EmbEva 13 | 14 | 15 | class ModelWrapper: 16 | 17 | def __init__(self, train_iter, warmup_step=500): 18 | self.warmup_step = warmup_step 19 | 20 | model = Encoder().cuda() 21 | 22 | # loss 23 | num_class = len(train_iter.dataset.train_names) 24 | self.fun_id_classifier = SoftmaxLoss(128, num_class=num_class, reduction="none").cuda() 25 | 26 | # optimizer 27 | model_params = list(model.parameters()) + list(self.fun_id_classifier.parameters()) 28 | # self.optimizer = torch.optim.SGD(model_params, lr=1e-2) 29 | # self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.1, last_epoch=-1) 30 | self.optimizer = torch.optim.Adam(model_params, lr=args.lr) 31 | self.scheduler = None 32 | 33 | self.num_class = num_class 34 | self.model = model 35 | 36 | # forever iter 37 | def cycle(dataloader): 38 | while True: 39 | for data in dataloader: 40 | yield data 41 | 42 | self.train_iter = iter(cycle(train_iter)) 43 | 44 | def train_step(self, id2weights=None): 45 | self.optimizer.zero_grad() 46 | # data 47 | data = next(self.train_iter) 48 | data = [i.cuda() for i in data] 49 | voice_data, face_data, id_label, _ = data 50 | # to emb 51 | v_emb, f_emb = self.model(voice_data, face_data) 52 | 53 | # loss 54 | loss_metric = 0.5 * wen_explicit_loss.cross_logit(v_emb, f_emb) + 0.5 * wen_explicit_loss.cross_logit(f_emb, v_emb) 55 | loss_ID_cls = self.fun_id_classifier(v_emb, id_label) + self.fun_id_classifier(f_emb, id_label) 56 | 57 | # ReWeight 58 | if id2weights is not None and len(id2weights) > 0: 59 | weights = [id2weights[key] for key in id_label.squeeze().detach().cpu().numpy()] 60 | weights = torch.FloatTensor(weights).cuda() 61 | loss = (weights * loss_ID_cls).mean() + (weights * loss_metric).mean() 62 | else: 63 | loss = loss_ID_cls.mean() + loss_metric.mean() 64 | 65 | loss.backward() 66 | self.optimizer.step() 67 | 68 | info = { 69 | "loss_item": loss.item(), 70 | "label_list": id_label.detach().cpu().numpy(), 71 | "loss_list": loss_ID_cls.detach().cpu().numpy(), # use not-weighted ID-Cls loss, Equation(7) of the original paper. 72 | } 73 | return info 74 | 75 | def generate_weights(self): 76 | identitiy_count = self.num_class 77 | total_step = 0 78 | 79 | # =====================================================stage1: 80 | print("Stage1: Warmup...") 81 | for i in range(self.warmup_step): 82 | info = self.train_step() 83 | if i % 25 == 0: 84 | obj = { 85 | "train_pre/loss": info["loss_item"], 86 | "train_pre/step": i 87 | } 88 | wb_util.log(obj) 89 | print(obj) 90 | total_step += 1 91 | 92 | # =====================================================stage2: 93 | print("Stage2: Calculate Weights....") 94 | t = 0 95 | id2weights = {} 96 | hardness = {} 97 | while not_zero_count(id2weights) < 0.9 * identitiy_count: 98 | info = self.train_step(id2weights) 99 | hardness = update_hardness(hardness, info["label_list"], info["loss_list"]) 100 | 101 | total_step += 1 102 | t += 1 103 | if t % 100 == 0: 104 | ratio = not_zero_count(id2weights) / identitiy_count 105 | 106 | if self.scheduler is not None: 107 | cur_lr = self.scheduler.get_last_lr()[0] 108 | else: 109 | cur_lr = -1 110 | 111 | obj = { 112 | "train_pre/step": total_step, 113 | "train_pre/lr": cur_lr, 114 | "train_pre/loss": info["loss_item"], 115 | "train_pre/non_zero_weight_ratio": ratio, 116 | } 117 | wb_util.log(obj) 118 | print(obj) 119 | id2weights = update_weight(id2weights, hardness) 120 | 121 | if total_step in [2000, 3000] and self.scheduler is not None: 122 | self.scheduler.step() 123 | 124 | print("Stage2 End, total step:%d" % (total_step)) 125 | return id2weights 126 | 127 | 128 | def train(id2weights): 129 | step = 0 130 | model_wrapper.model.train() 131 | for i in range(10 * 10000): 132 | info = model_wrapper.train_step(id2weights) 133 | info = {"train/loss": info["loss_item"]} 134 | 135 | step += 1 136 | if step % 50 == 0: 137 | obj = { 138 | "train/step": step, 139 | "train/loss": info["train/loss"], 140 | } 141 | print(obj) 142 | wb_util.log(obj) 143 | 144 | if step > 0 and step % args.eval_step == 0: 145 | if eval_cut.eval_short_cut(): 146 | return 147 | model_wrapper.model.train() 148 | 149 | 150 | def load_wens_official(): 151 | name2weights = {} 152 | dataset = train_iter.dataset 153 | lines = [l.strip() for l in open("./dataset/info/works/wen_weights.txt").readlines() if l.strip()] 154 | for line in lines: 155 | k, v = line.split(" ") 156 | name2weights[k] = float(v) 157 | 158 | id2weights = {} 159 | for name, id in dataset.name2id.items(): 160 | v = name2weights.get(name, 1.0) 161 | id2weights[id] = v 162 | 163 | for k, v in id2weights.items(): 164 | print(k, v) 165 | return id2weights 166 | 167 | 168 | if __name__ == "__main__": 169 | # python 5_Wen.py --calc_weight=False --load_weight_path="./outputs/VFALBenchmark/use_weight/id2weights.json" --name=use_weight --dryrun=False 170 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=5) 171 | parser.custom({ 172 | "batch_per_epoch": 500, 173 | "eval_step": 100, 174 | "mode": "load_official", # load_official、clac_weight、load_file 175 | "load_weight_path": "" 176 | }) 177 | parser.use_wb("9.16_Wen", "run1") 178 | args = parser.parse() 179 | seed_util.set_seed(args.seed) 180 | wb_util.init(args) 181 | wb_util.save(__file__) 182 | 183 | # train 184 | train_iter = v1_sup_id_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 185 | 186 | if args.mode == "clac_weight": 187 | # calc identity weights 188 | train_iter2 = v1_sup_id_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 189 | model_wrapper = ModelWrapper(train_iter2, warmup_step=500) 190 | id2weights = model_wrapper.generate_weights() 191 | id2weights_save_path = os.path.join(args.model_save_folder, args.project, args.name, "id2weights.json") 192 | path_util.mk_parent_dir_if_necessary(id2weights_save_path) 193 | pickle_util.save_json(id2weights_save_path, id2weights) 194 | print("id2weights:", id2weights) 195 | elif args.mode == "load_file": 196 | print("load weights") 197 | tmp = pickle_util.read_json(args.load_weight_path) 198 | id2weights = {} 199 | for k, v in tmp.items(): 200 | id2weights[int(k)] = v 201 | elif args.mode == "load_official": 202 | id2weights = load_wens_official() 203 | else: 204 | print("不使用weight") 205 | id2weights = {} 206 | 207 | model_wrapper = ModelWrapper(train_iter) 208 | emb_eva = EmbEva() 209 | eval_cut = Cut(emb_eva, model_wrapper.model, args) 210 | train(id2weights) 211 | -------------------------------------------------------------------------------- /works/6_FOP.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util, pair_selection_util, eva_emb_full 2 | import os 3 | from models.my_model import Encoder 4 | from models import fop_model 5 | import torch 6 | from loaders import v1_sup_id_loader 7 | from utils.losses.softmax_loss import SoftmaxLoss 8 | from utils.losses import fop_loss 9 | from utils.eval_shortcut import Cut 10 | from utils.eva_emb_full import EmbEva 11 | 12 | 13 | def do_step(epoch, step, data): 14 | optimizer.zero_grad() 15 | data = [i.cuda() for i in data] 16 | voice_data, face_data, id_label, _ = data 17 | v_emb, f_emb, fusion_emb = model(voice_data, face_data) 18 | 19 | if args.use_fusion_block: 20 | loss_id = fun_id_classifier(fusion_emb, id_label) 21 | loss_fop = fun_fop_loss(fusion_emb, id_label) 22 | else: 23 | cat_emb = torch.cat([v_emb, f_emb], dim=0) 24 | cat_id = torch.cat([id_label, id_label], dim=0).squeeze() 25 | loss_id = fun_id_classifier(cat_emb, cat_id) 26 | loss_fop = fun_fop_loss(cat_emb, cat_id) 27 | 28 | loss = loss_id + loss_fop 29 | loss.backward() 30 | optimizer.step() 31 | return loss.item(), {} 32 | 33 | 34 | def train(): 35 | step = 0 36 | model.train() 37 | 38 | for epo in range(args.epoch): 39 | wb_util.log({"train/epoch": epo}) 40 | for data in train_iter: 41 | loss, info = do_step(epo, step, data) 42 | step += 1 43 | if step % 50 == 0: 44 | obj = { 45 | "train/step": step, 46 | "train/loss": loss, 47 | } 48 | obj = {**obj, **info} 49 | print(obj) 50 | wb_util.log(obj) 51 | 52 | if step > 0 and step % args.eval_step == 0: 53 | if eval_cut.eval_short_cut(): 54 | return 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=6) 59 | parser.custom({ 60 | "batch_per_epoch": 500, 61 | "eval_step": 200, 62 | "use_fusion_block": False, 63 | }) 64 | parser.use_wb("VFALBenchmark", "FOP") 65 | args = parser.parse() 66 | seed_util.set_seed(args.seed) 67 | wb_util.save(__file__) 68 | wb_util.save(fop_model.__file__) 69 | train_iter = v1_sup_id_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 70 | 71 | # model 72 | model = fop_model.FopModel().cuda() 73 | 74 | # loss 75 | num_class = len(train_iter.dataset.train_names) 76 | fun_id_classifier = SoftmaxLoss(128, num_class=num_class).cuda() 77 | fun_fop_loss = fop_loss.OrthogonalProjectionLoss() 78 | 79 | # optimizer 80 | model_params = list(model.parameters()) + list(fun_id_classifier.parameters()) 81 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 82 | emb_eva = EmbEva() 83 | eval_cut = Cut(emb_eva, model, args) 84 | train() 85 | -------------------------------------------------------------------------------- /works/7_CMPC.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util, pair_selection_util, eva_emb_full 2 | import os 3 | from models.my_model import Encoder 4 | import torch 5 | from loaders import v2_unsup_loader 6 | from utils.losses import cmpc_loss 7 | from utils.eval_shortcut import Cut 8 | from utils.eva_emb_full import EmbEva 9 | from torch.nn import functional as F 10 | # from utils import keops_kmeans 11 | 12 | 13 | class Memory(): 14 | 15 | def __init__(self, vec_number, embedding_dim, momentum=0.5): 16 | super(Memory, self).__init__() 17 | self.momentum = momentum 18 | memo = torch.randn(vec_number, embedding_dim).cuda() 19 | memo = F.normalize(memo, p=2, dim=1) 20 | self.memo = memo 21 | 22 | def update(self, emb, label): 23 | with torch.no_grad(): 24 | emb = F.normalize(emb, p=2, dim=1) 25 | old_cache = self.memo.index_select(0, label) 26 | old_cache.mul_(self.momentum) 27 | old_cache.add_(torch.mul(emb, 1 - self.momentum)) 28 | new_cache = F.normalize(old_cache, p=2, dim=1) 29 | self.memo.index_copy_(0, label, new_cache) 30 | 31 | 32 | def do_step(data, v_cluster_result, f_cluster_result): 33 | optimizer.zero_grad() 34 | data = [i.cuda() for i in data] 35 | voice_data, face_data, movie_id = data 36 | v_emb, f_emb = model(voice_data, face_data) 37 | v_memory.update(v_emb, movie_id) 38 | f_memory.update(v_emb, movie_id) 39 | 40 | loss = loss_fun(v_emb, f_emb, v_cluster_result, f_cluster_result, movie_id) 41 | loss.backward() 42 | optimizer.step() 43 | return loss.item(), {} 44 | 45 | 46 | def train(): 47 | step = 0 48 | model.train() 49 | f_cluster_result = None 50 | v_cluster_result = None 51 | 52 | for epo in range(args.epoch): 53 | wb_util.log({"train/epoch": epo}) 54 | 55 | for data in train_iter: 56 | loss, info = do_step(data, v_cluster_result, f_cluster_result) 57 | step += 1 58 | if step % 50 == 0: 59 | obj = { 60 | "train/step": step, 61 | "train/loss": loss, 62 | } 63 | obj = {**obj, **info} 64 | print(obj) 65 | wb_util.log(obj) 66 | 67 | if step > 0 and step % args.eval_step == 0: 68 | if eval_cut.eval_short_cut(): 69 | return 70 | 71 | # do cluster 72 | # model.eval() 73 | # num_cluster = [50, 1000, 1500] 74 | # f_cluster_result = keops_kmeans.run_kmeans(f_memory.memo, num_cluster, Niter=20, temperature=0.2, verbose=True) 75 | # v_cluster_result = keops_kmeans.run_kmeans(v_memory.memo, num_cluster, Niter=20, temperature=0.2, verbose=True) 76 | # model.train() 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=6) 81 | parser.custom({ 82 | "warmup_iter": 50, 83 | "batch_per_epoch": 500, 84 | "eval_step": 200, 85 | "margin": 0.6, 86 | "use_fusion_block": False, 87 | "temperature": 0.03, 88 | }) 89 | parser.use_wb("9.17CMPC", "run1") 90 | args = parser.parse() 91 | seed_util.set_seed(args.seed) 92 | wb_util.save(__file__) 93 | train_iter = v2_unsup_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 94 | 95 | # model 96 | len_train_movies = len(train_iter.dataset.train_movies) 97 | v_memory = Memory(len_train_movies, 128) 98 | f_memory = Memory(len_train_movies, 128) 99 | model = Encoder().cuda() 100 | 101 | # loss 102 | loss_fun = cmpc_loss.IR_CMPC(args.temperature, delta=-1, ka=0.1, R=3).cuda() 103 | 104 | # optimizer 105 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 106 | emb_eva = EmbEva() 107 | eval_cut = Cut(emb_eva, model, args) 108 | train() 109 | -------------------------------------------------------------------------------- /works/8_CAE.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util, pair_selection_util, eva_emb_full 2 | import os 3 | from models.cae_model import CAE 4 | import torch 5 | from loaders import v2_unsup_loader 6 | from utils.eval_shortcut import Cut 7 | from utils.eva_emb_full import EmbEva 8 | 9 | 10 | def do_step(epoch, step, data): 11 | optimizer.zero_grad() 12 | data = [i.cuda() for i in data] 13 | voice_data, face_data, _ = data 14 | loss_emb, loss_dec = model(voice_data, face_data) 15 | loss = loss_emb + loss_dec 16 | loss.backward() 17 | optimizer.step() 18 | return loss.item(), {} 19 | 20 | 21 | def train(): 22 | step = 0 23 | model.train() 24 | 25 | for epo in range(args.epoch): 26 | wb_util.log({"train/epoch": epo}) 27 | for data in train_iter: 28 | loss, info = do_step(epo, step, data) 29 | step += 1 30 | if step % 50 == 0: 31 | obj = { 32 | "train/step": step, 33 | "train/loss": loss, 34 | } 35 | obj = {**obj, **info} 36 | print(obj) 37 | wb_util.log(obj) 38 | 39 | if step > 0 and step % args.eval_step == 0: 40 | if eval_cut.eval_short_cut(): 41 | return 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=5) 46 | parser.custom({ 47 | "batch_per_epoch": 500, 48 | "eval_step": 100, 49 | }) 50 | parser.use_wb("VFALBenchmark", "CAE") 51 | args = parser.parse() 52 | seed_util.set_seed(args.seed) 53 | train_iter = v2_unsup_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 54 | 55 | # model 56 | model = CAE().cuda() 57 | model_params = model.parameters() 58 | 59 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 60 | emb_eva = EmbEva() 61 | eval_cut = Cut(emb_eva, model, args) 62 | train() 63 | -------------------------------------------------------------------------------- /works/9_SL.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util, deepcluster_util, pickle_util 2 | from utils.eval_shortcut import Cut 3 | from models import my_model 4 | import torch 5 | from loaders import v5_voxceleb_cluster_ordered_loader 6 | from loaders import v6_voxceleb_loader_for_deepcluster 7 | from pytorch_metric_learning import losses 8 | from utils import model_util 9 | from utils.eva_emb_full import EmbEva 10 | import os 11 | from utils.config import face_emb_dict, voice_emb_dict 12 | import ipdb 13 | import tqdm 14 | 15 | 16 | def do_step(epoch, step, data): 17 | optimizer.zero_grad() 18 | data = [i.cuda() for i in data] 19 | voice_data, face_data, label = data 20 | v_emb, f_emb = model(voice_data, face_data) 21 | emb = torch.cat([v_emb, f_emb], dim=0) 22 | label2 = torch.cat([label, label], dim=0).squeeze() 23 | 24 | if args.ratio_mse > 0: 25 | loss_mse = fun_loss_mse(v_emb, f_emb) * args.ratio_mse 26 | else: 27 | loss_mse = 0 28 | 29 | loss = fun_loss_metric(emb, label2) + loss_mse 30 | loss.backward() 31 | optimizer.step() 32 | info = { 33 | } 34 | return loss.item(), info 35 | 36 | 37 | def get_ratio(loss, total_loss): 38 | if type(loss) == torch.Tensor: 39 | loss = loss.item() 40 | return loss / total_loss.item() 41 | 42 | 43 | def train(): 44 | step = 0 45 | model.train() 46 | 47 | for epo in range(args.epoch): 48 | wb_util.log({"train/epoch": epo}) 49 | # do cluster 50 | all_keys, all_emb, all_emb_v, all_emb_f = v5_voxceleb_cluster_ordered_loader.extract_embeddings(face_emb_dict, voice_emb_dict, model) 51 | movie2label, _ = deepcluster_util.do_cluster_v2(all_keys, all_emb, all_emb_v, all_emb_f, args.ncentroids, input_emb_type=args.cluster_type) 52 | # create dataset 53 | train_iter = v6_voxceleb_loader_for_deepcluster.get_iter(args.batch_size, 54 | args.batch_per_epoch * args.batch_size, 55 | face_emb_dict, 56 | voice_emb_dict, 57 | movie2label) 58 | 59 | for data in train_iter: 60 | loss, info = do_step(epo, step, data) 61 | step += 1 62 | if step % 50 == 0: 63 | obj = { 64 | "train/step": step, 65 | "train/loss": loss, 66 | } 67 | obj = {**obj, **info} 68 | print(obj) 69 | wb_util.log(obj) 70 | 71 | if step % args.eval_step == 0: 72 | if eval_cut.eval_short_cut(): 73 | return 74 | 75 | 76 | if __name__ == "__main__": 77 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=10) 78 | parser.custom({ 79 | "ncentroids": 1000, 80 | "batch_per_epoch": 500, 81 | "eval_step": 200, 82 | 83 | "ratio_mse": 0.0, 84 | 85 | "mts_alpha": 2.0, 86 | "mts_beta": 50.0, 87 | "mts_base": 1.0, 88 | 89 | "load_model": "", 90 | 91 | "cluster_type": "all", 92 | }) 93 | parser.use_wb("sl_project", "SL") 94 | args = parser.parse() 95 | seed_util.set_seed(args.seed) 96 | assert args.cluster_type in ["v", "f", "all"] 97 | 98 | 99 | # 1.model: 100 | model = my_model.Encoder().cuda() 101 | if args.load_model is not None and os.path.exists(args.load_model): 102 | model_util.load_model(args.load_model, model, strict=True) 103 | 104 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 105 | 106 | # 3.loss 107 | fun_loss_metric = losses.MultiSimilarityLoss(alpha=args.mts_alpha, beta=args.mts_beta, base=args.mts_base) 108 | 109 | fun_loss_mse = torch.nn.MSELoss() 110 | emb_eva = EmbEva() 111 | eval_cut = Cut(emb_eva, model, args) 112 | 113 | train() 114 | 115 | 116 | -------------------------------------------------------------------------------- /works_loss_cmp/0_loss_compare.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util, pair_selection_util, eva_emb_full 2 | import os 3 | from models.my_model import Encoder 4 | import torch 5 | from loaders import v1_sup_id_loader, v2_unsup_loader 6 | from utils.losses import triplet_hq1, center_loss_learnableW_L2dist, center_loss_eccv16 7 | from utils.eval_shortcut import Cut 8 | from utils import eva_emb_full 9 | from utils.losses import fop_loss, my_pml_infonce_v2 10 | from utils.losses import barlow_loss 11 | from utils.losses import unsup_nce 12 | from utils.losses.softmax_loss import SoftmaxLoss 13 | from pytorch_metric_learning import losses 14 | 15 | 16 | def do_step(epoch, step, data): 17 | optimizer.zero_grad() 18 | data = [i.cuda() for i in data] 19 | if len(data) == 4: 20 | voice_data, face_data, id_label, _ = data 21 | else: 22 | assert len(data) == 3 23 | voice_data, face_data, _ = data 24 | 25 | v_emb, f_emb = model(voice_data, face_data) 26 | 27 | if args.loss == "barlow": 28 | loss = loss_fun(v_emb, f_emb) 29 | elif args.loss == "unsup_nce": 30 | loss = loss_fun(v_emb, f_emb) + loss_fun(f_emb, v_emb) 31 | elif args.loss == "hq_triplet": 32 | loss = triplet_hq1.triplet_loss(v_emb, f_emb, id_label) 33 | elif args.loss == "sup_nce_v2": 34 | loss = loss_fun(v_emb, f_emb, id_label, id_label) + loss_fun(f_emb, v_emb, id_label, id_label) 35 | else: 36 | cat_emb = torch.cat([v_emb, f_emb], dim=0) 37 | cat_id = torch.cat([id_label, id_label], dim=0).squeeze() 38 | loss = loss_fun(cat_emb, cat_id) 39 | 40 | loss.backward() 41 | # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 42 | optimizer.step() 43 | return loss.item(), {} 44 | 45 | 46 | def train(): 47 | step = 0 48 | model.train() 49 | 50 | for epo in range(args.epoch): 51 | wb_util.log({"train/epoch": epo}) 52 | for data in train_iter: 53 | loss, info = do_step(epo, step, data) 54 | step += 1 55 | if step % 50 == 0: 56 | obj = { 57 | "train/step": step, 58 | "train/loss": loss, 59 | } 60 | obj = {**obj, **info} 61 | print(obj) 62 | wb_util.log(obj) 63 | 64 | if step > 0 and step % args.eval_step == 0: 65 | if eval_cut.eval_short_cut(): 66 | return 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=10) 71 | parser.custom({ 72 | "batch_per_epoch": 500, 73 | "eval_step": 100, 74 | "loss": "sup_nce", 75 | "contrastive_margin": 0.5, 76 | "triplet_margin": 1.0, 77 | "infoNCE_temperature": 0.07, 78 | "mts_alpha": 2.0, 79 | "mts_beta": 50.0, 80 | "mts_base": 1.0, 81 | }) 82 | parser.use_wb("VFALBenchmark", "LossComp") 83 | args = parser.parse() 84 | seed_util.set_seed(args.seed) 85 | # data 86 | if args.loss in ["unsup_nce", "barlow"]: 87 | # use unsupervised loader for unsupvised loss function! 88 | train_iter = v2_unsup_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 89 | else: 90 | train_iter = v1_sup_id_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 91 | num_class = len(train_iter.dataset.train_names) 92 | 93 | # model 94 | model = Encoder().cuda() 95 | 96 | # loss 97 | if args.loss == "sup_nce_pml": 98 | loss_fun = losses.NTXentLoss(temperature=args.infoNCE_temperature) 99 | elif args.loss == "sup_nce_v2": 100 | loss_fun = my_pml_infonce_v2.InfoNCE(temperature=args.infoNCE_temperature, reduction="mean") 101 | elif args.loss == "unsup_nce": 102 | loss_fun = unsup_nce.InfoNCE(args.infoNCE_temperature) 103 | elif args.loss == "pml_triplet_loss": 104 | loss_fun = losses.TripletMarginLoss() 105 | elif args.loss == "NCALoss": 106 | loss_fun = losses.NCALoss(softmax_scale=1) 107 | elif args.loss == "ProxyNCA": 108 | loss_fun = losses.ProxyNCALoss(num_classes=num_class, embedding_size=128).to(torch.device('cuda')) 109 | elif args.loss == "MultiSimilarity": 110 | loss_fun = losses.MultiSimilarityLoss(alpha=args.mts_alpha, beta=args.mts_beta, base=args.mts_base) 111 | elif args.loss == "LiftedStructure": 112 | loss_fun = losses.LiftedStructureLoss(neg_margin=1, pos_margin=0) 113 | elif args.loss == "barlow": 114 | loss_fun = barlow_loss.BarlowTwinsLoss() 115 | elif args.loss == "softmax": 116 | loss_fun = SoftmaxLoss(128, num_class=num_class).cuda() 117 | elif args.loss == "fop": 118 | loss_fun = fop_loss.OrthogonalProjectionLoss() 119 | else: 120 | raise Exception("wrong loss function:" + args.loss) 121 | 122 | if loss_fun is not None: 123 | model_params = list(model.parameters()) + list(loss_fun.parameters()) 124 | else: 125 | model_params = model.parameters() 126 | 127 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 128 | emb_eva = eva_emb_full.EmbEva() 129 | eval_cut = Cut(emb_eva, model, args) 130 | train() 131 | -------------------------------------------------------------------------------- /works_loss_cmp/1_contrastive_loss.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util 2 | from models.my_model import Encoder 3 | import torch 4 | from loaders import v7_sup_id_contrastive_loader 5 | from utils.eval_shortcut import Cut 6 | from utils import eva_emb_full 7 | 8 | 9 | def do_step(epoch, step, data): 10 | optimizer.zero_grad() 11 | data = [i.cuda() for i in data] 12 | voice_data, face_data, is_same_person = data 13 | v_emb, f_emb = model(voice_data, face_data) 14 | loss = loss_fun(v_emb, f_emb, is_same_person) 15 | loss.backward() 16 | optimizer.step() 17 | return loss.item(), {} 18 | 19 | 20 | def train(): 21 | step = 0 22 | model.train() 23 | 24 | for epo in range(args.epoch): 25 | wb_util.log({"train/epoch": epo}) 26 | for data in train_iter: 27 | loss, info = do_step(epo, step, data) 28 | step += 1 29 | if step % 50 == 0: 30 | obj = { 31 | "train/step": step, 32 | "train/loss": loss, 33 | } 34 | obj = {**obj, **info} 35 | print(obj) 36 | wb_util.log(obj) 37 | 38 | if step > 0 and step % args.eval_step == 0: 39 | if eval_cut.eval_short_cut(test_threshold=0): 40 | return 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=10) 45 | parser.custom({ 46 | "batch_per_epoch": 500, 47 | "eval_step": 100, 48 | "margin": 0.5, 49 | "loss": "" 50 | }) 51 | parser.use_wb("VFALBenchmark", "contrastive loss") 52 | args = parser.parse() 53 | seed_util.set_seed(args.seed) 54 | train_iter = v7_sup_id_contrastive_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 55 | num_class = len(train_iter.dataset.train_names) 56 | 57 | # model 58 | model = Encoder().cuda() 59 | 60 | loss_fun = torch.nn.CosineEmbeddingLoss(margin=args.margin, reduction="mean") 61 | 62 | model_params = model.parameters() 63 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 64 | emb_eva = eva_emb_full.EmbEva() 65 | eval_cut = Cut(emb_eva, model, args) 66 | train() 67 | -------------------------------------------------------------------------------- /works_loss_cmp/2_triplet_loss.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util 2 | from models.my_model import Encoder 3 | import torch 4 | from loaders import v3_triplet_loader 5 | from utils.eval_shortcut import Cut 6 | from utils import eva_emb_full 7 | 8 | 9 | def do_step(epoch, step, data): 10 | optimizer.zero_grad() 11 | data = [i.cuda() for i in data] 12 | v1, f1, v2, f2 = data 13 | v_emb1, f_emb1 = model(v1, f1) 14 | v_emb2, f_emb2 = model(v2, f2) 15 | loss = loss_fun(v_emb1, f_emb1, f_emb2) + loss_fun(f_emb1, v_emb1, v_emb2) 16 | loss.backward() 17 | optimizer.step() 18 | return loss.item(), {} 19 | 20 | 21 | def train(): 22 | step = 0 23 | model.train() 24 | 25 | for epo in range(args.epoch): 26 | wb_util.log({"train/epoch": epo}) 27 | for data in train_iter: 28 | loss, info = do_step(epo, step, data) 29 | step += 1 30 | if step % 50 == 0: 31 | obj = { 32 | "train/step": step, 33 | "train/loss": loss, 34 | } 35 | obj = {**obj, **info} 36 | print(obj) 37 | wb_util.log(obj) 38 | 39 | if step > 0 and step % args.eval_step == 0: 40 | if eval_cut.eval_short_cut(): 41 | return 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder="./outputs/", early_stop=10) 46 | parser.custom({ 47 | "batch_per_epoch": 500, 48 | "eval_step": 100, 49 | "margin": 1.0, 50 | "loss": "" 51 | }) 52 | parser.use_wb("VFALBenchmark", "triplet loss") 53 | args = parser.parse() 54 | seed_util.set_seed(args.seed) 55 | train_iter = v3_triplet_loader.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size) 56 | 57 | # model 58 | model = Encoder().cuda() 59 | 60 | loss_fun = torch.nn.TripletMarginLoss(margin=args.margin, p=2) 61 | 62 | model_params = model.parameters() 63 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 64 | emb_eva = eva_emb_full.EmbEva() 65 | eval_cut = Cut(emb_eva, model, args) 66 | train() 67 | --------------------------------------------------------------------------------