├── models ├── __init__.py ├── my_model.py └── incep.py ├── utils ├── __init__.py ├── path_util.py ├── worker_util.py ├── seed_util.py ├── sample_util.py ├── deep_coral_loss.py ├── pickle_util.py ├── my_softmax_loss.py ├── barlow_loss.py ├── vec_util.py ├── model_selector.py ├── eval_shortcut.py ├── wb_util.py ├── map_evaluate.py ├── distance_util.py ├── model_util.py ├── my_parser.py ├── deepcluster_util.py └── eva_emb_full.py ├── configs ├── __init__.py └── config.py ├── loaders ├── __init__.py ├── img_loader.py ├── voice_loader.py ├── voxceleb_loader_for_cae.py ├── voxceleb_loader_for_deepcluster.py └── voxceleb_cluster_ordered_loader.py ├── scripts ├── 1_extract_face_emb.py └── 2_exract_voice_emb.py ├── baseline ├── 3_barlow.py ├── 1_ccae.py └── 2_deepcluster.py ├── README.md ├── .gitignore └── sl.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/path_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def look_up(path): 5 | if os.path.exists(path): 6 | return path 7 | 8 | upper = "." + path 9 | if os.path.exists(upper): 10 | print("switch", path, "==>", upper) 11 | return upper 12 | 13 | return path 14 | -------------------------------------------------------------------------------- /utils/worker_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy 4 | 5 | 6 | def worker_init_fn(worker_id): 7 | pytorch_seed = torch.utils.data.get_worker_info().seed 8 | seed = pytorch_seed % (2 ** 32 - 1) 9 | random.seed(seed) 10 | numpy.random.seed(seed) 11 | print("worker:%d,pytorch_seed:%d" % (worker_id, seed)) 12 | -------------------------------------------------------------------------------- /utils/seed_util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy 3 | import torch 4 | 5 | 6 | def set_seed(seed): 7 | random.seed(seed) 8 | numpy.random.seed(seed) 9 | 10 | 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | # https://pytorch.org/docs/stable/notes/randomness.html 17 | -------------------------------------------------------------------------------- /utils/sample_util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | def random_element(array, need_index=False): 6 | length = len(array) 7 | assert length > 0, length 8 | rand_index = random.randint(0, length - 1) 9 | if need_index: 10 | return array[rand_index], rand_index 11 | else: 12 | return array[rand_index] 13 | 14 | 15 | def random_elements(array, number): 16 | return np.random.choice(array, number, replace=False) 17 | -------------------------------------------------------------------------------- /configs/config.py: -------------------------------------------------------------------------------- 1 | from utils import pickle_util, vec_util, path_util 2 | from utils.eva_emb_full import EmbEva 3 | 4 | model_save_folder = "./outputs/" 5 | 6 | # 1. data input 7 | face_emb_dict = pickle_util.read_pickle(path_util.look_up("./dataset/voxceleb/face_input.pkl")) 8 | voice_emb_dict = pickle_util.read_pickle(path_util.look_up("./dataset/voxceleb/voice_input.pkl")) 9 | vec_util.dict2unit_dict_inplace(face_emb_dict) 10 | vec_util.dict2unit_dict_inplace(voice_emb_dict) 11 | 12 | # 2.eval 13 | emb_eva = EmbEva(voice_emb_dict, face_emb_dict) 14 | -------------------------------------------------------------------------------- /utils/deep_coral_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def CORALV2(source, target): 5 | d = source.data.shape[1] 6 | # source covariance 7 | xm = torch.mean(source, 1, keepdim=True) - source 8 | xc = torch.matmul(torch.transpose(xm, 0, 1), xm) 9 | # target covariance 10 | xmt = torch.mean(target, 1, keepdim=True) - target 11 | xct = torch.matmul(torch.transpose(xmt, 0, 1), xmt) 12 | 13 | # frobenius norm between source and target 14 | tmp = torch.sum(torch.mul((xc - xct), (xc - xct))) 15 | loss = tmp / (4 * d * d) 16 | 17 | return loss 18 | -------------------------------------------------------------------------------- /utils/pickle_util.py: -------------------------------------------------------------------------------- 1 | import _pickle as pickle # python3 2 | import time 3 | import json 4 | 5 | 6 | def read_pickle(filepath): 7 | f = open(filepath, 'rb') 8 | word2mfccs = pickle.load(f) 9 | f.close() 10 | return word2mfccs 11 | 12 | 13 | def save_pickle(save_path, save_data): 14 | f = open(save_path, 'wb') 15 | pickle.dump(save_data, f) 16 | f.close() 17 | 18 | 19 | def read_json(filepath): 20 | with open(filepath) as f: 21 | obj = json.load(f) 22 | return obj 23 | 24 | 25 | def save_json(save_path, obj): 26 | with open(save_path, 'w') as f: 27 | json.dump(obj, f) 28 | -------------------------------------------------------------------------------- /utils/my_softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MySoftmaxLoss(nn.Module): 6 | def __init__(self, feature_dim, num_class): 7 | super(MySoftmaxLoss, self).__init__() 8 | self.in_feats = feature_dim 9 | self.W = torch.nn.Parameter(torch.randn(feature_dim, num_class)) 10 | self.cel = nn.CrossEntropyLoss() 11 | nn.init.xavier_normal_(self.W, gain=1) 12 | 13 | def forward(self, embedding, labels): 14 | assert embedding.size()[0] == labels.size()[0] 15 | assert embedding.size()[1] == self.in_feats 16 | logits = torch.mm(embedding, self.W) 17 | loss = self.cel(logits, labels) 18 | return loss 19 | -------------------------------------------------------------------------------- /loaders/img_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import torchvision 4 | 5 | 6 | class Dataset(torch.utils.data.Dataset): 7 | 8 | def __init__(self, all_image_files): 9 | self.all_image_files = all_image_files 10 | resize_size = 128 11 | self.transform_fn = torchvision.transforms.Compose([ 12 | torchvision.transforms.Resize(size=(resize_size, resize_size)), 13 | torchvision.transforms.ToTensor() 14 | ]) 15 | 16 | def __len__(self): 17 | return len(self.all_image_files) 18 | 19 | def __getitem__(self, index): 20 | file_path = self.all_image_files[index] 21 | img_PIL = Image.open(file_path) 22 | if img_PIL.mode != "RGB": 23 | img_PIL = img_PIL.convert("RGB") 24 | 25 | data = self.transform_fn(img_PIL) 26 | assert data.shape == (3, 128, 128), file_path 27 | return data, index 28 | -------------------------------------------------------------------------------- /utils/barlow_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BarlowTwinsLoss(torch.nn.Module): 5 | 6 | def __init__(self, lambda_param=5e-3): 7 | super(BarlowTwinsLoss, self).__init__() 8 | self.lambda_param = lambda_param 9 | 10 | def forward(self, z_a: torch.Tensor, z_b: torch.Tensor): 11 | # normalize repr. along the batch dimension 12 | z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD 13 | z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD 14 | 15 | N = z_a.size(0) 16 | D = z_a.size(1) 17 | 18 | # cross-correlation matrix 19 | c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD 20 | # loss 21 | c_diff = (c - torch.eye(D, device="cuda")).pow(2) # DxD 22 | # multiply off-diagonal elems of c_diff by lambda 23 | c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param 24 | 25 | loss = c_diff.sum() 26 | 27 | return loss 28 | -------------------------------------------------------------------------------- /utils/vec_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_vec_length(vec): 5 | if type(vec) == list: 6 | vec = np.array(vec) 7 | return np.sqrt(np.sum(vec * vec)) 8 | 9 | 10 | def dict2unit_dict_inplace(the_dict): 11 | for key in the_dict: 12 | vec = the_dict[key] 13 | the_len = get_vec_length(vec) 14 | the_dict[key] = vec / the_len 15 | return the_dict 16 | 17 | 18 | def assert_is_unit_tensor(tensor): 19 | npy = tensor.detach().cpu().numpy() 20 | length = get_vec_length(npy) 21 | assert np.isclose(length, 1.0) 22 | 23 | 24 | def assert_dict_unit_vector(the_dic): 25 | for key in the_dic: 26 | v = the_dic[key] 27 | the_len = get_vec_length(v) 28 | assert np.isclose(the_len, 1.0) 29 | break 30 | 31 | 32 | def get_vec_dim_in_dict(the_dic): 33 | for key in the_dic: 34 | v = the_dic[key] 35 | return len(v) 36 | 37 | 38 | def to_unit_vector(vector): 39 | return vector / get_vec_length(vector) 40 | 41 | 42 | def norm_batch_vector(matix): 43 | # matix = np.array([ 44 | # [3, 4], 45 | # [1, 1] 46 | # ]) 47 | vec_length = np.linalg.norm(matix, axis=1, keepdims=True) 48 | out = matix / vec_length 49 | # [[0.6 , 0.8 ], 50 | # [0.707, 0.707]] 51 | return out 52 | -------------------------------------------------------------------------------- /utils/model_selector.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | 4 | 5 | class ModelSelector: 6 | 7 | def __init__(self): 8 | self.history = collections.defaultdict(list) 9 | 10 | def log(self, the_dict): 11 | for key, value in the_dict.items(): 12 | self.history[key].append(value) 13 | best_info = {} 14 | for key in self.history: 15 | # valid/ms_fv 16 | # => best-valid/valid_ms_fv 17 | best_info["best-" + key] = max(self.history[key]) 18 | return best_info 19 | 20 | def is_best_model(self, indicator): 21 | assert indicator in self.history 22 | arr = self.history[indicator] 23 | return np.argmax(arr) == (len(arr) - 1) 24 | 25 | def should_stop(self, indicator, early_stop=10): 26 | arr = self.history[indicator] 27 | if len(arr) - 1 - np.argmax(arr) >= early_stop: 28 | return True 29 | return False 30 | 31 | 32 | def get_best_step_info(self, indictor, print_it=True): 33 | index = np.argmax(self.history[indictor]) 34 | ans = {} 35 | if print_it: 36 | for key in self.history: 37 | v = self.history[key][index] 38 | print("%s\t%.4f" % (key, v)) 39 | ans[key] = v 40 | return ans 41 | -------------------------------------------------------------------------------- /loaders/voice_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torchaudio 4 | import numpy as np 5 | 6 | 7 | class Dataset(torch.utils.data.Dataset): 8 | 9 | def __init__(self, all_wavs): 10 | self.all_wavs = all_wavs 11 | 12 | def __len__(self): 13 | return len(self.all_wavs) 14 | 15 | def __getitem__(self, index): 16 | return { 17 | "key": self.all_wavs[index], 18 | "data": torchaudio.load(self.all_wavs[index])[0] 19 | } 20 | 21 | 22 | def collate_fn(item_list): 23 | data_list = [i['data'] for i in item_list] 24 | the_lengths = np.array([i.shape[-1] for i in data_list]) 25 | max_len = np.max(the_lengths) 26 | len_ratio = the_lengths / max_len 27 | 28 | batch_size = len(item_list) 29 | output = torch.zeros([batch_size, max_len]) 30 | for i in range(batch_size): 31 | cur = data_list[i] 32 | cur_len = data_list[i].shape[-1] 33 | output[i, :cur_len] = cur.squeeze() 34 | 35 | len_ratio = torch.FloatTensor(len_ratio) 36 | keys = [i['key'] for i in item_list] 37 | return output, len_ratio, keys 38 | 39 | 40 | def get_loader(num_workers, batch_size, all_wavs): 41 | loader = DataLoader(Dataset(all_wavs), 42 | num_workers=num_workers, batch_size=batch_size, 43 | shuffle=False, pin_memory=True, collate_fn=collate_fn) 44 | return loader 45 | 46 | 47 | if __name__ == "__main__": 48 | pass 49 | -------------------------------------------------------------------------------- /loaders/voxceleb_loader_for_cae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import pickle_util, sample_util, worker_util 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | def get_iter(batch_size, full_length, name2face_emb, name2voice_emb): 7 | train_iter = DataLoader(DataSet(name2face_emb, name2voice_emb, full_length), 8 | batch_size=batch_size, shuffle=False, pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 9 | return train_iter 10 | 11 | 12 | class DataSet(torch.utils.data.Dataset): 13 | 14 | def __init__(self, name2face_emb, name2voice_emb, full_length): 15 | self.train_movie_list = pickle_util.read_pickle("./dataset/voxceleb/cluster/train_movie_list.pkl") 16 | self.movie2wav_path = pickle_util.read_pickle("./dataset/voxceleb/cluster/movie2wav_path.pkl") 17 | self.movie2jpg_path = pickle_util.read_pickle("./dataset/voxceleb/cluster/movie2jpg_path.pkl") 18 | 19 | self.full_length = full_length 20 | self.name2face_emb = name2face_emb 21 | self.name2voice_emb = name2voice_emb 22 | 23 | def __len__(self): 24 | return self.full_length 25 | 26 | def __getitem__(self, index): 27 | video = sample_util.random_element(self.train_movie_list) 28 | img = sample_util.random_element(self.movie2jpg_path[video]) 29 | wav = sample_util.random_element(self.movie2wav_path[video]) 30 | wav, img = self.to_tensor([wav, img]) 31 | return wav, img 32 | 33 | def to_tensor(self, path_arr): 34 | ans = [] 35 | for path in path_arr: 36 | if ".wav" in path: 37 | emb = self.name2voice_emb[path] 38 | else: 39 | emb = self.name2face_emb[path] 40 | emb = torch.FloatTensor(emb) 41 | ans.append(emb) 42 | return ans 43 | -------------------------------------------------------------------------------- /scripts/1_extract_face_emb.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from models import incep 5 | from loaders.img_loader import Dataset 6 | from utils import pickle_util 7 | import glob 8 | 9 | the_dict = {} 10 | 11 | 12 | def handle_emb_batch(all_data, batch_emb, indexies): 13 | batch_emb = batch_emb.detach().cpu().numpy().squeeze() 14 | assert len(batch_emb.shape) == 2 15 | indexies = indexies.detach().cpu().numpy().tolist() 16 | for idx, emb in zip(indexies, batch_emb): 17 | filepath = all_data[idx] 18 | the_dict[filepath] = emb 19 | 20 | 21 | def fun(num_workers, all_img_data, batch_size): 22 | start_time = time.time() 23 | the_iter = DataLoader(Dataset(all_img_data), num_workers=num_workers, batch_size=batch_size, shuffle=False, 24 | pin_memory=True) 25 | all_data = the_iter.dataset.all_image_files 26 | 27 | total_batch = int(len(all_data) / batch_size) + 1 28 | counter = 0 29 | with torch.no_grad(): 30 | for image_tensor, indexies in the_iter: 31 | counter += 1 32 | emb_vec = model(image_tensor.cuda()) 33 | handle_emb_batch(all_data, emb_vec, indexies) 34 | time_cost_h = (time.time() - start_time) / 3600.0 35 | progress = (counter + 1) / total_batch 36 | full_time = time_cost_h / progress 37 | print(counter, progress, "full:", full_time) 38 | 39 | 40 | if __name__ == '__main__': 41 | # 1.load model 42 | model = incep.InceptionResnetV1(pretrained="vggface2", classify=True) 43 | model.cuda() 44 | model.eval() 45 | 46 | # 2.get all img list 47 | all_jpgs = glob.glob("/your_path/*.jpg") 48 | 49 | # 3.processing 50 | fun(8, all_jpgs, batch_size=2048) 51 | 52 | # 4.save 53 | pickle_util.save_pickle("face_emb.pkl", the_dict) 54 | -------------------------------------------------------------------------------- /utils/eval_shortcut.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils import wb_util, model_util, pickle_util 3 | from utils import model_selector 4 | 5 | 6 | class Cut(): 7 | 8 | def __init__(self, emb_eva, model, args): 9 | self.modelSelector = model_selector.ModelSelector() 10 | self.emb_eva = emb_eva 11 | self.model = model 12 | self.args = args 13 | 14 | def eval_short_cut(self): 15 | emb_eva = self.emb_eva 16 | model = self.model 17 | modelSelector = self.modelSelector 18 | args = self.args 19 | 20 | # 1.do test 21 | valid_obj = emb_eva.do_valid(model) 22 | test_obj = emb_eva.do_full_test(model) 23 | obj = {**valid_obj, **test_obj} 24 | 25 | # 2.log 26 | wb_util.log(obj) 27 | modelSelector.log(obj) 28 | print(obj) 29 | 30 | # 3.init wandb 31 | wb_util.init(args) 32 | 33 | indicator = "valid/auc" 34 | if modelSelector.is_best_model(indicator): 35 | model_util.delete_last_saved_model() 36 | model_save_name = "auc[%.2f,%.2f]_ms[%.2f,%.2f]_map[%.2f,%.2f].pkl" % ( 37 | obj["valid/auc"] * 100, 38 | obj["test/auc"] * 100, 39 | obj["test/ms_v2f"] * 100, 40 | obj["test/ms_f2v"] * 100, 41 | obj["test/map_v2f"] * 100, 42 | obj["test/map_f2v"] * 100, 43 | ) 44 | model_save_path = os.path.join(args.model_save_folder, args.project, args.name, model_save_name) 45 | model_util.save_model(0, model, None, model_save_path) 46 | pickle_util.save_json(model_save_path + ".json", test_obj) 47 | else: 48 | print("not best model") 49 | 50 | if modelSelector.should_stop(indicator, args.early_stop): 51 | print("early_stop!") 52 | print(model_util.history_array[-1]) 53 | return True 54 | return False 55 | -------------------------------------------------------------------------------- /utils/wb_util.py: -------------------------------------------------------------------------------- 1 | # used for late-initialize wandb in case of crash before a full evaluation (which add an idle log on the monitoring panel) 2 | import wandb 3 | import os 4 | from utils import pickle_util 5 | 6 | history_logs = [] 7 | history_configs = [] 8 | is_inited = False 9 | 10 | 11 | def update_config(obj): 12 | history_configs.append(obj) 13 | if not is_inited: 14 | print("wb_util temporarily cache config") 15 | return 16 | wandb.config.update(obj) 17 | 18 | 19 | def log(obj): 20 | history_logs.append(obj) 21 | if not is_inited: 22 | print("wb_util temporarily cache log") 23 | return 24 | wandb.log(obj) 25 | 26 | 27 | def init(args): 28 | init_core(args.project, args.name, args.dryrun) 29 | 30 | 31 | def init_core(project, name, dryrun): 32 | global is_inited 33 | if is_inited: 34 | return 35 | is_inited = True 36 | 37 | if dryrun: 38 | os.environ['WANDB_MODE'] = 'dryrun' 39 | wandb.log = do_nothing 40 | wandb.save = do_nothing 41 | wandb.watch = do_nothing 42 | wandb.config = {} 43 | print("wb dryrun mode") 44 | return 45 | 46 | # read ./configs/wb_config.json 47 | filepath = "./configs/wb_config.json" 48 | assert os.path.exists(filepath), "do not have wandb config file" 49 | 50 | # assert have WB_KEY 51 | json_dict = pickle_util.read_json("./configs/wb_config.json") 52 | assert "WB_KEY" in json_dict, "wb_config.json do not have WB_KEY" 53 | WB_KEY = json_dict["WB_KEY"] 54 | 55 | # use self-hosted wb server 56 | if "WB_SERVER_URL" in json_dict: 57 | os.environ["WANDB_BASE_URL"] = json_dict["WB_SERVER_URL"] 58 | 59 | # login 60 | wandb.login(key=WB_KEY) 61 | wandb.init(project=project, name=name) 62 | print("wandb inited") 63 | 64 | # supplement config and logs 65 | for obj in history_configs: 66 | wandb.config.update(obj) 67 | 68 | for log in history_logs: 69 | wandb.log(log) 70 | 71 | 72 | def do_nothing(v): 73 | pass 74 | -------------------------------------------------------------------------------- /utils/map_evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.spatial 4 | 5 | 6 | def cos_dist(query_matrix, result_matrix): 7 | return scipy.spatial.distance.cdist(query_matrix, result_matrix, 'cosine') 8 | 9 | 10 | def fx_calc_map_label(query_matrix, result_matrix, labels, k=0, dist_method='COS'): 11 | if dist_method == 'L2': 12 | dist = scipy.spatial.distance.cdist(query_matrix, result_matrix, 'euclidean') 13 | elif dist_method == 'COS': 14 | dist = scipy.spatial.distance.cdist(query_matrix, result_matrix, 'cosine') 15 | ord = dist.argsort() 16 | numcases = dist.shape[0] 17 | if k == 0: 18 | k = numcases 19 | res = [] 20 | 21 | for i in range(numcases): 22 | order = ord[i] 23 | p = 0.0 24 | r = 0.0 25 | for j in range(k): 26 | if labels[i] == labels[order[j]]: 27 | r += 1 28 | p += (r / (j + 1)) 29 | if r > 0: 30 | res += [p / r] 31 | else: 32 | res += [0] 33 | 34 | return np.mean(res) 35 | 36 | 37 | def fx_calc_map_label_v2(dist, label, k=0): 38 | ord = dist.argsort() 39 | numcases = dist.shape[0] 40 | if k == 0: 41 | k = numcases 42 | res = [] 43 | 44 | for i in range(numcases): 45 | order = ord[i] 46 | p = 0.0 47 | r = 0.0 48 | for j in range(k): 49 | if label[i] == label[order[j]]: 50 | r += 1 51 | p += (r / (j + 1)) 52 | if r > 0: 53 | res += [p / r] 54 | else: 55 | res += [0] 56 | 57 | return np.mean(res) 58 | 59 | 60 | def fx_calc_map_label_v3(ord, label, k=0): 61 | numcases = ord.shape[0] 62 | if k == 0: 63 | k = numcases 64 | res = [] 65 | 66 | for i in range(numcases): 67 | order = ord[i] 68 | p = 0.0 69 | r = 0.0 70 | for j in range(k): 71 | if label[i] == label[order[j]]: 72 | r += 1 73 | p += (r / (j + 1)) 74 | if r > 0: 75 | res += [p / r] 76 | else: 77 | res += [0] 78 | 79 | return np.mean(res) 80 | 81 | 82 | if __name__ == "__main__": 83 | pass 84 | -------------------------------------------------------------------------------- /utils/distance_util.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import numpy as np 3 | import scipy.spatial 4 | 5 | 6 | def calc_inter_distance(embedding): 7 | matrix_dot = numpy.dot(embedding, numpy.transpose(embedding)) 8 | # (batch,batch) 9 | 10 | l2_norm_squired = numpy.diagonal(matrix_dot) 11 | 12 | distance_matrix_squired = numpy.expand_dims(l2_norm_squired, axis=0) + numpy.expand_dims(l2_norm_squired, 13 | axis=1) - 2.0 * matrix_dot 14 | 15 | distance_matrix = numpy.maximum(distance_matrix_squired, 0.0) 16 | distance_matrix = numpy.sqrt(distance_matrix) 17 | return distance_matrix 18 | 19 | 20 | def calc_matrix_distance(matrix_a, matrix_b): 21 | # matrix_a: (batch_a,dim) 22 | # matrix_b: (batch_b,dim) 23 | 24 | matrix_dot = numpy.dot(matrix_a, numpy.transpose(matrix_b)) 25 | # (batch_a,batch_b) 26 | 27 | a_square = numpy.sum(matrix_a * matrix_a, axis=1) 28 | # (batch_a) 29 | 30 | b_square = numpy.sum(matrix_b * matrix_b, axis=1) 31 | # (batch_b) 32 | 33 | a_square_2d = numpy.expand_dims(a_square, axis=1) 34 | # (1,batch_a) 35 | 36 | b_square_2d = numpy.expand_dims(b_square, axis=0) 37 | # (batch_b,1) 38 | 39 | distance_matrix_squired = a_square_2d - 2.0 * matrix_dot + b_square_2d 40 | 41 | distance_matrix = numpy.maximum(distance_matrix_squired, 0.0) 42 | distance_matrix = numpy.sqrt(distance_matrix) 43 | return distance_matrix 44 | 45 | 46 | def parallel_distance(a, b): 47 | a = numpy.array(a) 48 | b = numpy.array(b) 49 | 50 | assert len(a) == len(b) 51 | 52 | c = a - b 53 | return numpy.sqrt(numpy.sum(c * c, axis=1)) 54 | 55 | 56 | def parallel_distance_cosine_based_distance(a, b): 57 | assert len(a.shape) == 2 58 | assert a.shape == b.shape 59 | ab = np.sum(a * b, axis=1) 60 | # (batch_size,) 61 | 62 | a_norm = np.sqrt(np.sum(a * a, axis=1)) 63 | b_norm = np.sqrt(np.sum(b * b, axis=1)) 64 | cosine = ab / (a_norm * b_norm) 65 | 66 | dist = 1 - cosine 67 | # 0~2 68 | return dist 69 | 70 | 71 | def distance_of_2point(a, b): 72 | return parallel_distance([a], [b])[0] 73 | 74 | 75 | def cosine_similarity(v1, v2): 76 | return (1 - scipy.spatial.distance.cosine(v1, v2) + 1) / 2.0 77 | -------------------------------------------------------------------------------- /baseline/3_barlow.py: -------------------------------------------------------------------------------- 1 | # add python path 2 | import sys 3 | 4 | sys.path.append("/".join(sys.path[0].split("/")[:-1])) 5 | 6 | from configs import config 7 | from utils import my_parser, seed_util, wb_util, model_util 8 | from utils.eval_shortcut import Cut 9 | from models import my_model 10 | import torch 11 | from loaders import voxceleb_loader_for_cae 12 | from utils import barlow_loss 13 | import os 14 | 15 | 16 | def do_step(epoch, step, data): 17 | optimizer.zero_grad() 18 | data = [i.cuda() for i in data] 19 | voice_data, face_data = data 20 | v_emb, f_emb = model(voice_data, face_data) 21 | loss = fun_barlow(v_emb, f_emb) 22 | loss.backward() 23 | optimizer.step() 24 | return loss.item(), {} 25 | 26 | 27 | def train(): 28 | step = 0 29 | model.train() 30 | 31 | for epo in range(args.epoch): 32 | wb_util.log({"train/epoch": epo}) 33 | for data in train_iter: 34 | loss, info = do_step(epo, step, data) 35 | step += 1 36 | if step % 50 == 0: 37 | obj = { 38 | "train/step": step, 39 | "train/loss": loss, 40 | } 41 | obj = {**obj, **info} 42 | print(obj) 43 | wb_util.log(obj) 44 | 45 | if step > 0 and step % args.eval_step == 0: 46 | if eval_cut.eval_short_cut(): 47 | return 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = my_parser.MyParser(epoch=100, batch_size=256, seed=4, model_save_folder=config.model_save_folder, early_stop=10) 52 | parser.custom({ 53 | "batch_per_epoch": 500, 54 | "eval_step": 250, 55 | "load_model": "" 56 | }) 57 | parser.use_wb("sl_project", "barlow") 58 | args = parser.parse() 59 | seed_util.set_seed(args.seed) 60 | 61 | # we found no-shared setting could have better performance 62 | model = my_model.Encoder(shared=False).cuda() 63 | if args.load_model is not None and os.path.exists(args.load_model): 64 | model_util.load_model(args.load_model, model, strict=True) 65 | 66 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 67 | from configs.config import face_emb_dict, voice_emb_dict, emb_eva 68 | 69 | train_iter = voxceleb_loader_for_cae.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size, 70 | face_emb_dict, voice_emb_dict) 71 | eval_cut = Cut(emb_eva, model, args) 72 | fun_barlow = barlow_loss.BarlowTwinsLoss() 73 | train() 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code for *Self-Lifting: A Novel Framework For Unsupervised Voice-Face Association Learning,ICMR,2022* 2 | 3 | 4 | 5 | ## Requirements 6 | 7 | ``` 8 | faiss==1.7.1 9 | pytorch==1.8.1 10 | pytorch-metric-learning==0.9.96 11 | wandb==0.12.10 12 | ``` 13 | 14 | 15 | 16 | ## Dataset 17 | 18 | Download file from [Baidu Disk](https://pan.baidu.com/s/1yCvVOytilWYHdG4dHYHnEw) (code:`9d0a`) or [GoogleDrive](https://drive.google.com/file/d/1NZLfYrvqoa7XGJxITYE0v0SRYN-m33hv/view?usp=sharing) and unzip it to the project root. 19 | The `dataset` folder structure is shown below: 20 | 21 | ``` 22 | dataset/ 23 | └── voxceleb 24 | ├── cluster 25 | │   ├── movie2jpg_path.pkl 26 | │   ├── movie2wav_path.pkl 27 | │   └── train_movie_list.pkl 28 | ├── eval 29 | │   ├── test_matching_10.pkl 30 | │   ├── test_matching_g.pkl 31 | │   ├── test_matching.pkl 32 | │   ├── test_retrieval.pkl 33 | │   ├── test_verification.pkl 34 | │   ├── test_verification_g.pkl 35 | │   └── valid_verification.pkl 36 | ├── face_input.pkl 37 | └── voice_input.pkl 38 | ``` 39 | 40 | 41 | 42 | # Train 43 | 44 | **1. Train Self-Lifting Framework:** 45 | 46 | ``python sl.py`` 47 | 48 | 49 | 50 | **2. Train a baseline:** 51 | 52 | ``python baseline/1_ccae.py`` 53 | 54 | ``python baseline/2_deepcluster.py`` 55 | 56 | ``python baseline/3_barlow.py`` 57 | 58 | 59 | 60 | --- 61 | 62 | *use [wandb](https://wandb.ai) to view the training process:* 63 | 64 | 1. Create `wb_config.json` file in the `./configs` folder, using the following content: 65 | 66 | ``` 67 | { 68 | "WB_KEY": "Your wandb auth key" 69 | } 70 | ``` 71 | 72 | 73 | 74 | 2. add `--dryrun=False` to the training command, for example: `python sl.py --dryrun=False` 75 | 76 | 77 | 78 | ## Model Checkpoints 79 | 80 | You can get the final model checkpoints at [here](https://pan.baidu.com/s/1Ol0FtaXUm8BticDDNLJaxg) (code:`4ae6`) 81 | 82 | 83 | 84 | ## Backbone Models 85 | 86 | The Inception-V1 model is based on [facenet_pytorch](https://github.com/timesler/facenet-pytorch). 87 | 88 | The ECAPA-TDNN model is based on [SpeechBrain](https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb). 89 | While this model is trained with Vox1+Vox2, 90 | thus we retrained one only with Vox2. 91 | The checkpoint can be found at [baidu disk](https://pan.baidu.com/s/18vDu8_XxxuplW-k6i4xHZQ?pwd=fdra) or [google-drive](https://drive.google.com/file/d/1SynmHLSva8mkaVlDlSQnMVqIOLSYy36d/view?usp=sharing). 92 | 93 | We also offer demo scripts for extracting the embeddings in `scripts/`. 94 | 95 | -------------------------------------------------------------------------------- /utils/model_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | import sys 5 | from utils import pickle_util 6 | 7 | history_array = [] 8 | 9 | 10 | def save_model(epoch, model, optimizer, file_save_path): 11 | dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir)) 12 | if not os.path.exists(dirpath): 13 | print("mkdir:", dirpath) 14 | os.makedirs(dirpath) 15 | 16 | opti = None 17 | if optimizer is not None: 18 | opti = optimizer.state_dict() 19 | 20 | torch.save(obj={ 21 | 'epoch': epoch, 22 | 'model': model.state_dict(), 23 | 'optimizer': opti, 24 | }, f=file_save_path) 25 | 26 | history_array.append(file_save_path) 27 | 28 | 29 | def delete_last_saved_model(): 30 | if len(history_array) == 0: 31 | return 32 | last_path = history_array.pop() 33 | if os.path.exists(last_path): 34 | os.remove(last_path) 35 | print("delete model:", last_path) 36 | 37 | if os.path.exists(last_path + ".json"): 38 | os.remove(last_path + ".json") 39 | 40 | 41 | def load_model(resume_path, model, optimizer=None, strict=True): 42 | checkpoint = torch.load(resume_path) 43 | start_epoch = checkpoint['epoch'] + 1 44 | model.load_state_dict(checkpoint['model'], strict=strict) 45 | if optimizer is not None: 46 | optimizer.load_state_dict(checkpoint['optimizer']) 47 | print("checkpoint loaded!") 48 | return start_epoch 49 | 50 | 51 | def save_model_v2(model, args, model_save_name): 52 | model_save_path = os.path.join(args.model_save_folder, args.project, args.name, model_save_name) 53 | save_model(0, model, None, model_save_path) 54 | print("save:", model_save_path) 55 | 56 | 57 | def save_project_info(args): 58 | run_info = { 59 | "cmd_str": ' '.join(sys.argv[1:]), 60 | "args": vars(args), 61 | } 62 | 63 | name = "run_info.json" 64 | folder = os.path.join(args.model_save_folder, args.project, args.name) 65 | if not os.path.exists(folder): 66 | os.makedirs(folder) 67 | 68 | json_file_path = os.path.join(folder, name) 69 | with open(json_file_path, "w") as f: 70 | json.dump(run_info, f) 71 | 72 | print("save_project_info:", json_file_path) 73 | 74 | 75 | def get_pkl_json(folder): 76 | names = [i for i in os.listdir(folder) if ".pkl.json" in i] 77 | assert len(names) == 1 78 | json_path = os.path.join(folder, names[0]) 79 | obj = pickle_util.read_json(json_path) 80 | return obj 81 | -------------------------------------------------------------------------------- /utils/my_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | 4 | 5 | class MyParser(): 6 | 7 | def __init__(self, epoch, batch_size, worker=0, seed=2526, 8 | max_hour=100, early_stop=5, lr=1e-3, 9 | model_save_folder=None): 10 | super(MyParser, self).__init__() 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--seed", default=seed, type=int) 13 | parser.add_argument("--worker", default=worker, type=int) 14 | parser.add_argument("--epoch", default=epoch, type=int) 15 | parser.add_argument("--batch_size", default=batch_size, type=int) 16 | parser.add_argument("--max_hour", default=max_hour, type=int) 17 | parser.add_argument("--early_stop", default=early_stop, type=int) 18 | parser.add_argument("--lr", default=lr, type=float) 19 | parser.add_argument("--model_save_folder", default=model_save_folder, type=str) 20 | self.core_parser = parser 21 | 22 | def use_wb(self, project, name, dryrun=True): 23 | self.project = project 24 | self.name = name 25 | self.dryrun = dryrun 26 | parser = self.core_parser 27 | parser.add_argument("--project", default=self.project, type=str) 28 | parser.add_argument("--name", default=self.name, type=str) 29 | parser.add_argument("--dryrun", default=self.dryrun, type=ast.literal_eval) 30 | 31 | def custom(self, the_dict): 32 | parser = self.core_parser 33 | for key in the_dict: 34 | value = the_dict[key] 35 | if type(value) == str or value is None: 36 | parser.add_argument("--" + key, default=value, type=str) 37 | elif type(value) == int: 38 | parser.add_argument("--" + key, default=value, type=int) 39 | elif type(value) == float: 40 | parser.add_argument("--" + key, default=value, type=float) 41 | elif type(value) == bool: 42 | parser.add_argument("--" + key, default=value, type=ast.literal_eval) 43 | else: 44 | raise Exception("unsupported type:" + type(value)) 45 | 46 | def parse(self): 47 | args = parse_it(self.core_parser) 48 | return args 49 | 50 | def show(self): 51 | the_dic = vars(self.parse()) 52 | keys = list(the_dic.keys()) 53 | keys.sort() 54 | for key in keys: 55 | print(key, ":", the_dic[key]) 56 | 57 | 58 | def parse_it(parser): 59 | args = parser.parse_args() 60 | return args 61 | 62 | 63 | if __name__ == "__main__": 64 | pass 65 | -------------------------------------------------------------------------------- /baseline/1_ccae.py: -------------------------------------------------------------------------------- 1 | # add python path 2 | import sys 3 | sys.path.append("/".join(sys.path[0].split("/")[:-1])) 4 | 5 | 6 | from utils import my_parser, seed_util, wb_util, model_util 7 | from utils.eval_shortcut import Cut 8 | from models import my_model 9 | import torch 10 | from loaders import voxceleb_loader_for_cae 11 | from configs.config import face_emb_dict, voice_emb_dict, emb_eva, model_save_folder 12 | import os 13 | 14 | 15 | def do_step(epoch, step, data): 16 | optimizer.zero_grad() 17 | data = [i.cuda() for i in data] 18 | voice_data, face_data = data 19 | loss_emb, loss_dec = model(voice_data, face_data) 20 | loss = loss_emb + loss_dec 21 | loss.backward() 22 | optimizer.step() 23 | return loss.item(), {} 24 | 25 | 26 | def train(): 27 | step = 0 28 | model.train() 29 | 30 | for epo in range(args.epoch): 31 | wb_util.log({"train/epoch": epo}) 32 | for data in train_iter: 33 | loss, info = do_step(epo, step, data) 34 | step += 1 35 | if step % 50 == 0: 36 | obj = { 37 | "train/step": step, 38 | "train/loss": loss, 39 | } 40 | obj = {**obj, **info} 41 | print(obj) 42 | wb_util.log(obj) 43 | 44 | if step > 0 and step % args.eval_step == 0: 45 | if eval_cut.eval_short_cut(): 46 | return 47 | 48 | 49 | if __name__ == "__main__": 50 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder=model_save_folder, early_stop=10) 51 | parser.custom({ 52 | "batch_per_epoch": 500, 53 | "eval_step": 250, 54 | "load_model": "load_model", 55 | }) 56 | parser.use_wb("sl_project", "CCAE") 57 | args = parser.parse() 58 | seed_util.set_seed(args.seed) 59 | 60 | model = my_model.CAE().cuda() 61 | 62 | if args.load_model is not None and os.path.exists(args.load_model): 63 | tmp_model = my_model.Encoder(shared=True).cuda() 64 | model_util.load_model(args.load_model, tmp_model, strict=True) 65 | model.encoder = tmp_model 66 | model.face_encoder = tmp_model.face_encoder 67 | model.voice_encoder = tmp_model.voice_encoder 68 | 69 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 70 | 71 | train_iter = voxceleb_loader_for_cae.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size, 72 | face_emb_dict, voice_emb_dict) 73 | eval_cut = Cut(emb_eva, model, args) 74 | train() 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/* 2 | outputs/* 3 | jupyters/ 4 | results/ 5 | configs/wb*.json 6 | 临时说明.txt 7 | 8 | .DS_Store 9 | 10 | #idea 11 | .idea 12 | wandb/ 13 | 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | cover/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | .pybuilder/ 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | # For a library or package, you might want to ignore these files since the code is 101 | # intended to run in multiple environments; otherwise, check them in: 102 | # .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | 148 | # pytype static type analyzer 149 | .pytype/ 150 | 151 | # Cython debug symbols 152 | cython_debug/ -------------------------------------------------------------------------------- /scripts/2_exract_voice_emb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import model_util, pickle_util 3 | from loaders import voice_loader 4 | import time 5 | import glob 6 | 7 | 8 | def generate_emb_dict(wav_list, batch_size=16): 9 | loader = voice_loader.get_loader(4, batch_size, wav_list) 10 | the_dict = {} 11 | counter = 0 12 | start_time = time.time() 13 | for data, lens, keys in loader: 14 | try: 15 | core_step(data, lens, model, keys, the_dict) 16 | except Exception as e: 17 | print("error:", e) 18 | continue 19 | 20 | counter += 1 21 | if counter % 10 == 0: 22 | processed = len(the_dict) 23 | progress = processed / len(loader.dataset) 24 | time_cost = time.time() - start_time 25 | total_time = time_cost / progress / 3600.0 26 | print("progress:", progress, "total_time:", total_time) 27 | return the_dict 28 | 29 | 30 | def core_step(wavs, lens, model, keys, the_dict): 31 | with torch.no_grad(): 32 | feats = fun_compute_features(wavs.cuda()) 33 | feats = fun_mean_var_norm(feats, lens) 34 | embedding = model(feats, lens) 35 | embedding_npy = embedding.detach().cpu().numpy().squeeze() 36 | # (batch,192) 37 | for key, emb in zip(keys, embedding_npy): 38 | the_dict[key] = emb 39 | 40 | 41 | def get_ecapa_model(): 42 | from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN 43 | n_mels = 80 44 | channels = [1024, 1024, 1024, 1024, 3072] 45 | kernel_sizes = [5, 3, 3, 3, 1] 46 | dilations = [1, 2, 3, 4, 1] 47 | attention_channels = 128 48 | lin_neurons = 192 49 | model = ECAPA_TDNN(input_size=n_mels, channels=channels, 50 | kernel_sizes=kernel_sizes, dilations=dilations, 51 | attention_channels=attention_channels, 52 | lin_neurons=lin_neurons 53 | ) 54 | # print(model) 55 | return model 56 | 57 | 58 | def get_fun_compute_features(): 59 | from speechbrain.lobes.features import Fbank 60 | 61 | n_mels = 80 62 | left_frames = 0 63 | right_frames = 0 64 | deltas = False 65 | compute_features = Fbank(n_mels=n_mels, left_frames=left_frames, right_frames=right_frames, deltas=deltas) 66 | return compute_features 67 | 68 | 69 | def get_fun_norm(): 70 | from speechbrain.processing.features import InputNormalization 71 | return InputNormalization(norm_type="sentence", std_norm=False) 72 | 73 | 74 | if __name__ == "__main__": 75 | # 1.get model 76 | model = get_ecapa_model().cuda() 77 | pkl_path = "ecapa_acc0.9854.pkl" 78 | model_util.load_model(pkl_path, model) 79 | model.eval() 80 | 81 | fun_compute_features = get_fun_compute_features().cuda() 82 | fun_mean_var_norm = get_fun_norm().cuda() 83 | 84 | # 2.get all wav files 85 | wav_list = glob.glob("/your_path/*.wav") 86 | 87 | the_dict = generate_emb_dict(wav_list) 88 | pickle_util.save_pickle("voice_emb.pkl", the_dict) 89 | -------------------------------------------------------------------------------- /baseline/2_deepcluster.py: -------------------------------------------------------------------------------- 1 | # add python path 2 | import sys 3 | sys.path.append("/".join(sys.path[0].split("/")[:-1])) 4 | from utils import my_parser, seed_util, wb_util, deepcluster_util, model_util, my_softmax_loss 5 | from utils.eval_shortcut import Cut 6 | from models import my_model 7 | from configs import config 8 | import torch 9 | from loaders import voxceleb_loader_for_deepcluster 10 | 11 | 12 | def do_step(epoch, step, data): 13 | optimizer.zero_grad() 14 | data = [i.cuda() for i in data] 15 | voice_data, face_data, label = data 16 | v_emb, f_emb = model(voice_data, face_data) 17 | emb = torch.cat([v_emb, f_emb], dim=0) 18 | label2 = torch.cat([label, label], dim=0).squeeze() 19 | loss = fun_loss_metric(emb, label2) 20 | loss.backward() 21 | optimizer.step() 22 | info = { 23 | } 24 | return loss.item(), info 25 | 26 | 27 | def train(): 28 | step = 0 29 | model.train() 30 | 31 | for epo in range(args.epoch): 32 | wb_util.log({"train/epoch": epo}) 33 | movie2label, _ = deepcluster_util.do_cluster(ordered_iter, args.ncentroids, model=model, input_emb_type=args.input_emb_type) 34 | train_iter = voxceleb_loader_for_deepcluster.get_iter(args.batch_size, args.batch_per_epoch * args.batch_size, 35 | face_emb_dict, voice_emb_dict, movie2label) 36 | 37 | for data in train_iter: 38 | loss, info = do_step(epo, step, data) 39 | step += 1 40 | if step % 50 == 0: 41 | obj = { 42 | "train/step": step, 43 | "train/loss": loss, 44 | } 45 | obj = {**obj, **info} 46 | print(obj) 47 | wb_util.log(obj) 48 | 49 | if step > 0 and step % args.eval_step == 0: 50 | if eval_cut.eval_short_cut(): 51 | return 52 | 53 | 54 | if __name__ == "__main__": 55 | 56 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder=config.model_save_folder, early_stop=10) 57 | parser.custom({ 58 | "ncentroids": 1000, 59 | "batch_per_epoch": 250, 60 | "eval_step": 250, 61 | "input_emb_type": "all", 62 | "load_model": "", 63 | "shared": False, 64 | }) 65 | parser.use_wb("sl_project", "deepcluster") 66 | args = parser.parse() 67 | seed_util.set_seed(args.seed) 68 | 69 | # model: 70 | model = my_model.Encoder(shared=args.shared).cuda() 71 | import os 72 | 73 | if args.load_model is not None and os.path.exists(args.load_model): 74 | model_util.load_model(args.load_model, model, strict=True) 75 | 76 | fun_loss_metric = my_softmax_loss.MySoftmaxLoss(128, num_class=args.ncentroids).cuda() 77 | model_params = list(model.parameters()) + list(fun_loss_metric.parameters()) 78 | optimizer = torch.optim.Adam(model_params, lr=args.lr) 79 | 80 | # 3. loader 81 | from configs.config import face_emb_dict, voice_emb_dict, emb_eva 82 | 83 | eval_cut = Cut(emb_eva, model, args) 84 | ordered_iter = voxceleb_loader_for_deepcluster.get_ordered_iter(args.batch_size, face_emb_dict, voice_emb_dict) 85 | 86 | train() 87 | -------------------------------------------------------------------------------- /loaders/voxceleb_loader_for_deepcluster.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from utils import pickle_util, sample_util, worker_util, vec_util 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | def get_iter(batch_size, full_length, name2face_emb, name2voice_emb, movie2label): 8 | train_iter = DataLoader(DataSet(name2face_emb, name2voice_emb, full_length, movie2label), 9 | batch_size=batch_size, shuffle=False, pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 10 | return train_iter 11 | 12 | 13 | class DataSet(torch.utils.data.Dataset): 14 | 15 | def __init__(self, name2face_emb, name2voice_emb, full_length, movie2label): 16 | self.train_movie_list = list(movie2label.keys()) 17 | 18 | self.movie2wav_path = pickle_util.read_pickle("./dataset/voxceleb/cluster/movie2wav_path.pkl") 19 | self.movie2jpg_path = pickle_util.read_pickle("./dataset/voxceleb/cluster/movie2jpg_path.pkl") 20 | 21 | self.full_length = full_length 22 | self.name2face_emb = name2face_emb 23 | self.name2voice_emb = name2voice_emb 24 | self.movie2label = movie2label 25 | 26 | def __len__(self): 27 | return self.full_length 28 | 29 | def __getitem__(self, index): 30 | movie = sample_util.random_element(self.train_movie_list) 31 | label = self.movie2label[movie] 32 | 33 | img = sample_util.random_element(self.movie2jpg_path[movie]) 34 | wav = sample_util.random_element(self.movie2wav_path[movie]) 35 | wav, img = self.to_tensor([wav, img]) 36 | 37 | return wav, img, torch.LongTensor([label]) 38 | 39 | def to_tensor(self, path_arr): 40 | ans = [] 41 | for path in path_arr: 42 | if ".wav" in path: 43 | emb = self.name2voice_emb[path] 44 | else: 45 | emb = self.name2face_emb[path] 46 | emb = torch.FloatTensor(emb) 47 | ans.append(emb) 48 | return ans 49 | 50 | 51 | class OredredDataSet(torch.utils.data.Dataset): 52 | 53 | def __init__(self, name2face_emb, name2voice_emb): 54 | self.train_movie_list = pickle_util.read_pickle("./dataset/voxceleb/cluster/train_movie_list.pkl") 55 | 56 | self.movie2wav_path = pickle_util.read_pickle("./dataset/voxceleb/cluster/movie2wav_path.pkl") 57 | self.movie2jpg_path = pickle_util.read_pickle("./dataset/voxceleb/cluster/movie2jpg_path.pkl") 58 | 59 | self.name2face_emb = name2face_emb 60 | self.name2voice_emb = name2voice_emb 61 | 62 | def __len__(self): 63 | return len(self.train_movie_list) 64 | 65 | def __getitem__(self, index): 66 | movie = self.train_movie_list[index] 67 | img = np.mean([self.name2face_emb[i] for i in self.movie2jpg_path[movie]], axis=0) 68 | wav = np.mean([self.name2voice_emb[i] for i in self.movie2wav_path[movie]], axis=0) 69 | 70 | img = vec_util.to_unit_vector(img) 71 | wav = vec_util.to_unit_vector(wav) 72 | return torch.FloatTensor(wav), torch.FloatTensor(img), torch.LongTensor([index]) 73 | 74 | 75 | def get_ordered_iter(batch_size, name2face_emb, name2voice_emb): 76 | train_iter = DataLoader(OredredDataSet(name2face_emb, name2voice_emb), 77 | batch_size=batch_size, shuffle=False, 78 | pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 79 | return train_iter 80 | -------------------------------------------------------------------------------- /loaders/voxceleb_cluster_ordered_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from utils import pickle_util, worker_util 4 | from utils.path_util import look_up 5 | 6 | from torch.utils.data import DataLoader 7 | import collections 8 | 9 | 10 | def extract_embeddings(name2face_emb, name2voice_emb, model): 11 | face_iter = get_ordered_iter(1024, name2face_emb, name2voice_emb, is_face=True) 12 | movies, emb_face = extract_embeddings_core(face_iter, model.face_encoder) 13 | 14 | voice_iter = get_ordered_iter(1024, name2face_emb, name2voice_emb, is_face=False) 15 | movies2, emb_voice = extract_embeddings_core(voice_iter, model.voice_encoder) 16 | 17 | assert len(movies2) == len(movies) 18 | final_emb = np.hstack([emb_voice, emb_face]) 19 | return movies, final_emb, emb_voice, emb_face 20 | 21 | 22 | def extract_embeddings_core(ordered_iter, encoder): 23 | # 1.extract embedding 24 | encoder.eval() 25 | the_dict = collections.defaultdict(list) 26 | for data in ordered_iter: 27 | with torch.no_grad(): 28 | batch_movie, tensor = data 29 | # ipdb.set_trace() 30 | batch_emb = encoder(tensor.cuda()).detach().cpu().numpy() 31 | for emb, movie in zip(batch_emb, batch_movie): 32 | the_dict[movie].append(emb) 33 | encoder.train() 34 | 35 | # 2. merge embedding by video 36 | final_dict = {} 37 | for key, arr in the_dict.items(): 38 | # arr:[batch,emb] 39 | final_dict[key] = np.mean(arr, axis=0) 40 | 41 | # 3.sort 42 | videos = list(final_dict.keys()) 43 | videos.sort() 44 | emb_array = np.array([final_dict[key] for key in videos]) 45 | 46 | return videos, emb_array 47 | 48 | 49 | def get_ordered_iter(batch_size, name2face_emb, name2voice_emb, is_face): 50 | train_iter = DataLoader(OredredDataSet(is_face, name2face_emb, name2voice_emb), 51 | batch_size=batch_size, shuffle=False, 52 | pin_memory=True, worker_init_fn=worker_util.worker_init_fn) 53 | return train_iter 54 | 55 | 56 | class OredredDataSet(torch.utils.data.Dataset): 57 | 58 | def __init__(self, is_face, name2face_emb, name2voice_emb): 59 | train_movie_list = pickle_util.read_pickle(look_up("./dataset/voxceleb/cluster/train_movie_list.pkl")) 60 | 61 | movie2wav_path = pickle_util.read_pickle(look_up("./dataset/voxceleb/cluster/movie2wav_path.pkl")) 62 | # ['id10001/Y8hIVOBuels/00008.wav',.... 63 | 64 | movie2jpg_path = pickle_util.read_pickle(look_up("./dataset/voxceleb/cluster/movie2jpg_path.pkl")) 65 | # ['A.J._Buckley/1.6/Y8hIVOBuels/0005175.jpg',.... 66 | 67 | # 3.数据 68 | all_jpgs = [] 69 | all_wavs = [] 70 | for movie in train_movie_list: 71 | for short_path in movie2jpg_path[movie]: 72 | all_jpgs.append([movie, short_path]) 73 | 74 | for short_path in movie2wav_path[movie]: 75 | all_wavs.append([movie, short_path]) 76 | 77 | if is_face: 78 | self.data = all_jpgs 79 | self.name2emb = name2face_emb 80 | else: 81 | self.data = all_wavs 82 | self.name2emb = name2voice_emb 83 | 84 | def __len__(self): 85 | return len(self.data) 86 | 87 | def __getitem__(self, index): 88 | movie, short_path = self.data[index] 89 | tensor = torch.FloatTensor(self.name2emb[short_path]) 90 | return movie, tensor 91 | -------------------------------------------------------------------------------- /models/my_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Encoder(torch.nn.Module): 5 | def __init__(self, voice_size=192, face_size=512, embedding_size=128, shared=True): 6 | super(Encoder, self).__init__() 7 | # input->drop-fc256-relu-[fc256-relu-fc128] 8 | mid_dim = 256 9 | 10 | def create_front(input_size): 11 | return torch.nn.Sequential( 12 | torch.nn.Dropout(), 13 | torch.nn.Linear(input_size, mid_dim), 14 | torch.nn.ReLU(), 15 | ) 16 | 17 | def create_rare(): 18 | return torch.nn.Sequential( 19 | torch.nn.Linear(mid_dim, mid_dim), 20 | torch.nn.ReLU(), 21 | torch.nn.Linear(mid_dim, embedding_size), 22 | ) 23 | 24 | face_rare = create_rare() 25 | if shared: 26 | voice_rare = face_rare 27 | else: 28 | voice_rare = create_rare() 29 | 30 | self.face_encoder = torch.nn.Sequential( 31 | create_front(face_size), 32 | face_rare 33 | ) 34 | self.voice_encoder = torch.nn.Sequential( 35 | create_front(voice_size), 36 | voice_rare 37 | ) 38 | 39 | def forward(self, voice_data, face_data): 40 | v_emb = self.voice_encoder(voice_data) 41 | f_emb = self.face_encoder(face_data) 42 | return v_emb, f_emb 43 | 44 | 45 | class Decoder(torch.nn.Module): 46 | def __init__(self, voice_size=192, face_size=512, embedding_size=128, shared=True): 47 | super(Decoder, self).__init__() 48 | # 128->Drop-fc256-Relu-fc256-Relu-xxx 49 | 50 | mid_dim = 256 51 | 52 | def create_rare(): 53 | return torch.nn.Sequential( 54 | torch.nn.Dropout(), 55 | torch.nn.Linear(embedding_size, mid_dim), 56 | torch.nn.ReLU(), 57 | torch.nn.Linear(mid_dim, mid_dim), 58 | torch.nn.ReLU(), 59 | ) 60 | 61 | face_rare = create_rare() 62 | if shared: 63 | voice_rare = face_rare 64 | else: 65 | voice_rare = create_rare() 66 | 67 | self.dec_face = torch.nn.Sequential( 68 | face_rare, 69 | torch.nn.Linear(mid_dim, voice_size), 70 | ) 71 | self.dec_voice = torch.nn.Sequential( 72 | voice_rare, 73 | torch.nn.Linear(mid_dim, face_size) 74 | ) 75 | 76 | def forward(self, v_emb, f_emb): 77 | f_out = self.dec_voice(v_emb) 78 | v_out = self.dec_face(f_emb) 79 | return v_out, f_out 80 | 81 | 82 | class CAE(torch.nn.Module): 83 | def __init__(self): 84 | super(CAE, self).__init__() 85 | self.encoder = Encoder(shared=True) 86 | self.face_encoder = self.encoder.face_encoder 87 | self.voice_encoder = self.encoder.voice_encoder 88 | 89 | self.decoder = Decoder(shared=False) 90 | self.fun_loss_mse = torch.nn.MSELoss() 91 | 92 | def forward(self, voice_data, face_data, only_emb=False): 93 | v_emb = self.voice_encoder(voice_data) 94 | f_emb = self.face_encoder(face_data) 95 | if only_emb: 96 | return v_emb, f_emb 97 | v_out, f_out = self.decoder(v_emb, f_emb) 98 | 99 | fun_loss_mse = self.fun_loss_mse 100 | loss_dec = fun_loss_mse(voice_data, v_out) + fun_loss_mse(face_data, f_out) 101 | loss_emb = fun_loss_mse(v_emb, f_emb) 102 | return loss_emb, loss_dec 103 | -------------------------------------------------------------------------------- /sl.py: -------------------------------------------------------------------------------- 1 | from utils import my_parser, seed_util, wb_util, deepcluster_util 2 | from utils.eval_shortcut import Cut 3 | from models import my_model 4 | import torch 5 | from loaders import voxceleb_cluster_ordered_loader 6 | from loaders import voxceleb_loader_for_deepcluster 7 | from pytorch_metric_learning import losses 8 | from utils import model_util 9 | from configs.config import face_emb_dict, voice_emb_dict, emb_eva, model_save_folder 10 | import os 11 | 12 | 13 | def do_step(epoch, step, data): 14 | optimizer.zero_grad() 15 | data = [i.cuda() for i in data] 16 | voice_data, face_data, label = data 17 | v_emb, f_emb = model(voice_data, face_data) 18 | emb = torch.cat([v_emb, f_emb], dim=0) 19 | label2 = torch.cat([label, label], dim=0).squeeze() 20 | 21 | if args.ratio_mse > 0: 22 | loss_mse = fun_loss_mse(v_emb, f_emb) * args.ratio_mse 23 | else: 24 | loss_mse = 0 25 | 26 | loss = fun_loss_metric(emb, label2) + loss_mse 27 | loss.backward() 28 | optimizer.step() 29 | info = { 30 | } 31 | return loss.item(), info 32 | 33 | 34 | def get_ratio(loss, total_loss): 35 | if type(loss) == torch.Tensor: 36 | loss = loss.item() 37 | return loss / total_loss.item() 38 | 39 | 40 | def train(): 41 | step = 0 42 | model.train() 43 | 44 | for epo in range(args.epoch): 45 | wb_util.log({"train/epoch": epo}) 46 | # do cluster 47 | all_keys, all_emb, all_emb_v, all_emb_f = voxceleb_cluster_ordered_loader.extract_embeddings(face_emb_dict, voice_emb_dict, model) 48 | movie2label, _ = deepcluster_util.do_cluster_v2(all_keys, all_emb, all_emb_v, all_emb_f, args.ncentroids, input_emb_type=args.cluster_type) 49 | # create dataset 50 | train_iter = voxceleb_loader_for_deepcluster.get_iter(args.batch_size, 51 | args.batch_per_epoch * args.batch_size, 52 | face_emb_dict, 53 | voice_emb_dict, 54 | movie2label) 55 | 56 | for data in train_iter: 57 | loss, info = do_step(epo, step, data) 58 | step += 1 59 | if step % 50 == 0: 60 | obj = { 61 | "train/step": step, 62 | "train/loss": loss, 63 | } 64 | obj = {**obj, **info} 65 | print(obj) 66 | wb_util.log(obj) 67 | 68 | if step % args.eval_step == 0: 69 | if eval_cut.eval_short_cut(): 70 | return 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = my_parser.MyParser(epoch=100, batch_size=256, model_save_folder=model_save_folder, early_stop=10) 75 | parser.custom({ 76 | "ncentroids": 1000, 77 | "batch_per_epoch": 250, 78 | "eval_step": 250, 79 | 80 | "ratio_mse": 0.0, 81 | 82 | "mts_alpha": 2.0, 83 | "mts_beta": 50.0, 84 | "mts_base": 1.0, 85 | 86 | "load_model": "", 87 | 88 | "cluster_type": "all", 89 | }) 90 | parser.use_wb("sl_project", "SL") 91 | args = parser.parse() 92 | seed_util.set_seed(args.seed) 93 | assert args.cluster_type in ["v", "f", "all"] 94 | 95 | # 1.model: 96 | model = my_model.Encoder().cuda() 97 | if args.load_model is not None and os.path.exists(args.load_model): 98 | model_util.load_model(args.load_model, model, strict=True) 99 | 100 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 101 | 102 | # 2. loader 103 | eval_cut = Cut(emb_eva, model, args) 104 | 105 | # 3.loss 106 | fun_loss_metric = losses.MultiSimilarityLoss(alpha=args.mts_alpha, beta=args.mts_beta, base=args.mts_base) 107 | 108 | fun_loss_mse = torch.nn.MSELoss() 109 | 110 | train() 111 | -------------------------------------------------------------------------------- /utils/deepcluster_util.py: -------------------------------------------------------------------------------- 1 | import mkl 2 | import collections 3 | 4 | mkl.get_max_threads() 5 | import faiss 6 | from utils import wb_util, distance_util 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def do_k_means(matrix, ncentroids): 12 | niter = 20 13 | verbose = True 14 | d = matrix.shape[1] 15 | 16 | kmeans = faiss.Kmeans(d, 17 | ncentroids, 18 | niter=niter, 19 | verbose=verbose, 20 | spherical=False, 21 | min_points_per_centroid=3, 22 | max_points_per_centroid=100000, 23 | gpu=False, 24 | ) 25 | 26 | kmeans.train(matrix) 27 | 28 | D, I = kmeans.index.search(matrix, 1) 29 | 30 | cluster_label = I.squeeze() 31 | similarity_array = [] 32 | for i in range(len(matrix)): 33 | sample_vec = matrix[i] 34 | sample_label = I[i][0] 35 | center_vec = kmeans.centroids[sample_label] 36 | similarity = distance_util.cosine_similarity(sample_vec, center_vec) 37 | similarity_array.append(similarity) 38 | similarity_array = np.array(similarity_array) 39 | 40 | sorted_similarity_array = similarity_array.copy() 41 | sorted_similarity_array.sort() 42 | 43 | return cluster_label, similarity_array 44 | 45 | 46 | def get_center_matrix(v_emb, f_emb, cluster_label, ncentroids): 47 | tmp_dict = collections.defaultdict(list) 48 | 49 | for v, f, label in zip(v_emb, f_emb, cluster_label): 50 | tmp_dict[label].append(v) 51 | tmp_dict[label].append(f) 52 | 53 | tmp_arr = [] 54 | for i in range(ncentroids): 55 | vec = np.mean(tmp_dict[i], axis=0) 56 | tmp_arr.append(vec) 57 | 58 | center_matrix = np.array(tmp_arr) 59 | return center_matrix 60 | 61 | 62 | def extract_embeddings(ordered_iter, model): 63 | model.eval() 64 | all_emb = [] 65 | all_emb_v = [] 66 | all_emb_f = [] 67 | all_keys = [] 68 | for data in ordered_iter: 69 | with torch.no_grad(): 70 | data = [i.cuda() for i in data] 71 | voice_data, face_data, label = data 72 | v_emb, f_emb = model(voice_data, face_data) 73 | # [v-f] 74 | the_emb = torch.cat([v_emb, f_emb], dim=1).detach().cpu().numpy() 75 | label_npy = label.squeeze().detach().cpu().numpy().tolist() 76 | 77 | for emb, label_int in zip(the_emb, label_npy): 78 | all_emb.append(emb) 79 | all_emb_v.append(emb[0:128]) 80 | all_emb_f.append(emb[128:]) 81 | all_keys.append(ordered_iter.dataset.train_movie_list[label_int]) 82 | model.train() 83 | return all_keys, all_emb, all_emb_v, all_emb_f 84 | 85 | 86 | def do_cluster(ordered_iter, ncentroids, model, input_emb_type="all"): 87 | all_keys, all_emb, all_emb_v, all_emb_f = extract_embeddings(ordered_iter, model) 88 | 89 | if input_emb_type == "v": 90 | input_emb = np.array(all_emb_v) 91 | elif input_emb_type == "f": 92 | input_emb = np.array(all_emb_f) 93 | elif input_emb_type == "all": 94 | input_emb = np.array(all_emb) 95 | else: 96 | raise Exception("wrong type") 97 | 98 | cluster_label, similarity_array = do_k_means(input_emb, ncentroids) 99 | 100 | movie2label = {} 101 | for label, key, sim in zip(cluster_label, all_keys, similarity_array): 102 | movie2label[key] = label 103 | 104 | center_vector = get_center_matrix(all_emb_v, all_emb_f, cluster_label, ncentroids) 105 | return movie2label, center_vector 106 | 107 | 108 | def do_cluster_v2(all_keys, all_emb, all_emb_v, all_emb_f, ncentroids, input_emb_type="all"): 109 | if input_emb_type == "v": 110 | input_emb = np.array(all_emb_v) 111 | elif input_emb_type == "f": 112 | input_emb = np.array(all_emb_f) 113 | elif input_emb_type == "all": 114 | input_emb = np.array(all_emb) 115 | else: 116 | raise Exception("wrong type") 117 | 118 | cluster_label, similarity_array = do_k_means(input_emb, ncentroids) 119 | 120 | movie2label = {} 121 | for label, key, sim in zip(cluster_label, all_keys, similarity_array): 122 | movie2label[key] = label 123 | 124 | center_vector = get_center_matrix(all_emb_v, all_emb_f, cluster_label, ncentroids) 125 | return movie2label, center_vector 126 | -------------------------------------------------------------------------------- /models/incep.py: -------------------------------------------------------------------------------- 1 | from facenet_pytorch.models.inception_resnet_v1 import * 2 | 3 | 4 | class InceptionResnetV1(nn.Module): 5 | """Inception Resnet V1 model with optional loading of pretrained weights. 6 | 7 | Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface 8 | datasets. Pretrained state_dicts are automatically downloaded on model instantiation if 9 | requested and cached in the torch cache. Subsequent instantiations use the cache rather than 10 | redownloading. 11 | 12 | Keyword Arguments: 13 | pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'. 14 | (default: {None}) 15 | classify {bool} -- Whether the model should output classification probabilities or feature 16 | embeddings. (default: {False}) 17 | num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not 18 | equal to that used for the pretrained model, the final linear layer will be randomly 19 | initialized. (default: {None}) 20 | dropout_prob {float} -- Dropout probability. (default: {0.6}) 21 | """ 22 | 23 | def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None): 24 | super().__init__() 25 | 26 | # Set simple attributes 27 | self.pretrained = pretrained 28 | self.classify = classify 29 | self.num_classes = num_classes 30 | 31 | if pretrained == 'vggface2': 32 | tmp_classes = 8631 33 | elif pretrained == 'casia-webface': 34 | tmp_classes = 10575 35 | elif pretrained is None and self.classify and self.num_classes is None: 36 | raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified') 37 | 38 | # Define layers 39 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) 40 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) 41 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) 42 | self.maxpool_3a = nn.MaxPool2d(3, stride=2) 43 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) 44 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) 45 | self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2) 46 | self.repeat_1 = nn.Sequential( 47 | Block35(scale=0.17), 48 | Block35(scale=0.17), 49 | Block35(scale=0.17), 50 | Block35(scale=0.17), 51 | Block35(scale=0.17), 52 | ) 53 | self.mixed_6a = Mixed_6a() 54 | self.repeat_2 = nn.Sequential( 55 | Block17(scale=0.10), 56 | Block17(scale=0.10), 57 | Block17(scale=0.10), 58 | Block17(scale=0.10), 59 | Block17(scale=0.10), 60 | Block17(scale=0.10), 61 | Block17(scale=0.10), 62 | Block17(scale=0.10), 63 | Block17(scale=0.10), 64 | Block17(scale=0.10), 65 | ) 66 | self.mixed_7a = Mixed_7a() 67 | self.repeat_3 = nn.Sequential( 68 | Block8(scale=0.20), 69 | Block8(scale=0.20), 70 | Block8(scale=0.20), 71 | Block8(scale=0.20), 72 | Block8(scale=0.20), 73 | ) 74 | self.block8 = Block8(noReLU=True) 75 | self.avgpool_1a = nn.AdaptiveAvgPool2d(1) 76 | self.dropout = nn.Dropout(dropout_prob) 77 | self.last_linear = nn.Linear(1792, 512, bias=False) 78 | self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True) 79 | 80 | if pretrained is not None: 81 | self.logits = nn.Linear(512, tmp_classes) 82 | load_weights(self, pretrained) 83 | 84 | if self.classify and self.num_classes is not None: 85 | self.logits = nn.Linear(512, self.num_classes) 86 | 87 | self.device = torch.device('cpu') 88 | if device is not None: 89 | self.device = device 90 | self.to(device) 91 | 92 | def forward(self, x): 93 | """Calculate embeddings or logits given a batch of input image tensors. 94 | 95 | Arguments: 96 | x {torch.tensor} -- Batch of image tensors representing faces. 97 | 98 | Returns: 99 | torch.tensor -- Batch of embedding vectors or multinomial logits. 100 | """ 101 | x = self.conv2d_1a(x) 102 | x = self.conv2d_2a(x) 103 | x = self.conv2d_2b(x) 104 | x = self.maxpool_3a(x) 105 | x = self.conv2d_3b(x) 106 | x = self.conv2d_4a(x) 107 | x = self.conv2d_4b(x) 108 | x = self.repeat_1(x) 109 | x = self.mixed_6a(x) 110 | x = self.repeat_2(x) 111 | x = self.mixed_7a(x) 112 | x = self.repeat_3(x) 113 | x = self.block8(x) 114 | x = self.avgpool_1a(x) 115 | x = self.dropout(x) 116 | x = self.last_linear(x.view(x.shape[0], -1)) 117 | x = self.last_bn(x) 118 | # x1 = self.logits(x) 119 | # x2 = F.normalize(x, p=2, dim=1) 120 | # return x2, x1 121 | return x 122 | -------------------------------------------------------------------------------- /utils/eva_emb_full.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from utils import distance_util, path_util 4 | from utils import map_evaluate 5 | from sklearn.metrics import roc_auc_score 6 | import scipy.spatial 7 | import numpy as np 8 | import collections 9 | from utils import pickle_util 10 | 11 | 12 | class EmbEva: 13 | 14 | def __init__(self, 15 | voice_embedding_dict, 16 | face_embedding_dict, 17 | batch_size=512, 18 | ): 19 | self.voice_embedding_dict = voice_embedding_dict 20 | self.face_embedding_dict = face_embedding_dict 21 | self.batch_size = batch_size 22 | 23 | def do_valid(self, model): 24 | obj = {"valid/auc": self.do_verification(model, "./dataset/voxceleb/eval/valid_verification.pkl")} 25 | return obj 26 | 27 | def do_full_test(self, model): 28 | obj = {} 29 | # 1.verification 30 | obj["test/auc"] = self.do_verification(model, "./dataset/voxceleb/eval/test_verification.pkl") 31 | obj["test/auc_g"] = self.do_verification(model, "./dataset/voxceleb/eval/test_verification_g.pkl") 32 | 33 | # 2.retrieval 34 | obj["test/map_v2f"], obj["test/map_f2v"] = self.do_retrival(model, "./dataset/voxceleb/eval/test_retrieval.pkl") 35 | 36 | # 3.matching 37 | obj["test/ms_v2f"], obj["test/ms_f2v"] = self.do_matching(model, "./dataset/voxceleb/eval/test_matching.pkl") 38 | obj["test/ms_v2f_g"], obj["test/ms_f2v_g"] = self.do_matching(model, "./dataset/voxceleb/eval/test_matching_g.pkl") 39 | return obj 40 | 41 | def do_1_N_matching(self, model): 42 | data = pickle_util.read_pickle(path_util.look_up("./dataset/voxceleb/eval/test_matching_10.pkl")) 43 | v2emb, f2emb = self.to_emb_dict(model, data["jpg_set"], data["wav_set"]) 44 | key2emb = {**v2emb, **f2emb} 45 | ans = {} 46 | ans["v2f"] = handle_1_n(data["match_list"], is_v2f=True, key2emb=key2emb) 47 | ans["f2v"] = handle_1_n(data["match_list"], is_v2f=False, key2emb=key2emb) 48 | return ans 49 | 50 | def do_matching(self, model, pkl_path): 51 | data = pickle_util.read_pickle(pkl_path) 52 | v2emb, f2emb = self.to_emb_dict(model, data["jpg_set"], data["wav_set"]) 53 | ms_vf, ms_fv = calc_ms(data["match_list"], v2emb, f2emb) 54 | return ms_vf, ms_fv 55 | 56 | def do_verification(self, model, pkl_path): 57 | data = pickle_util.read_pickle(pkl_path) 58 | v2emb, f2emb = self.to_emb_dict(model, data["jpg_set"], data["wav_set"]) 59 | return calc_vrification(data["list"], v2emb, f2emb) 60 | 61 | def do_retrival(self, model, pkl_path): 62 | data = pickle_util.read_pickle(pkl_path) 63 | v2emb, f2emb = self.to_emb_dict(model, data["jpg_set"], data["wav_set"]) 64 | map_vf, map_fv = calc_map_value(data["retrieval_lists"], v2emb, f2emb) 65 | return map_vf, map_fv 66 | 67 | def to_emb_dict(self, model, all_jpg_set, all_wav_set): 68 | model.eval() 69 | batch_size = self.batch_size 70 | image_loader = DataLoader(DataSet(list(all_jpg_set), self.face_embedding_dict), batch_size=batch_size, shuffle=False, pin_memory=True) 71 | voice_loader = DataLoader(DataSet(list(all_wav_set), self.voice_embedding_dict), batch_size=batch_size, shuffle=False, pin_memory=True) 72 | f2emb = get_path2emb(image_loader.dataset.data, model.face_encoder, image_loader) 73 | v2emb = get_path2emb(voice_loader.dataset.data, model.voice_encoder, voice_loader) 74 | model.train() 75 | return v2emb, f2emb 76 | 77 | 78 | def calc_ms(all_data, v2emb, f2emb): 79 | voice1_emb = [] 80 | voice2_emb = [] 81 | face1_emb = [] 82 | face2_emb = [] 83 | 84 | for name1, voice1, face1, name2, voice2, face2 in all_data: 85 | voice1_emb.append(v2emb[voice1]) 86 | voice2_emb.append(v2emb[voice2]) 87 | face1_emb.append(f2emb[face1]) 88 | face2_emb.append(f2emb[face2]) 89 | 90 | voice1_emb = np.array(voice1_emb) 91 | voice2_emb = np.array(voice2_emb) 92 | face1_emb = np.array(face1_emb) 93 | face2_emb = np.array(face2_emb) 94 | 95 | dist_vf1 = distance_util.parallel_distance_cosine_based_distance(voice1_emb, face1_emb) 96 | dist_vf2 = distance_util.parallel_distance_cosine_based_distance(voice1_emb, face2_emb) 97 | dist_fv1 = distance_util.parallel_distance_cosine_based_distance(face1_emb, voice1_emb) 98 | dist_fv2 = distance_util.parallel_distance_cosine_based_distance(face1_emb, voice2_emb) 99 | 100 | vf_result = dist_vf1 < dist_vf2 101 | fv_result = dist_fv1 < dist_fv2 102 | ms_vf = np.mean(vf_result) 103 | ms_fv = np.mean(fv_result) 104 | 105 | obj = { 106 | "dist_vf1": dist_vf1, 107 | "dist_vf2": dist_vf2, 108 | "dist_fv1": dist_fv1, 109 | "dist_fv2": dist_fv2, 110 | "test_data": all_data, # name1, voice1, face1, name2, voice2, face2 111 | "result_fv": fv_result, 112 | "result_vf": vf_result, 113 | "score_vf": ms_vf, 114 | "score_fv": ms_fv, 115 | } 116 | return ms_vf, ms_fv 117 | 118 | 119 | def calc_map_value(retrieval_lists, v2emb, f2emb): 120 | tmp_dic = collections.defaultdict(list) 121 | for arr in retrieval_lists: 122 | map_vf, map_fv = calc_map_recall_at_k(arr, v2emb, f2emb) 123 | tmp_dic["map_vf"].append(map_vf) 124 | tmp_dic["map_fv"].append(map_fv) 125 | map_fv = np.mean(tmp_dic["map_fv"]) 126 | map_vf = np.mean(tmp_dic["map_vf"]) 127 | return map_vf, map_fv 128 | 129 | 130 | class DataSet(torch.utils.data.Dataset): 131 | 132 | def __init__(self, data, path2emb): 133 | self.data = data 134 | self.path2emb = path2emb 135 | 136 | def __len__(self): 137 | return len(self.data) 138 | 139 | def __getitem__(self, index): 140 | short_path = self.data[index] 141 | data = self.path2emb[short_path] 142 | data = torch.FloatTensor(data) 143 | return data, index 144 | 145 | 146 | def handle_1_n(match_list, is_v2f, key2emb): 147 | tmp_dict = collections.defaultdict(list) 148 | for voices, faces in match_list: 149 | if is_v2f: 150 | prob = voices[0] 151 | gallery = faces 152 | else: 153 | prob = faces[0] 154 | gallery = voices 155 | 156 | # 1. to vector 157 | prob_vec = np.array([key2emb[prob]]) 158 | gallery_vec = np.array([key2emb[i] for i in gallery]) 159 | 160 | # 2. calc similarity 161 | distances = scipy.spatial.distance.cdist(prob_vec, gallery_vec, 'cosine') 162 | distances = distances.squeeze() 163 | assert len(distances) == len(gallery_vec) 164 | 165 | # 3. get results of 2~N matching 166 | for index in range(2, len(gallery) + 1): 167 | arr = distances[:index] 168 | is_correct = int(np.argmin(arr) == 0) 169 | tmp_dict[index].append(is_correct) 170 | 171 | for key, arr in tmp_dict.items(): 172 | tmp_dict[key] = np.mean(arr) 173 | return tmp_dict 174 | 175 | 176 | # 177 | 178 | def get_path2emb(all_path_list, encoder, loader): 179 | f2emb = {} 180 | for data, path_indexes in loader: 181 | emb_batch = encoder(data.cuda()).detach().cpu().numpy() 182 | path_indexes = path_indexes.detach().cpu().numpy() 183 | for p_index, emb in zip(path_indexes, emb_batch): 184 | the_path = all_path_list[p_index] 185 | f2emb[the_path] = emb 186 | 187 | return f2emb 188 | 189 | 190 | def cosine_similarity(a, b): 191 | assert len(a.shape) == 2 192 | assert a.shape == b.shape 193 | 194 | ab = np.sum(a * b, axis=1) 195 | # (batch_size,) 196 | 197 | a_norm = np.sqrt(np.sum(a * a, axis=1)) 198 | b_norm = np.sqrt(np.sum(b * b, axis=1)) 199 | cosine = ab / (a_norm * b_norm) 200 | # [-1,1] 201 | prob = (cosine + 1) / 2.0 202 | return prob 203 | 204 | 205 | def calc_vrification(the_list, v2emb, f2emb): 206 | voice_emb = np.array([v2emb[tup[0]] for tup in the_list]) 207 | face_emb = np.array([f2emb[tup[1]] for tup in the_list]) 208 | real_label = np.array([tup[2] for tup in the_list]) 209 | 210 | # AUC 211 | prob = cosine_similarity(voice_emb, face_emb) 212 | auc = roc_auc_score(real_label, prob) 213 | return auc 214 | 215 | 216 | def calc_map_recall_at_k(all_data, v2emb, f2emb): 217 | # 1.get embedding 218 | labels = [] 219 | v_emb_list = [] 220 | f_emb_list = [] 221 | for v, f, name in all_data: 222 | labels.append(name) 223 | v_emb_list.append(v2emb[v]) 224 | f_emb_list.append(f2emb[f]) 225 | 226 | v_emb_list = np.array(v_emb_list) 227 | f_emb_list = np.array(f_emb_list) 228 | 229 | # 2. calculate distance 230 | vf_dist = scipy.spatial.distance.cdist(v_emb_list, f_emb_list, 'cosine') 231 | fv_dist = vf_dist.T 232 | 233 | # 3.map value 234 | map_vf = map_evaluate.fx_calc_map_label_v2(vf_dist, labels) 235 | map_fv = map_evaluate.fx_calc_map_label_v2(fv_dist, labels) 236 | return map_vf, map_fv 237 | 238 | 239 | def calc_ms_f2v(all_data, v2emb, f2emb): 240 | voice1_emb = [] 241 | voice2_emb = [] 242 | face1_emb = [] 243 | 244 | for face1, voice1, voice2 in all_data: 245 | voice1_emb.append(v2emb[voice1]) 246 | voice2_emb.append(v2emb[voice2]) 247 | face1_emb.append(f2emb[face1]) 248 | 249 | voice1_emb = np.array(voice1_emb) 250 | voice2_emb = np.array(voice2_emb) 251 | face1_emb = np.array(face1_emb) 252 | 253 | dist_fv1 = distance_util.parallel_distance_cosine_based_distance(face1_emb, voice1_emb) 254 | dist_fv2 = distance_util.parallel_distance_cosine_based_distance(face1_emb, voice2_emb) 255 | 256 | fv_result = dist_fv1 < dist_fv2 257 | ms_fv = np.mean(fv_result) 258 | return ms_fv 259 | 260 | 261 | def calc_ms_v2f(all_data, v2emb, f2emb): 262 | voice1_emb = [] 263 | face1_emb = [] 264 | face2_emb = [] 265 | 266 | for voice1, face1, face2 in all_data: 267 | voice1_emb.append(v2emb[voice1]) 268 | face1_emb.append(f2emb[face1]) 269 | face2_emb.append(f2emb[face2]) 270 | 271 | voice1_emb = np.array(voice1_emb) 272 | face1_emb = np.array(face1_emb) 273 | face2_emb = np.array(face2_emb) 274 | 275 | dist_vf1 = distance_util.parallel_distance_cosine_based_distance(voice1_emb, face1_emb) 276 | dist_vf2 = distance_util.parallel_distance_cosine_based_distance(voice1_emb, face2_emb) 277 | 278 | vf_result = dist_vf1 < dist_vf2 279 | ms_vf = np.mean(vf_result) 280 | return ms_vf 281 | --------------------------------------------------------------------------------