├── .gitignore ├── AgeGender ├── Detector.py ├── __init__.py ├── data.npy ├── from_camera.py └── models │ ├── __init__.py │ ├── model.py │ └── train.py ├── ArticelTranslate └── RetinaFace.txt ├── FacialExpression ├── FaceExpression.py ├── __init__.py ├── from_camera.py └── models │ ├── __init__.py │ ├── densenet.py │ └── resnet.py ├── FullPipeline.py ├── GUI.py ├── InsightFace ├── __init__.py ├── add_face_from_camera.py ├── add_face_from_dir.py ├── camera_verify.py ├── data │ ├── __init__.py │ ├── config.py │ └── data_pipe.py ├── models │ ├── Learner.py │ ├── evaluatation.py │ ├── model.py │ └── train.py ├── utils.py └── version ├── README.md ├── Retinaface ├── Retinaface.py ├── __init__.py ├── from_camera.py ├── models │ ├── __init__.py │ ├── net.py │ ├── retinaface.py │ ├── rfb.py │ └── slim.py └── utils │ ├── __init__.py │ ├── alignment.py │ ├── box_utils.py │ └── config.py ├── animate.py ├── application.py ├── application_backend.py ├── audio.py ├── audio_hparams.py ├── audio_processing.py ├── augmentation.py ├── batch_inference.py ├── config.json ├── config ├── bair-256.yaml ├── fashion-256.yaml ├── mgif-256.yaml ├── nemo-256.yaml ├── taichi-256.yaml ├── taichi-adv-256.yaml ├── vox-256.yaml └── vox-adv-256.yaml ├── convert_model.py ├── crop-video.py ├── data.py ├── data ├── bair256.csv ├── cmudict_dictionary ├── heteronyms ├── taichi-loading │ ├── README.md │ ├── load_videos.py │ └── taichi-metadata.csv └── taichi256.csv ├── data_utils.py ├── demo.py ├── denoiser.py ├── discriminator.py ├── distributed.py ├── filelists ├── README.md ├── libritts_speakerinfo.txt ├── libritts_train_clean_100_audiopath_text_sid_atleast5min_val_filelist.txt ├── libritts_train_clean_100_audiopath_text_sid_shorterthan10s_atleast5min_train_filelist.txt ├── ljs_audio_text_test_filelist.txt ├── ljs_audio_text_train_filelist.txt ├── ljs_audio_text_val_filelist.txt ├── ljs_audiopaths_text_sid_train_filelist.txt ├── ljs_audiopaths_text_sid_val_filelist.txt └── train.txt ├── flowtron.py ├── flowtron_logger.py ├── flowtron_plotting_utils.py ├── frames_dataset.py ├── generate_image.py ├── generation_resources ├── 0.gif ├── 1.gif ├── 10.gif ├── 11.gif ├── 12.gif ├── 2.gif ├── 3.gif ├── 4.gif ├── 5.gif ├── 6.gif ├── 7.gif ├── 8.gif └── 9.gif ├── generator.py ├── glow.py ├── glow_old.py ├── hparams.py ├── inference.py ├── interface.py ├── layers.py ├── log_file ├── logger.py ├── logger_firstorder.py ├── loss_function.py ├── loss_scaler.py ├── mel2samp.py ├── model.py ├── modules ├── dense_motion.py ├── discriminator.py ├── generator.py ├── keypoint_detector.py ├── model.py └── util.py ├── multiproc.py ├── output └── audio │ └── test.wav ├── plotting_utils.py ├── preprocess.py ├── reconstruction.py ├── requirements.txt ├── run.py ├── stft.py ├── stylegan2 ├── __init__.py ├── external_models │ ├── __init__.py │ ├── inception.py │ └── lpips.py ├── loss_fns.py ├── metrics │ ├── __init__.py │ ├── fid.py │ └── ppl.py ├── models.py ├── modules.py ├── project.py ├── train.py └── utils.py ├── sync_batchnorm ├── __init__.py ├── batchnorm.py ├── comm.py ├── replicate.py └── unittest.py ├── temp.png ├── text ├── LICENSE ├── __init__.py ├── acronyms.py ├── cleaners.py ├── cmudict.py ├── datestime.py ├── numbers.py └── symbols.py ├── train.py ├── train_firstorder.py ├── train_lipgan.py ├── train_unet.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | *.mp4 4 | models/stylegan_Gs.pth 5 | *.pth 6 | *.dat 7 | *.tar 8 | *.pt 9 | *.h5 10 | -------------------------------------------------------------------------------- /AgeGender/Detector.py: -------------------------------------------------------------------------------- 1 | from AgeGender.models.model import ShuffleneTiny, ShuffleneFull 2 | import torch 3 | 4 | 5 | class AgeGender: 6 | 7 | def __init__(self, name, weight_path, device): 8 | """ 9 | Age and gender Detector 10 | :param name: name of backbone (full or tiny) 11 | :param device: model run in cpu or gpu (cuda, cpu) 12 | :param weight_path: path of network weight 13 | 14 | Notice: image size must be 112x112 15 | but cun run with 224x224 16 | 17 | Method detect: 18 | :param faces: 4d tensor of face for example size(1, 3, 112, 112) 19 | :returns genders list and ages list 20 | """ 21 | if name == 'tiny': 22 | model = ShuffleneTiny() 23 | elif name == 'full': 24 | model = ShuffleneFull() 25 | else: 26 | exit('from AgeGender Detector: model dose not support just(tiny, full)') 27 | 28 | model.load_state_dict(torch.load(weight_path)) 29 | model.to(device).eval() 30 | self.model = model 31 | self.device = device 32 | 33 | def detect(self, faces): 34 | faces = faces.permute(0, 3, 1, 2) 35 | faces = faces.float().div(255).to(self.device) 36 | 37 | mu = torch.as_tensor([0.485, 0.456, 0.406], dtype=faces.dtype, device=faces.device) 38 | std = torch.as_tensor([0.229, 0.224, 0.225], dtype=faces.dtype, device=faces.device) 39 | faces[:].sub_(mu[:, None, None]).div_(std[:, None, None]) 40 | 41 | outputs = self.model(faces) 42 | genders = [] 43 | ages = [] 44 | for out in outputs: 45 | gender = torch.argmax(out[:2]) 46 | gender = 'Male' if gender == 0 else 'Female' 47 | genders.append(gender) 48 | ages.append(int(out[-1])) 49 | 50 | return genders, ages 51 | -------------------------------------------------------------------------------- /AgeGender/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/AgeGender/__init__.py -------------------------------------------------------------------------------- /AgeGender/data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/AgeGender/data.npy -------------------------------------------------------------------------------- /AgeGender/from_camera.py: -------------------------------------------------------------------------------- 1 | from Retinaface.Retinaface import FaceDetector 2 | from AgeGender.Detector import AgeGender 3 | import cv2 4 | from time import time 5 | 6 | face_detector = FaceDetector(name='mobilenet', weight_path='../Retinaface/weights/mobilenet.pth', device='cuda') 7 | age_gender_detector = AgeGender(name='full', weight_path='weights/ShufflenetFull.pth', device='cuda') 8 | 9 | vid = cv2.VideoCapture(0) 10 | vid.set(3, 1280) 11 | vid.set(4, 720) 12 | while True: 13 | ret, frame = vid.read() 14 | faces, boxes, scores, landmarks = face_detector.detect_align(frame) 15 | if len(faces.shape) > 1: 16 | tic = time() 17 | genders, ages = age_gender_detector.detect(faces) 18 | print(time()-tic) 19 | for i, b in enumerate(boxes): 20 | cv2.putText(frame, f'{genders[i]},{ages[i]}', (int(b[0]), int(b[1]-10)), cv2.FONT_HERSHEY_SIMPLEX, 1.1, [0, 200, 0], 3) 21 | cv2.rectangle(frame, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (255, 0, 0), 3) 22 | 23 | for p in landmarks: 24 | for i in range(5): 25 | cv2.circle(frame, (p[i][0], p[i][1]), 3, (0, 255, 0), -1) 26 | 27 | cv2.imshow('frame', frame) 28 | if cv2.waitKey(1) == ord('q'): 29 | break 30 | 31 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /AgeGender/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/AgeGender/models/__init__.py -------------------------------------------------------------------------------- /AgeGender/models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torchvision.models as models 4 | import time 5 | import sys 6 | 7 | mean = [0.485, 0.456, 0.406] 8 | std = [0.229, 0.224, 0.225] 9 | sz = 112 10 | 11 | 12 | class ShuffleneTiny(nn.Module): 13 | 14 | def __init__(self): 15 | super(ShuffleneTiny, self).__init__() 16 | self.model = models.shufflenet_v2_x0_5() 17 | self.model.fc = nn.Sequential( 18 | nn.BatchNorm1d(1024), 19 | nn.Linear(1024, 3)) 20 | 21 | def forward(self, x): 22 | return self.model(x) 23 | 24 | 25 | class ShuffleneFull(nn.Module): 26 | 27 | def __init__(self): 28 | super(ShuffleneFull, self).__init__() 29 | self.model = models.shufflenet_v2_x1_0() 30 | self.model.fc = nn.Sequential( 31 | nn.BatchNorm1d(1024), 32 | nn.Linear(1024, 3)) 33 | 34 | def forward(self, x): 35 | return self.model(x) 36 | 37 | 38 | class TrainModel: 39 | 40 | def __init__(self, model, train_dl, valid_dl, optimizer, certrion, scheduler, num_epochs): 41 | 42 | self.num_epochs = num_epochs 43 | self.model = model 44 | self.scheduler = scheduler 45 | self.train_dl = train_dl 46 | self.valid_dl = valid_dl 47 | self.optimizer = optimizer 48 | self.certrion = certrion 49 | 50 | self.loss_history = [] 51 | self.best_acc_valid = 0.0 52 | self.best_wieght = None 53 | 54 | self.training() 55 | 56 | def training(self): 57 | 58 | valid_acc = 0 59 | for epoch in range(self.num_epochs): 60 | 61 | print('Epoch %2d/%2d' % (epoch + 1, self.num_epochs)) 62 | print('-' * 15) 63 | 64 | t0 = time.time() 65 | train_acc = self.train_model() 66 | valid_acc = self.valid_model() 67 | if self.scheduler: 68 | self.scheduler.step() 69 | 70 | time_elapsed = time.time() - t0 71 | print(' Training complete in: %.0fm %.0fs' % (time_elapsed // 60, time_elapsed % 60)) 72 | print('| val_acc_gender | val_l1_loss | acc_gender | l1_loss |') 73 | print('| %.3f | %.3f | %.3f | %.3f \n' % (valid_acc[0], valid_acc[1], train_acc[0], train_acc[1])) 74 | 75 | if valid_acc[0] > self.best_acc_valid: 76 | self.best_acc_valid = valid_acc[1] 77 | self.best_wieght = self.model.state_dict().copy() 78 | return 79 | 80 | def train_model(self): 81 | 82 | self.model.train() 83 | N = len(self.train_dl.dataset) 84 | step = N // self.train_dl.batch_size 85 | 86 | avg_loss = 0.0 87 | acc_gender = 0.0 88 | loss_age = 0.0 89 | 90 | for i, (x, y) in enumerate(self.train_dl): 91 | x, y = x.cuda(), y.cuda() 92 | # forward 93 | pred_8 = self.model(x) 94 | # loss 95 | loss = self.certrion(pred_8, y) 96 | # backward 97 | self.optimizer.zero_grad() 98 | loss.backward() 99 | self.optimizer.step() 100 | # statistics of model training 101 | avg_loss = (avg_loss * i + loss) / (i + 1) 102 | acc_gender += accuracy_gender(pred_8, y) 103 | loss_age += l1loss_age(pred_8, y) 104 | 105 | self.loss_history.append(avg_loss) 106 | 107 | # report statistics 108 | sys.stdout.flush() 109 | sys.stdout.write("\r Train_Step: %d/%d | runing_loss: %.4f" % (i + 1, step, avg_loss)) 110 | 111 | sys.stdout.flush() 112 | return torch.tensor([acc_gender, loss_age]) / N 113 | 114 | def valid_model(self): 115 | print() 116 | self.model.eval() 117 | N = len(self.valid_dl.dataset) 118 | step = N // self.valid_dl.batch_size 119 | acc_gender = 0.0 120 | loss_age = 0.0 121 | 122 | with torch.no_grad(): 123 | for i, (x, y) in enumerate(self.valid_dl): 124 | x, y = x.cuda(), y.cuda() 125 | 126 | score = self.model(x) 127 | acc_gender += accuracy_gender(score, y) 128 | loss_age += l1loss_age(score, y) 129 | 130 | sys.stdout.flush() 131 | sys.stdout.write("\r Vaild_Step: %d/%d" % (i, step)) 132 | 133 | sys.stdout.flush() 134 | return torch.tensor([acc_gender, loss_age]) / N 135 | 136 | 137 | def accuracy_gender(input, targs): 138 | pred = torch.argmax(input[:, :2], dim=1) 139 | y = targs[:, 0] 140 | return torch.sum(pred == y) 141 | 142 | 143 | def l1loss_age(input, targs): 144 | return F.l1_loss(input[:, -1], targs[:, -1]).mean() 145 | -------------------------------------------------------------------------------- /AgeGender/models/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sajjad Ayoubi: Age Gender Detection 3 | I use UTKFace DataSet 4 | from: https://susanqq.github.io/UTKFace/ 5 | and I create a annotation file annot.npy 6 | which there is in weights folder 7 | """ 8 | 9 | 10 | 11 | from PIL import Image 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | from torch.utils.data import DataLoader, Dataset 17 | import torchvision.transforms as transforms 18 | from AgeGender.models.model import ShuffleneFull, TrainModel 19 | 20 | 21 | class MultitaskDataset(Dataset): 22 | 23 | def __init__(self, data, tfms, root='FaceSet/'): 24 | self.root = root 25 | self.tfms = tfms 26 | self.ages = data[:, 3] 27 | self.races = data[:, 2] 28 | self.genders = data[:, 1] 29 | self.imgs = data[:, 0] 30 | 31 | def __len__(self): 32 | return len(self.imgs) 33 | 34 | def __getitem__(self, i): 35 | return self.tfms(Image.open(self.root + self.imgs[i])), torch.tensor( 36 | [self.genders[i], self.races[i], self.ages[i]]).float() 37 | 38 | def __repr__(self): 39 | return f'{type(self).__name__} of len {len(self)}' 40 | 41 | 42 | mean = [0.485, 0.456, 0.406] 43 | std = [0.229, 0.224, 0.225] 44 | sz = 112 45 | bs = 256 46 | 47 | tf = {'train': transforms.Compose([ 48 | transforms.RandomRotation(degrees=0.2), 49 | transforms.RandomHorizontalFlip(p=.5), 50 | transforms.RandomGrayscale(p=.2), 51 | transforms.Resize((sz, sz)), 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean, std)]), 54 | 'test': transforms.Compose([ 55 | transforms.Resize((sz, sz)), 56 | transforms.ToTensor(), 57 | transforms.Normalize(mean, std)])} 58 | 59 | data = np.load('data.npy', allow_pickle=True) 60 | train_data = data[data[:, -1] == 1] 61 | valid_data = data[data[:, -1] == 0] 62 | 63 | valid_ds = MultitaskDataset(data=valid_data, tfms=tf['test']) 64 | train_ds = MultitaskDataset(data=train_data, tfms=tf['train']) 65 | 66 | train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=8) 67 | valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=True, num_workers=4) 68 | 69 | 70 | def multitask_loss(input, target): 71 | input_gender = input[:, :2] 72 | input_age = input[:, -1] 73 | 74 | loss_gender = F.cross_entropy(input_gender, target[:, 0].long()) 75 | loss_age = F.l1_loss(input_age, target[:, 2]) 76 | 77 | return loss_gender / (.16) + loss_age * 2 78 | 79 | 80 | model = ShuffleneFull().cuda() 81 | optimizer = optim.Adam(params=model.parameters(), lr=0.001, weight_decay=0.001) 82 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) 83 | history = TrainModel(model, train_dl, valid_dl, optimizer, multitask_loss, scheduler, 5) 84 | print(history) -------------------------------------------------------------------------------- /ArticelTranslate/RetinaFace.txt: -------------------------------------------------------------------------------- 1 | ریتانافیس: ردیاب موقعیت چهره تک مرحله ای در محیط طبیعی 2 | اگرچه اقدامات بسیار چشمگیری در زمینه ردیابی چهره در شرایط واقعی و کنترل نشده صورت گرفته است ، اما محلی سازی دقیق و کارآمد چهره در محیط طبیعی همچنان یک چالش جدی است. در این مقاله یک ردیاب چهره تک مرحله با اسم مستقل ارائه شده است 3 | ریتانافیس که ردیابی صورت را در سطح پیکسل های تصویر انجام می دهد. 4 | با استفاده از مزایای مشترک یادگیری تحت نظارت فوق العاده و به ویژه یادگیری چند وظیفه ای نظارت شده می توان چهره را در مقیاس های مختلف شناسایی کرد. 5 | ما در پنج مورد زیر مشارکت ٪ می کنیم(-1-) 6 | ما پنج قسمت از صورت (نوک بینی چشم چپ چشم راست سمت چپ لب سمت راست لب) را در مجوعه داده WIDER FACE به صورت دستی حاشیه نویسی کرده ایم. 7 | در تشخیص چهره های سخت(چهره کوچک و چهره های نا متعارف) با کمک این سیگنال نظارت که اضافه کردیم پیشرفت چشمگیری حاصل شد.(-2-) 8 | ما همچنین برای پیش بینی پیکسل های چهره، شاخه ای برای تشخیص اطلاعات چهره با شکل سه بعدی پیکسل به موازات شعب نظارت شده موجود(-3-) 9 | در مجموعه سخت ارزیابی شبکه عصبی WIDER FACE 10 | ریتانافیس بالاتر از بهترین نتایج موجود در این زمینه عمل کرد. دقت متوسط ​​(AP) با 1.1٪ (دستیابی به AP برابر 91.4٪) 11 | در مجموعه داده تست IJB-C: 12 | ریتانافیس قادر است با استفاده از روش ArcFace نتایج به دست امده روی مسله تایید چهره را به طور چشمگیری ارتقا دهد و در این زمینه بهترین باشد. (TAR : 89.59%) 13 | با استفاده از شبکه های عصبی ستون فقرات سبک . ریتانافیس را می توان روی یک هسته CPU به صورت بلادرنگ برای تصاویر VGA اجرا کرد. 14 | توضیحات بیشتر و کد در : https://github.com/deepinsight/insightface/tree/master/RetinaFace. 15 | 1. مقدمه: 16 | ردیابی موقعیت چهره به صورت خودکار یک پیشنیاز برای تحلیل چهره در بسیاری از برنامه های کاربردی نظیر تحلیل احساسات و سن و افراد است. تعریف محدودی از شناسایی موقیت چهره ممکن است به تشخیص سنتی صورت [53 ، 62] اشاره کند، که هدف آن تخمین کادر مستطیلی دور صورت بدون مقیاس و موقعیت قبلی است. با این وجود، در این مقاله اشکال صورت ویژگی های بهتری برای طبقه بندی چهره ارائه می شود. با الهام از [6] [MTCNN [66] , STN [5 به طور همزمان چهره ها و پنج نشانه از صورت شناسایی می شود. 17 | 18 | به دلیل محدودیت داده های آموزشی تایید نشد که آیا [JDA [6] ، MTCNN [66 و [STN [5 می توانند از علائم صورت برای شناسای چهره های کوچک در تصویر استفاده کنند یا نه. یكی از سؤالاتی كه در این مقاله به آن پاسخ می دهیم این است كه آیا ما می توانیم با استفاده از سیگنال نظارت اضافی ساخته شده از پنج علامت صورت بهترین عملکرد فعلی (90.3% [67]) را در مجموعه تست سخت [WIDER FACE [60 بهبود دهیم. در [Mask R-CNN [20 ، با اضافه کردن یک شاخه برای پیش بینی یک ماسک شی به طور موازی با شاخه موجود برای محدود کردن تشخیص جعبه محاطی و رگرسیون ، عملکرد تشخیص به طور قابل توجهی بهبود می یابد. این امر تأیید می کند که حاشیه نویسی متراکم پیکسلی نیز برای بهبود تشخیص مفید است. در [FAN [56 ، یک نقشه ویژگی به سطح لنگر برای بهبود تشخیص چهره های مسدود شده ارائه شده است. با این وجود ، نقشه ویژگی پیشنهادی کاملاً درشت است و حاوی اطلاعات معنایی نیست. به تازگی ، مدل های اصلاح شده سه بعدی خود نظارت [14 ، 51 ، 52 ، 70] به مدل سازی چهره ای 3 بعدی در محیط واقعی دست یافته اند. به خصوص ، به طور خاص ، [MeshDecoder [70 با استفاده از غلطک ها [10 ، 40] در شکل و بافت، به سرعت بیشتری نسبت به زمان بلادرنگ دست می یابد. با این حال ، اصلی ترین چالش های اعمال [MeshDecoder [70 در ردیاب تک مرحله ای عبارتند از: (1) پارامترهای دوربین برای تخمین دقیق سخت است ، و (2) شکل نهان مشترک و بازنمایی بافت از یک بردار ویژگی تنها (1 × 1 Conv روی هرم ویژگی) به جای ویژگی جمع شده RoI پیش بینی شده است ، که نشان دهنده خطر ابتلا به ویژگی تغییر در این مقاله ، ما از شعبه MeshDecoder [70] از طریق یادگیری خود نظارت برای پیش بینی شکل چهره سه بعدی هوشمند با پیکسل به موازات شاخه های نظارت موجود استفاده می کنیم. 19 | 20 | -------------------------------------------------------------------------------- /FacialExpression/FaceExpression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .models.densenet import densenet121 4 | from .models.resnet import resnet34 5 | 6 | labels = np.array(['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']) 7 | 8 | 9 | class EmotionDetector: 10 | 11 | def __init__(self, name='resnet34', device='cpu', weight_path='weights/resnet34.pth'): 12 | """ 13 | Residual Masking Emotion Detector from a list of labels 14 | :param name: name of backbone of networks (resnet34, densenet121) 15 | :param device: model run in cpu or gpu (cuda, cpu) 16 | :param weight_path: path of network weight 17 | 18 | Notice: image size must be 224x224 19 | 20 | Method detect_emotion: 21 | :param faces: 4d tensor of face for example size(1, 3, 224, 224) 22 | :returns emotions list and probability of emotions 23 | """ 24 | 25 | self.device = device 26 | self.model = None 27 | if name == 'resnet34': 28 | self.model = resnet34() 29 | elif name == 'densnet121': 30 | self.model = densenet121() 31 | else: 32 | exit('EmotionDetector: Network does not support!! \n just(resnet34, densnet121)') 33 | 34 | self.model.load_state_dict(torch.load(weight_path)) 35 | self.model.to(device).eval() 36 | 37 | 38 | def detect_emotion(self, faces): 39 | if len(faces) > 0: 40 | faces = faces.permute(0, 3, 1, 2) 41 | faces = faces.float().div(255).to(self.device) 42 | emotions = self.model(faces) 43 | prob = torch.softmax(emotions, dim=1) 44 | emo_prob, emo_idx = torch.max(prob, dim=1) 45 | return labels[emo_idx.tolist()], emo_prob.tolist() 46 | else: 47 | return 0, 0 48 | -------------------------------------------------------------------------------- /FacialExpression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/FacialExpression/__init__.py -------------------------------------------------------------------------------- /FacialExpression/from_camera.py: -------------------------------------------------------------------------------- 1 | from Retinaface.Retinaface import FaceDetector 2 | from FacialExpression.FaceExpression import EmotionDetector 3 | import cv2 4 | 5 | face_detector = FaceDetector(name='mobilenet', weight_path='../Retinaface/weights/mobilenet.pth', device='cuda', face_size=(224, 224)) 6 | emotion_detector = EmotionDetector(name='densnet121', weight_path='weights/densnet121.pth', device='cuda') 7 | 8 | 9 | vid = cv2.VideoCapture(0) 10 | vid.set(3, 1280) 11 | vid.set(4, 720) 12 | while True: 13 | ret, frame = vid.read() 14 | faces, boxes, scores, landmarks = face_detector.detect_align(frame) 15 | if len(faces.shape) > 1: 16 | emotions, emo_probs = emotion_detector.detect_emotion(faces) 17 | 18 | for b in boxes: 19 | cv2.putText(frame, emotions[0], (int(b[0]), int(b[1])-5), cv2.FONT_HERSHEY_SIMPLEX, 1.1, [0, 200, 0], 3) 20 | cv2.rectangle(frame, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (255, 0, 0), 3) 21 | 22 | for p in landmarks: 23 | for i in range(5): 24 | cv2.circle(frame, (p[i][0], p[i][1]), 3, (0, 255, 0), -1) 25 | 26 | cv2.imshow('frame', frame) 27 | if cv2.waitKey(1) == ord('q'): 28 | break 29 | 30 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /FacialExpression/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/FacialExpression/models/__init__.py -------------------------------------------------------------------------------- /GUI.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import uic 2 | 3 | with open('interface_new.py', 'w') as fout: 4 | uic.compileUi('Mock.ui', fout) -------------------------------------------------------------------------------- /InsightFace/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/InsightFace/__init__.py -------------------------------------------------------------------------------- /InsightFace/add_face_from_camera.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | from Retinaface.Retinaface import FaceDetector 4 | from pathlib import Path 5 | 6 | parser = argparse.ArgumentParser(description='take a picture') 7 | parser.add_argument('--name', '-n', default='unknown', type=str, help='input the name of the recording person') 8 | args = parser.parse_args() 9 | 10 | save_path = Path('data/facebank')/args.name 11 | if not save_path.exists(): 12 | save_path.mkdir() 13 | 14 | # init camera 15 | cap = cv2.VideoCapture(1) 16 | cap.set(3, 1280) 17 | cap.set(4, 720) 18 | # init detector 19 | detector = FaceDetector(name='resnet', weight_path='Retinaface/weights/resnet50.pth', device='cuda') 20 | count = 4 21 | while cap.isOpened(): 22 | _, frame = cap.read() 23 | frame = cv2.putText( 24 | frame, f'Press t to take {count} pictures, then finish...', (10, 50), 25 | cv2.FONT_HERSHEY_SIMPLEX, 2, (0,100,0), 3, cv2.LINE_AA) 26 | 27 | if cv2.waitKey(1) & 0xFF == ord('t'): 28 | count -= 1 29 | faces = detector.detect_align(frame)[0].cpu().numpy() 30 | if len(faces.shape) > 1: 31 | cv2.imwrite(f'{save_path}/{args.name}_{count}.jpg', faces[0]) 32 | if count <= 0: 33 | break 34 | else: 35 | print('there is not face in this frame') 36 | 37 | cv2.imshow("My Capture", frame) 38 | 39 | cap.release() 40 | cv2.destroyAllWindows() 41 | -------------------------------------------------------------------------------- /InsightFace/add_face_from_dir.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | from Retinaface.Retinaface import FaceDetector 4 | from pathlib import Path 5 | 6 | parser = argparse.ArgumentParser(description='take a picture') 7 | parser.add_argument('--path', '-p', default='unknown', type=str, help='path of dir of person images') 8 | args = parser.parse_args() 9 | print('only a face in each image and all image from the same person') 10 | 11 | 12 | dir_path = Path(args.path) 13 | if not dir_path.is_dir(): 14 | exit('dir does not exists !!') 15 | save_path = Path(f'data/facebank/{dir_path.name}') 16 | if not save_path.exists(): 17 | save_path.mkdir() 18 | 19 | # init detector 20 | detector = FaceDetector(name='mobilenet', weight_path='Retinaface/weights/mobilenet.pth', device='cuda') 21 | 22 | counter = 0 23 | for img_path in dir_path.iterdir(): 24 | img = cv2.imread(str(img_path)) 25 | face = detector.detect_align(img)[0].cpu().numpy() 26 | if len(face.shape) > 1: 27 | save_name = f'{save_path}/{dir_path.name}_{counter}.jpg' 28 | cv2.imwrite(save_name, face[0]) 29 | counter += 1 30 | else: 31 | print(img_path, 'in this image did not detect any face') -------------------------------------------------------------------------------- /InsightFace/camera_verify.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | from data.config import get_config 4 | from models.Learner import face_learner 5 | from utils import update_facebank, load_facebank, special_draw 6 | from Retinaface.Retinaface import FaceDetector 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='for face verification') 11 | parser.add_argument('-th', '--threshold', help='threshold to decide identical faces', default=1.45, type=float) 12 | parser.add_argument("-u", "--update", default=True, help="whether perform update the facebank") 13 | parser.add_argument("-tta", "--tta", default=True, help="whether test time augmentation") 14 | parser.add_argument('-m', '--mobilenet', default=False, help="use mobilenet for backbone") 15 | args = parser.parse_args() 16 | 17 | conf = get_config(training=False) 18 | detector = FaceDetector(name='mobilenet', weight_path='Retinaface/weights/mobilenet.pth', device=conf.device) 19 | conf.use_mobilfacenet = args.mobilenet 20 | face_rec = face_learner(conf, inference=True) 21 | face_rec.threshold = args.threshold 22 | face_rec.model.eval() 23 | 24 | if args.update: 25 | targets, names = update_facebank(conf, face_rec.model, detector, tta=args.tta) 26 | else: 27 | targets, names = load_facebank(conf) 28 | 29 | # init camera 30 | cap = cv2.VideoCapture(1) 31 | cap.set(3, 1280) 32 | cap.set(4, 720) 33 | # frame rate 6 due to my laptop is quite slow... 34 | while cap.isOpened(): 35 | _, frame = cap.read() 36 | faces, boxes, scores, landmarks = detector.detect_align(frame) 37 | if len(faces.shape) > 1: 38 | results, score = face_rec.infer(conf, faces, targets, args.tta) 39 | for idx, bbox in enumerate(boxes): 40 | special_draw(frame, bbox, landmarks[idx], names[results[idx] + 1], score[idx]) 41 | cv2.imshow('face Capture', frame) 42 | if cv2.waitKey(1) & 0xFF == ord('q'): 43 | break 44 | 45 | cap.release() 46 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /InsightFace/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/InsightFace/data/__init__.py -------------------------------------------------------------------------------- /InsightFace/data/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | from pathlib import Path 3 | import torch 4 | from torch.nn import CrossEntropyLoss 5 | 6 | 7 | def get_config(training=True): 8 | conf = edict() 9 | conf.data_path = Path('data') 10 | conf.work_path = Path('weights/') 11 | conf.model_path = conf.work_path / 'models' 12 | conf.log_path = conf.work_path / 'log' 13 | conf.save_path = conf.work_path 14 | conf.input_size = [112, 112] 15 | conf.embedding_size = 512 16 | conf.use_mobilfacenet = False 17 | conf.net_depth = 50 18 | conf.drop_ratio = 0.6 19 | conf.net_mode = 'ir_se' # or 'ir' 20 | conf.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | conf.data_mode = 'emore' 22 | conf.vgg_folder = conf.data_path / 'faces_vgg_112x112' 23 | conf.ms1m_folder = conf.data_path / 'faces_ms1m_112x112' 24 | conf.emore_folder = conf.data_path / 'faces_emore' 25 | conf.batch_size = 100 # irse net depth 50 26 | # conf.batch_size = 200 # mobilefacenet 27 | # --------------------Training Config ------------------------ 28 | if training: 29 | conf.log_path = conf.work_path / 'log' 30 | conf.save_path = conf.work_path / 'save' 31 | # conf.weight_decay = 5e-4 32 | conf.lr = 1e-3 33 | conf.momentum = 0.9 34 | conf.pin_memory = True 35 | # conf.num_workers = 4 # when batchsize is 200 36 | conf.num_workers = 3 37 | conf.ce_loss = CrossEntropyLoss() 38 | # --------------------Inference Config ------------------------ 39 | else: 40 | conf.facebank_path = conf.data_path / 'facebank' 41 | conf.threshold = 1.5 42 | conf.face_limit = 10 43 | # when inference, at maximum detect 10 faces in one image, my laptop is slow 44 | return conf 45 | -------------------------------------------------------------------------------- /InsightFace/data/data_pipe.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, ConcatDataset, DataLoader 2 | from torchvision import transforms as trans 3 | from torchvision.datasets import ImageFolder 4 | from PIL import ImageFile 5 | 6 | ImageFile.LOAD_TRUNCATED_IMAGES = True 7 | import numpy as np 8 | import bcolz 9 | 10 | 11 | def de_preprocess(tensor): 12 | return tensor * 0.5 + 0.5 13 | 14 | 15 | def get_train_dataset(imgs_folder): 16 | train_transform = trans.Compose([ 17 | trans.RandomHorizontalFlip(), 18 | trans.ToTensor(), 19 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 20 | ]) 21 | ds = ImageFolder(imgs_folder, train_transform) 22 | class_num = ds[-1][1] + 1 23 | return ds, class_num 24 | 25 | 26 | def get_train_loader(conf): 27 | if conf.data_mode in ['ms1m', 'concat']: 28 | ms1m_ds, ms1m_class_num = get_train_dataset(conf.ms1m_folder / 'imgs') 29 | print('ms1m loader generated') 30 | if conf.data_mode in ['vgg', 'concat']: 31 | vgg_ds, vgg_class_num = get_train_dataset(conf.vgg_folder / 'imgs') 32 | print('vgg loader generated') 33 | if conf.data_mode == 'vgg': 34 | ds = vgg_ds 35 | class_num = vgg_class_num 36 | elif conf.data_mode == 'ms1m': 37 | ds = ms1m_ds 38 | class_num = ms1m_class_num 39 | elif conf.data_mode == 'concat': 40 | for i, (url, label) in enumerate(vgg_ds.imgs): 41 | vgg_ds.imgs[i] = (url, label + ms1m_class_num) 42 | ds = ConcatDataset([ms1m_ds, vgg_ds]) 43 | class_num = vgg_class_num + ms1m_class_num 44 | elif conf.data_mode == 'emore': 45 | ds, class_num = get_train_dataset(conf.emore_folder / 'imgs') 46 | loader = DataLoader(ds, batch_size=conf.batch_size, shuffle=True, pin_memory=conf.pin_memory, 47 | num_workers=conf.num_workers) 48 | return loader, class_num 49 | 50 | 51 | def get_val_pair(path, name): 52 | carray = bcolz.carray(rootdir=path / name, mode='r') 53 | issame = np.load(path / '{}_list.npy'.format(name)) 54 | return carray, issame 55 | 56 | 57 | def get_val_data(data_path): 58 | agedb_30, agedb_30_issame = get_val_pair(data_path, 'agedb_30') 59 | cfp_fp, cfp_fp_issame = get_val_pair(data_path, 'cfp_fp') 60 | lfw, lfw_issame = get_val_pair(data_path, 'lfw') 61 | return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame 62 | 63 | 64 | class train_dataset(Dataset): 65 | def __init__(self, imgs_bcolz, label_bcolz, h_flip=True): 66 | self.imgs = bcolz.carray(rootdir=imgs_bcolz) 67 | self.labels = bcolz.carray(rootdir=label_bcolz) 68 | self.h_flip = h_flip 69 | self.length = len(self.imgs) - 1 70 | if h_flip: 71 | self.transform = trans.Compose([ 72 | trans.ToPILImage(), 73 | trans.RandomHorizontalFlip(), 74 | trans.ToTensor(), 75 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 76 | ]) 77 | self.class_num = self.labels[-1] + 1 78 | 79 | def __len__(self): 80 | return self.length 81 | 82 | def __getitem__(self, index): 83 | img = torch.tensor(self.imgs[index + 1], dtype=torch.float) 84 | label = torch.tensor(self.labels[index + 1], dtype=torch.long) 85 | if self.h_flip: 86 | img = de_preprocess(img) 87 | img = self.transform(img) 88 | return img, label 89 | -------------------------------------------------------------------------------- /InsightFace/models/train.py: -------------------------------------------------------------------------------- 1 | from data.config import get_config 2 | from models.Learner import face_learner 3 | import argparse 4 | 5 | # python train.py -net mobilefacenet -b 200 -w 4 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description='for face verification') 9 | parser.add_argument("-e", "--epochs", help="training epochs", default=20, type=int) 10 | parser.add_argument("-net", "--net_mode", help="which network, [ir, ir_se, mobilefacenet]",default='ir_se', type=str) 11 | parser.add_argument("-depth", "--net_depth", help="how many layers [50,100,152]", default=50, type=int) 12 | parser.add_argument('-lr','--lr',help='learning rate',default=1e-3, type=float) 13 | parser.add_argument("-b", "--batch_size", help="batch_size", default=96, type=int) 14 | parser.add_argument("-w", "--num_workers", help="workers number", default=3, type=int) 15 | parser.add_argument("-d", "--data_mode", help="use which database, [vgg, ms1m, emore, concat]",default='emore', type=str) 16 | args = parser.parse_args() 17 | 18 | conf = get_config() 19 | 20 | if args.net_mode == 'mobilefacenet': 21 | conf.use_mobilfacenet = True 22 | else: 23 | conf.net_mode = args.net_mode 24 | conf.net_depth = args.net_depth 25 | 26 | conf.lr = args.lr 27 | conf.batch_size = args.batch_size 28 | conf.num_workers = args.num_workers 29 | conf.data_mode = args.data_mode 30 | learner = face_learner(conf) 31 | learner.train(conf, args.epochs) -------------------------------------------------------------------------------- /InsightFace/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import numpy as np 3 | import io, cv2 4 | import torch 5 | from models.model import l2_norm 6 | import matplotlib.pyplot as plt 7 | plt.switch_backend('agg') 8 | 9 | 10 | def faces_preprocessing(faces, device): 11 | faces = faces.permute(0, 3, 1, 2).float() 12 | faces = faces.div(255).to(device) 13 | mu = torch.as_tensor([.5, .5, .5], dtype=faces.dtype, device=device) 14 | faces[:].sub_(mu[:, None, None]).div_(mu[:, None, None]) 15 | return faces 16 | 17 | 18 | def separate_bn_paras(modules): 19 | if not isinstance(modules, list): 20 | modules = [*modules.modules()] 21 | paras_only_bn = [] 22 | paras_wo_bn = [] 23 | for layer in modules: 24 | if 'model' in str(layer.__class__): 25 | continue 26 | if 'container' in str(layer.__class__): 27 | continue 28 | else: 29 | if 'batchnorm' in str(layer.__class__): 30 | paras_only_bn.extend([*layer.parameters()]) 31 | else: 32 | paras_wo_bn.extend([*layer.parameters()]) 33 | return paras_only_bn, paras_wo_bn 34 | 35 | 36 | def update_facebank(conf, model, detector, tta=True): 37 | model.eval() 38 | faces_embs = torch.empty(0).to(conf.device) 39 | names = np.array(['Unknown']) 40 | for path in conf.facebank_path.iterdir(): 41 | if path.is_file(): 42 | continue 43 | faces = [] 44 | for img_path in path.iterdir(): 45 | face = cv2.imread(str(img_path)) 46 | if face.shape[:2] != (112, 112): # if img be not face 47 | face = detector.detect_align(face)[0] 48 | cv2.imwrite(img_path, face) 49 | else: 50 | face = torch.tensor(face).unsqueeze(0) 51 | faces.append(face) 52 | 53 | faces = torch.cat(faces) 54 | if len(faces.shape) <= 3: 55 | continue 56 | 57 | with torch.no_grad(): 58 | faces = faces_preprocessing(faces, device=conf.device) 59 | if tta: 60 | face_emb = model(faces) 61 | hflip_emb = model(faces.flip(-1)) # image horizontal flip 62 | face_embs = l2_norm(face_emb + hflip_emb) 63 | else: 64 | face_embs = model(faces) 65 | 66 | faces_embs = torch.cat((faces_embs, face_embs.mean(0, keepdim=True))) 67 | names = np.append(names, path.name) 68 | 69 | torch.save(faces_embs, conf.facebank_path/'facebank.pth') 70 | np.save(conf.facebank_path/'names', names) 71 | print('from recognizer: facebank updated') 72 | return faces_embs, names 73 | 74 | 75 | def load_facebank(conf): 76 | embs = torch.load(conf.facebank_path/'facebank.pth') 77 | names = np.load(conf.facebank_path/'names.npy') 78 | print('from recognizer: facebank loaded') 79 | return embs, names 80 | 81 | 82 | def get_time(): 83 | return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-') 84 | 85 | 86 | def gen_plot(fpr, tpr): 87 | """Create a pyplot plot and save to buffer.""" 88 | plt.figure() 89 | plt.xlabel("FPR", fontsize=14) 90 | plt.ylabel("TPR", fontsize=14) 91 | plt.title("ROC Curve", fontsize=14) 92 | plt.plot(fpr, tpr, linewidth=2) 93 | buf = io.BytesIO() 94 | plt.savefig(buf, format='jpeg') 95 | buf.seek(0) 96 | plt.close() 97 | return buf 98 | 99 | 100 | def draw(bbox, name, frame): 101 | frame = cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 6) 102 | frame = cv2.putText(frame, name, (bbox[0], bbox[1]), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 4, cv2.LINE_AA) 103 | return frame 104 | 105 | 106 | def special_draw(img, box, landmarsk, name, score): 107 | """draw a bounding box on image""" 108 | color = (148, 133, 0) 109 | tl = round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line thickness 110 | c1, c2 = (box[0], box[1]), (box[2], box[3]) 111 | # draw bounding box 112 | cv2.rectangle(img, c1, c2, color, thickness=tl) 113 | # draw landmark 114 | for land in landmarsk: 115 | cv2.circle(img, tuple(land.int().tolist()), 3, color, -1) 116 | # draw score 117 | score = 100-(score*100/1.4) 118 | score = 0 if score < 0 else score 119 | bar = (box[3] + 2) - (box[1] - 2) 120 | score_final = bar - (score*bar/100) 121 | cv2.rectangle(img, (box[2] + 1, box[1] - 2 + score_final), (box[2] + (tl+5), box[3] + 2), color, -1) 122 | # draw label 123 | tf = max(tl - 1, 1) # font thickness 124 | t_size = cv2.getTextSize(name, 0, fontScale=tl / 3, thickness=tf)[0] 125 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 126 | cv2.rectangle(img, c1, c2, color, -1) # filled 127 | cv2.putText(img, name, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) -------------------------------------------------------------------------------- /InsightFace/version: -------------------------------------------------------------------------------- 1 | find any person add to image on after mean on all of image embs 2 | verify 1 frame in 2 frame camera show 3 | inanity top_k and keep_top_k in RetinaFace -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Morphling - Text to human pipeline. 2 | [![](http://img.youtube.com/vi/VxrtbWqwyUk/0.jpg)](http://www.youtube.com/watch?v=VxrtbWqwyUk "Creation of this tool") 3 | 4 | A Machine Learning pipeline that goes from text to fully animated faces and voiced faces! 5 | 6 | # Supporting this project 7 | If you'd like to support my work please consider: 8 | 9 | If you appreciate my work send all crypto donations in Eth, Bnb, Matic, Avax etc on any chain to : 0xe0a09b49721FBD8B23c67a3a9fdE44be4412B8fD 10 | 11 | # Video Tool Use Guide 12 | 13 | TODO 14 | 15 | # Requirements 16 | 17 | ## Main Requirements 18 | CUDA Enabled GPU 19 | 20 | Python > 3 21 | 22 | FFMPEG (Ensure it is in path) 23 | 24 | ## PIP Modules 25 | ```python 26 | pip3 install -r requirements.txt 27 | ``` 28 | 29 | ## Useage 30 | ```python 31 | python3 interface.py 32 | ``` 33 | # Downloads 34 | A complete downloadable version of this project found on google drive: https://drive.google.com/drive/folders/1uNL7wzCTbG7opHDvW61ly_N8Oa4LTSxy?usp=sharing 35 | 36 | ## Models 37 | 38 | ### Drive download 39 | All models can be found on google drive. They should be placed in the models folder.: https://drive.google.com/drive/u/2/folders/1LVEHsGlU5yipw6boRv1ocYfWcPaCi1ox 40 | 41 | ### Seperate Downloads 42 | These should be placed in the 'models' folder: 43 | Vox-adv-cpk.pth.tar OR vox-cpk.pth.tar FROM : https://drive.google.com/drive/folders/1PyQJmkdCsAkOYwUyaj_l-l0as-iLDgeH or https://yadi.sk/d/lEw8uRm140L_eQ 44 | alternatively you can use the avatarify version from https://www.dropbox.com/s/t7h24l6wx9vreto/vox-adv-cpk.pth.tar?dl=0 or https://yadi.sk/d/M0FWpz2ExBfgAA or https://drive.google.com/file/d/1coUCdyRXDbpWnEkA99NLNY60mb9dQ_n3/view 45 | libritts - https://drive.google.com/uc?id=1KhJcPawFgmfvwV7tQAOeC253rYstLrs8 46 | waveglow - https://drive.google.com/uc?id=1rpK8CzAAirq9sWZhe9nlfvxMF1dRgFbF 47 | lipgan - https://drive.google.com/uc?id=1DtXY5Ei_V6QjrLwfe7YDrmbSCDu6iru1 48 | face detector: http://dlib.net/files/mmod_human_face_detector.dat.bz2 49 | 50 | ## This work has adapted notebooks: 51 | 1) Flowtron - https://colab.research.google.com/github/tugstugi/dl-colab-notebooks/blob/master/notebooks/NVidia_Flowtron_Waveglow.ipynb 52 | 2) LipGan - https://colab.research.google.com/github/tugstugi/dl-colab-notebooks/blob/master/notebooks/LipGAN.ipynb#scrollTo=ktXeABjLYb70 53 | 3) First order motion - https://colab.research.google.com/github/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb 54 | 55 | ## This makes use of several technologies 56 | Each of these uses a different license. Be sure to check them out for more info. 57 | 1) Flowtron/Waveglow - https://github.com/NVIDIA/flowtron 58 | 2) LipGan - https://github.com/Rudrabha/LipGAN 59 | 3) First Order Motion - https://github.com/AliaksandrSiarohin/first-order-model 60 | 4) ISR - https://idealo.github.io/image-super-resolution/ 61 | 5) Stylegan2-Pytorch - https://github.com/Tetratrio/stylegan2_pytorch 62 | 63 | # License 64 | This repo is a melting pot of different licenses, including GPL and MIT licenses. 65 | 66 | My own contributions are under the 'WTF Public License' 67 | 68 | 69 | ``` 70 | DO WHAT THE F*CK YOU WANT TO PUBLIC LICENCE 71 | Version 3.1, July 2019 72 | 73 | by Sam Hocevar 74 | theiostream 75 | dtf 76 | 77 | DO WHAT THE F*CK YOU WANT TO PUBLIC LICENCE 78 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 79 | 80 | 0. You just DO WHAT THE F*CK YOU WANT TO. 81 | 82 | 83 | ``` 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /Retinaface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/Retinaface/__init__.py -------------------------------------------------------------------------------- /Retinaface/from_camera.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from time import time 3 | from Retinaface.Retinaface import FaceDetector 4 | 5 | detector = FaceDetector(name='mobilenet', weight_path='Retinaface/weights/mobilenet.pth', device='cuda', face_size=(224, 224)) 6 | 7 | cap = cv2.VideoCapture(0) 8 | cap.set(3, 1280) 9 | cap.set(4, 720) 10 | while True: 11 | _, frame = cap.read() 12 | 13 | tic = time() 14 | # boxes, scores, landmarks = detector.detect_faces(frame) 15 | faces, boxes, scores, landmarks = detector.detect_align(frame) 16 | print('forward time: ', time() - tic) 17 | if len(faces.shape) > 1: 18 | for i, f in enumerate(faces.cpu().numpy()): 19 | cv2.imshow(f'align_{i}', f) 20 | 21 | for b in boxes: 22 | cv2.rectangle(frame, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (255, 0, 0), 3) 23 | for p in landmarks: 24 | for i in range(5): 25 | cv2.circle(frame, (p[i][0], p[i][1]), 3, (0, 255, 0), -1) 26 | 27 | cv2.imshow('frame', frame) 28 | if cv2.waitKey(1) & 0xFF == ord('q'): 29 | break 30 | 31 | cap.release() 32 | cv2.destroyAllWindows() 33 | -------------------------------------------------------------------------------- /Retinaface/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/Retinaface/models/__init__.py -------------------------------------------------------------------------------- /Retinaface/models/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv_bn(inp, oup, stride=1, leaky=0): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 9 | nn.BatchNorm2d(oup), 10 | nn.LeakyReLU(negative_slope=leaky, inplace=True) 11 | ) 12 | 13 | 14 | def conv_bn_no_relu(inp, oup, stride): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 17 | nn.BatchNorm2d(oup), 18 | ) 19 | 20 | 21 | def conv_bn1X1(inp, oup, stride, leaky=0): 22 | return nn.Sequential( 23 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), 24 | nn.BatchNorm2d(oup), 25 | nn.LeakyReLU(negative_slope=leaky, inplace=True) 26 | ) 27 | 28 | 29 | def conv_dw(inp, oup, stride, leaky=0.1): 30 | return nn.Sequential( 31 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 32 | nn.BatchNorm2d(inp), 33 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 34 | 35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 38 | ) 39 | 40 | 41 | class SSH(nn.Module): 42 | def __init__(self, in_channel, out_channel): 43 | super(SSH, self).__init__() 44 | assert out_channel % 4 == 0 45 | leaky = 0 46 | if (out_channel <= 64): 47 | leaky = 0.1 48 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) 49 | 50 | self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) 51 | self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 52 | 53 | self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) 54 | self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 55 | 56 | def forward(self, input): 57 | conv3X3 = self.conv3X3(input) 58 | 59 | conv5X5_1 = self.conv5X5_1(input) 60 | conv5X5 = self.conv5X5_2(conv5X5_1) 61 | 62 | conv7X7_2 = self.conv7X7_2(conv5X5_1) 63 | conv7X7 = self.conv7x7_3(conv7X7_2) 64 | 65 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) 66 | out = F.relu(out) 67 | return out 68 | 69 | 70 | class FPN(nn.Module): 71 | def __init__(self, in_channels_list, out_channels): 72 | super(FPN, self).__init__() 73 | leaky = 0 74 | if (out_channels <= 64): 75 | leaky = 0.1 76 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) 77 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) 78 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) 79 | 80 | self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) 81 | self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) 82 | 83 | def forward(self, input): 84 | # names = list(input.keys()) 85 | input = list(input.values()) 86 | 87 | output1 = self.output1(input[0]) 88 | output2 = self.output2(input[1]) 89 | output3 = self.output3(input[2]) 90 | 91 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest") 92 | output2 = output2 + up3 93 | output2 = self.merge2(output2) 94 | 95 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest") 96 | output1 = output1 + up2 97 | output1 = self.merge1(output1) 98 | 99 | out = [output1, output2, output3] 100 | return out 101 | 102 | 103 | class MobileNetV1(nn.Module): 104 | def __init__(self): 105 | super(MobileNetV1, self).__init__() 106 | self.stage1 = nn.Sequential( 107 | conv_bn(3, 8, 2, leaky=0.1), # 3 108 | conv_dw(8, 16, 1), # 7 109 | conv_dw(16, 32, 2), # 11 110 | conv_dw(32, 32, 1), # 19 111 | conv_dw(32, 64, 2), # 27 112 | conv_dw(64, 64, 1), # 43 113 | ) 114 | self.stage2 = nn.Sequential( 115 | conv_dw(64, 128, 2), # 43 + 16 = 59 116 | conv_dw(128, 128, 1), # 59 + 32 = 91 117 | conv_dw(128, 128, 1), # 91 + 32 = 123 118 | conv_dw(128, 128, 1), # 123 + 32 = 155 119 | conv_dw(128, 128, 1), # 155 + 32 = 187 120 | conv_dw(128, 128, 1), # 187 + 32 = 219 121 | ) 122 | self.stage3 = nn.Sequential( 123 | conv_dw(128, 256, 2), # 219 +3 2 = 241 124 | conv_dw(256, 256, 1), # 241 + 64 = 301 125 | ) 126 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 127 | self.fc = nn.Linear(256, 1000) 128 | 129 | def forward(self, x): 130 | x = self.stage1(x) 131 | x = self.stage2(x) 132 | x = self.stage3(x) 133 | x = self.avg(x) 134 | # x = self.model(x) 135 | x = x.view(-1, 256) 136 | x = self.fc(x) 137 | return x 138 | -------------------------------------------------------------------------------- /Retinaface/models/retinaface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models._utils as _utils 5 | 6 | from ..models.net import FPN as FPN 7 | from ..models.net import MobileNetV1 as MobileNetV1 8 | from ..models.net import SSH as SSH 9 | 10 | 11 | class ClassHead(nn.Module): 12 | def __init__(self, inchannels=512, num_anchors=3): 13 | super(ClassHead, self).__init__() 14 | self.num_anchors = num_anchors 15 | self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) 16 | 17 | def forward(self, x): 18 | out = self.conv1x1(x) 19 | out = out.permute(0, 2, 3, 1).contiguous() 20 | 21 | return out.view(out.shape[0], -1, 2) 22 | 23 | 24 | class BboxHead(nn.Module): 25 | def __init__(self, inchannels=512, num_anchors=3): 26 | super(BboxHead, self).__init__() 27 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) 28 | 29 | def forward(self, x): 30 | out = self.conv1x1(x) 31 | out = out.permute(0, 2, 3, 1).contiguous() 32 | 33 | return out.view(out.shape[0], -1, 4) 34 | 35 | 36 | class LandmarkHead(nn.Module): 37 | def __init__(self, inchannels=512, num_anchors=3): 38 | super(LandmarkHead, self).__init__() 39 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) 40 | 41 | def forward(self, x): 42 | out = self.conv1x1(x) 43 | out = out.permute(0, 2, 3, 1).contiguous() 44 | 45 | return out.view(out.shape[0], -1, 10) 46 | 47 | 48 | class RetinaFace(nn.Module): 49 | def __init__(self, cfg=None, phase='train'): 50 | """ 51 | :param cfg: Network related settings. 52 | :param phase: train or test. 53 | """ 54 | super(RetinaFace, self).__init__() 55 | self.phase = phase 56 | # backbone = MobileNetV1() 57 | if cfg['name'] == 'mobilenet0.25': 58 | backbone = MobileNetV1() 59 | elif cfg['name'] == 'Resnet50': 60 | import torchvision.models as models 61 | backbone = models.resnet50(pretrained=cfg['pretrain']) 62 | 63 | self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers']) 64 | in_channels_stage2 = cfg['in_channel'] 65 | in_channels_list = [ 66 | in_channels_stage2 * 2, 67 | in_channels_stage2 * 4, 68 | in_channels_stage2 * 8, 69 | ] 70 | out_channels = cfg['out_channel'] 71 | self.fpn = FPN(in_channels_list, out_channels) 72 | self.ssh1 = SSH(out_channels, out_channels) 73 | self.ssh2 = SSH(out_channels, out_channels) 74 | self.ssh3 = SSH(out_channels, out_channels) 75 | 76 | self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel']) 77 | self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) 78 | self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) 79 | 80 | def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): 81 | classhead = nn.ModuleList() 82 | for i in range(fpn_num): 83 | classhead.append(ClassHead(inchannels, anchor_num)) 84 | return classhead 85 | 86 | def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): 87 | bboxhead = nn.ModuleList() 88 | for i in range(fpn_num): 89 | bboxhead.append(BboxHead(inchannels, anchor_num)) 90 | return bboxhead 91 | 92 | def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): 93 | landmarkhead = nn.ModuleList() 94 | for i in range(fpn_num): 95 | landmarkhead.append(LandmarkHead(inchannels, anchor_num)) 96 | return landmarkhead 97 | 98 | def forward(self, inputs): 99 | out = self.body(inputs) 100 | 101 | # FPN 102 | fpn = self.fpn(out) 103 | 104 | # SSH 105 | feature1 = self.ssh1(fpn[0]) 106 | feature2 = self.ssh2(fpn[1]) 107 | feature3 = self.ssh3(fpn[2]) 108 | features = [feature1, feature2, feature3] 109 | 110 | bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) 111 | classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1) 112 | ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1) 113 | 114 | if self.phase == 'train': 115 | output = (bbox_regressions, classifications, ldm_regressions) 116 | else: 117 | output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) 118 | return output 119 | -------------------------------------------------------------------------------- /Retinaface/models/slim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models.detection.backbone_utils as backbone_utils 4 | import torchvision.models._utils as _utils 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | def conv_bn(inp, oup, stride = 1): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | nn.BatchNorm2d(oup), 12 | nn.ReLU(inplace=True) 13 | ) 14 | 15 | def depth_conv2d(inp, oup, kernel=1, stride=1, pad=0): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, inp, kernel_size = kernel, stride = stride, padding=pad, groups=inp), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(inp, oup, kernel_size=1) 20 | ) 21 | 22 | def conv_dw(inp, oup, stride): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 25 | nn.BatchNorm2d(inp), 26 | nn.ReLU(inplace=True), 27 | 28 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 29 | nn.BatchNorm2d(oup), 30 | nn.ReLU(inplace=True) 31 | ) 32 | 33 | class Slim(nn.Module): 34 | def __init__(self, cfg = None, phase = 'train'): 35 | """ 36 | :param cfg: Network related settings. 37 | :param phase: train or test. 38 | """ 39 | super(Slim, self).__init__() 40 | self.phase = phase 41 | self.num_classes = 2 42 | 43 | self.conv1 = conv_bn(3, 16, 2) 44 | self.conv2 = conv_dw(16, 32, 1) 45 | self.conv3 = conv_dw(32, 32, 2) 46 | self.conv4 = conv_dw(32, 32, 1) 47 | self.conv5 = conv_dw(32, 64, 2) 48 | self.conv6 = conv_dw(64, 64, 1) 49 | self.conv7 = conv_dw(64, 64, 1) 50 | self.conv8 = conv_dw(64, 64, 1) 51 | 52 | self.conv9 = conv_dw(64, 128, 2) 53 | self.conv10 = conv_dw(128, 128, 1) 54 | self.conv11 = conv_dw(128, 128, 1) 55 | 56 | self.conv12 = conv_dw(128, 256, 2) 57 | self.conv13 = conv_dw(256, 256, 1) 58 | 59 | self.conv14 = nn.Sequential( 60 | nn.Conv2d(in_channels=256, out_channels=64, kernel_size=1), 61 | nn.ReLU(inplace=True), 62 | depth_conv2d(64, 256, kernel=3, stride=2, pad=1), 63 | nn.ReLU(inplace=True) 64 | ) 65 | self.loc, self.conf, self.landm = self.multibox(self.num_classes); 66 | 67 | def multibox(self, num_classes): 68 | loc_layers = [] 69 | conf_layers = [] 70 | landm_layers = [] 71 | loc_layers += [depth_conv2d(64, 3 * 4, kernel=3, pad=1)] 72 | conf_layers += [depth_conv2d(64, 3 * num_classes, kernel=3, pad=1)] 73 | landm_layers += [depth_conv2d(64, 3 * 10, kernel=3, pad=1)] 74 | 75 | loc_layers += [depth_conv2d(128, 2 * 4, kernel=3, pad=1)] 76 | conf_layers += [depth_conv2d(128, 2 * num_classes, kernel=3, pad=1)] 77 | landm_layers += [depth_conv2d(128, 2 * 10, kernel=3, pad=1)] 78 | 79 | loc_layers += [depth_conv2d(256, 2 * 4, kernel=3, pad=1)] 80 | conf_layers += [depth_conv2d(256, 2 * num_classes, kernel=3, pad=1)] 81 | landm_layers += [depth_conv2d(256, 2 * 10, kernel=3, pad=1)] 82 | 83 | loc_layers += [nn.Conv2d(256, 3 * 4, kernel_size=3, padding=1)] 84 | conf_layers += [nn.Conv2d(256, 3 * num_classes, kernel_size=3, padding=1)] 85 | landm_layers += [nn.Conv2d(256, 3 * 10, kernel_size=3, padding=1)] 86 | return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers), nn.Sequential(*landm_layers) 87 | 88 | 89 | def forward(self,inputs): 90 | detections = list() 91 | loc = list() 92 | conf = list() 93 | landm = list() 94 | 95 | x1 = self.conv1(inputs) 96 | x2 = self.conv2(x1) 97 | x3 = self.conv3(x2) 98 | x4 = self.conv4(x3) 99 | x5 = self.conv5(x4) 100 | x6 = self.conv6(x5) 101 | x7 = self.conv7(x6) 102 | x8 = self.conv8(x7) 103 | detections.append(x8) 104 | 105 | x9 = self.conv9(x8) 106 | x10 = self.conv10(x9) 107 | x11 = self.conv11(x10) 108 | detections.append(x11) 109 | 110 | x12 = self.conv12(x11) 111 | x13 = self.conv13(x12) 112 | detections.append(x13) 113 | 114 | x14= self.conv14(x13) 115 | detections.append(x14) 116 | 117 | for (x, l, c, lam) in zip(detections, self.loc, self.conf, self.landm): 118 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 119 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 120 | landm.append(lam(x).permute(0, 2, 3, 1).contiguous()) 121 | 122 | bbox_regressions = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1) 123 | classifications = torch.cat([o.view(o.size(0), -1, 2) for o in conf], 1) 124 | ldm_regressions = torch.cat([o.view(o.size(0), -1, 10) for o in landm], 1) 125 | 126 | 127 | 128 | if self.phase == 'train': 129 | output = (bbox_regressions, classifications, ldm_regressions) 130 | else: 131 | output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) 132 | return output 133 | -------------------------------------------------------------------------------- /Retinaface/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/Retinaface/utils/__init__.py -------------------------------------------------------------------------------- /Retinaface/utils/alignment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 1 15:43:29 2020 4 | @author: Sajjad Ayobbi 5 | """ 6 | import cv2 7 | import numpy as np 8 | from skimage import transform as trans 9 | 10 | # reference facial points, a list of coordinates (x,y) 11 | # REFERENCE_FACIAL_POINTS = [ 12 | # [30.29459953, 51.69630051], 13 | # [65.53179932, 51.50139999], 14 | # [48.02519989, 71.73660278], 15 | # [33.54930115, 92.3655014], 16 | # [62.72990036, 92.20410156] 17 | # ] 18 | 19 | REFERENCE_FACIAL_POINTS = [ 20 | [30.29459953, 51.69630051], 21 | [65.53179932, 51.50139999], 22 | [48.02519989, 71.73660278], 23 | [33.54930115, 87], 24 | [62.72990036, 87] 25 | ] 26 | 27 | DEFAULT_CROP_SIZE = (96, 112) 28 | 29 | 30 | class FaceWarpException(Exception): 31 | def __str__(self): 32 | return 'In File {}:{}'.format( 33 | __file__, super.__str__(self)) 34 | 35 | 36 | def get_reference_facial_points(output_size=(112, 112)): 37 | 38 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) 39 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE) 40 | 41 | # size_diff = max(tmp_crop_size) - tmp_crop_size 42 | # tmp_5pts += size_diff / 2 43 | # tmp_crop_size += size_diff 44 | # return tmp_5pts 45 | 46 | x_scale = output_size[0]/tmp_crop_size[0] 47 | y_scale = output_size[1]/tmp_crop_size[1] 48 | tmp_5pts[:, 0] *= x_scale 49 | tmp_5pts[:, 1] *= y_scale 50 | 51 | return tmp_5pts 52 | 53 | -------------------------------------------------------------------------------- /Retinaface/utils/box_utils.py: -------------------------------------------------------------------------------- 1 | from itertools import product as product 2 | from math import ceil 3 | import torch 4 | 5 | 6 | # Original author: Francisco Massa: 7 | # https://github.com/fmassa/object-detection.torch 8 | # Ported to PyTorch by Max deGroot (02/01/2017) 9 | 10 | def prior_box(cfg, image_size=None): 11 | steps = cfg['steps'] 12 | feature_maps = [[ceil(image_size[0] / step), ceil(image_size[1] / step)] for step in steps] 13 | min_sizes_ = cfg['min_sizes'] 14 | anchors = [] 15 | 16 | for k, f in enumerate(feature_maps): 17 | min_sizes = min_sizes_[k] 18 | for i, j in product(range(f[0]), range(f[1])): 19 | for min_size in min_sizes: 20 | s_kx = min_size / image_size[1] 21 | s_ky = min_size / image_size[0] 22 | dense_cx = [x * steps[k] / image_size[1] for x in [j + 0.5]] 23 | dense_cy = [y * steps[k] / image_size[0] for y in [i + 0.5]] 24 | for cy, cx in product(dense_cy, dense_cx): 25 | anchors += [cx, cy, s_kx, s_ky] 26 | 27 | # back to torch land 28 | output = torch.Tensor(anchors).view(-1, 4) 29 | return output 30 | 31 | 32 | # Adapted from https://github.com/Hakuyume/chainer-ssd 33 | def decode(loc, priors, variances): 34 | """Decode locations from predictions using priors to undo 35 | the encoding we did for offset regression at train time. 36 | Args: 37 | loc (tensor): location predictions for loc layers, 38 | Shape: [num_priors,4] 39 | priors (tensor): Prior boxes in center-offset form. 40 | Shape: [num_priors,4]. 41 | variances: (list[float]) Variances of priorboxes 42 | Return: 43 | decoded bounding box predictions 44 | """ 45 | 46 | boxes = torch.cat(( 47 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 48 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 49 | boxes[:, :2] -= boxes[:, 2:] / 2 50 | boxes[:, 2:] += boxes[:, :2] 51 | return boxes 52 | 53 | 54 | def decode_landmark(pre, priors, variances): 55 | """Decode landm from predictions using priors to undo 56 | the encoding we did for offset regression at train time. 57 | Args: 58 | pre (tensor): landm predictions for loc layers, 59 | Shape: [num_priors,10] 60 | priors (tensor): Prior boxes in center-offset form. 61 | Shape: [num_priors,4]. 62 | variances: (list[float]) Variances of priorboxes 63 | Return: 64 | decoded landm predictions 65 | """ 66 | landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], 67 | priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], 68 | priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], 69 | priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], 70 | priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], 71 | ), dim=1) 72 | return landms 73 | 74 | 75 | def nms(box, scores, thresh): 76 | x1 = box[:, 0] 77 | y1 = box[:, 1] 78 | x2 = box[:, 2] 79 | y2 = box[:, 3] 80 | zero = torch.tensor([0.0]).to(scores.device) 81 | 82 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 83 | order = scores.argsort(descending=True) 84 | 85 | keep = [] 86 | while order.shape[0] > 0: 87 | i = order[0] 88 | keep.append(i) 89 | xx1 = torch.max(x1[i], x1[order[1:]]) 90 | yy1 = torch.max(y1[i], y1[order[1:]]) 91 | xx2 = torch.min(x2[i], x2[order[1:]]) 92 | yy2 = torch.min(y2[i], y2[order[1:]]) 93 | 94 | w = torch.max(zero, xx2 - xx1 + 1) 95 | h = torch.max(zero, yy2 - yy1 + 1) 96 | inter = w * h 97 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 98 | 99 | inds = torch.where(ovr <= thresh)[0] 100 | order = order[inds + 1] 101 | 102 | return keep 103 | -------------------------------------------------------------------------------- /Retinaface/utils/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | 3 | cfg_mnet = { 4 | 'name': 'mobilenet0.25', 5 | 'min_sizes': [[16, 32], [64, 128], [256, 512]], 6 | 'steps': [8, 16, 32], 7 | 'variance': [0.1, 0.2], 8 | 'clip': False, 9 | 'loc_weight': 2.0, 10 | 'gpu_train': True, 11 | 'batch_size': 32, 12 | 'ngpu': 1, 13 | 'epoch': 250, 14 | 'decay1': 190, 15 | 'decay2': 220, 16 | 'image_size': 640, 17 | 'pretrain': True, 18 | 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3}, 19 | 'in_channel': 32, 20 | 'out_channel': 64 21 | } 22 | 23 | cfg_re50 = { 24 | 'name': 'Resnet50', 25 | 'min_sizes': [[16, 32], [64, 128], [256, 512]], 26 | 'steps': [8, 16, 32], 27 | 'variance': [0.1, 0.2], 28 | 'clip': False, 29 | 'loc_weight': 2.0, 30 | 'gpu_train': True, 31 | 'batch_size': 24, 32 | 'ngpu': 4, 33 | 'epoch': 100, 34 | 'decay1': 70, 35 | 'decay2': 90, 36 | 'image_size': 840, 37 | 'pretrain': False, 38 | 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, 39 | 'in_channel': 256, 40 | 'out_channel': 256 41 | } 42 | 43 | cfg_slim = { 44 | 'name': 'slim', 45 | 'min_sizes': [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]], 46 | 'steps': [8, 16, 32, 64], 47 | 'variance': [0.1, 0.2], 48 | 'clip': False, 49 | 'loc_weight': 2.0, 50 | 'gpu_train': True, 51 | 'batch_size': 32, 52 | 'ngpu': 1, 53 | 'epoch': 250, 54 | 'decay1': 190, 55 | 'decay2': 220, 56 | 'image_size': 300 57 | } 58 | 59 | cfg_rfb = { 60 | 'name': 'RFB', 61 | 'min_sizes': [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]], 62 | 'steps': [8, 16, 32, 64], 63 | 'variance': [0.1, 0.2], 64 | 'clip': False, 65 | 'loc_weight': 2.0, 66 | 'gpu_train': True, 67 | 'batch_size': 32, 68 | 'ngpu': 1, 69 | 'epoch': 250, 70 | 'decay1': 190, 71 | 'decay2': 220, 72 | 'image_size': 300 73 | } 74 | 75 | -------------------------------------------------------------------------------- /animate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from frames_dataset import PairedDataset 8 | from logger_firstorder import Logger, Visualizer 9 | import imageio 10 | from scipy.spatial import ConvexHull 11 | import numpy as np 12 | 13 | from sync_batchnorm import DataParallelWithCallback 14 | 15 | 16 | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, 17 | use_relative_movement=False, use_relative_jacobian=False): 18 | if adapt_movement_scale: 19 | source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume 20 | driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume 21 | adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) 22 | else: 23 | adapt_movement_scale = 1 24 | 25 | kp_new = {k: v for k, v in kp_driving.items()} 26 | 27 | if use_relative_movement: 28 | kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) 29 | kp_value_diff *= adapt_movement_scale 30 | kp_new['value'] = kp_value_diff + kp_source['value'] 31 | 32 | if use_relative_jacobian: 33 | jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) 34 | kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) 35 | 36 | return kp_new 37 | 38 | 39 | def animate(config, generator, kp_detector, checkpoint, log_dir, dataset): 40 | log_dir = os.path.join(log_dir, 'animation') 41 | png_dir = os.path.join(log_dir, 'png') 42 | animate_params = config['animate_params'] 43 | 44 | dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs']) 45 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 46 | 47 | if checkpoint is not None: 48 | Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) 49 | else: 50 | raise AttributeError("Checkpoint should be specified for mode='animate'.") 51 | 52 | if not os.path.exists(log_dir): 53 | os.makedirs(log_dir) 54 | 55 | if not os.path.exists(png_dir): 56 | os.makedirs(png_dir) 57 | 58 | if torch.cuda.is_available(): 59 | generator = DataParallelWithCallback(generator) 60 | kp_detector = DataParallelWithCallback(kp_detector) 61 | 62 | generator.eval() 63 | kp_detector.eval() 64 | 65 | for it, x in tqdm(enumerate(dataloader)): 66 | with torch.no_grad(): 67 | predictions = [] 68 | visualizations = [] 69 | 70 | driving_video = x['driving_video'] 71 | source_frame = x['source_video'][:, :, 0, :, :] 72 | 73 | kp_source = kp_detector(source_frame) 74 | kp_driving_initial = kp_detector(driving_video[:, :, 0]) 75 | 76 | for frame_idx in range(driving_video.shape[2]): 77 | driving_frame = driving_video[:, :, frame_idx] 78 | kp_driving = kp_detector(driving_frame) 79 | kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, 80 | kp_driving_initial=kp_driving_initial, **animate_params['normalization_params']) 81 | out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm) 82 | 83 | out['kp_driving'] = kp_driving 84 | out['kp_source'] = kp_source 85 | out['kp_norm'] = kp_norm 86 | 87 | del out['sparse_deformed'] 88 | 89 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) 90 | 91 | visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame, 92 | driving=driving_frame, out=out) 93 | visualization = visualization 94 | visualizations.append(visualization) 95 | 96 | predictions = np.concatenate(predictions, axis=1) 97 | result_name = "-".join([x['driving_name'][0], x['source_name'][0]]) 98 | imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8)) 99 | 100 | image_name = result_name + animate_params['format'] 101 | imageio.mimsave(os.path.join(log_dir, image_name), visualizations) 102 | -------------------------------------------------------------------------------- /application.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jul 12 06:14:38 2020 4 | 5 | @author: Hephyrius 6 | """ 7 | from interface import Ui_MainWindow 8 | 9 | Ui_MainWindow() -------------------------------------------------------------------------------- /audio_hparams.py: -------------------------------------------------------------------------------- 1 | from tensorflow.plugins.hprams.api import HParams 2 | from glob import glob 3 | import os, pickle 4 | 5 | # Default hyperparameters 6 | hparams = HParams( 7 | num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality 8 | # network 9 | rescale=True, # Whether to rescale audio prior to preprocessing 10 | rescaling_max=0.9, # Rescaling value 11 | 12 | # For cases of OOM (Not really recommended, only use if facing unsolvable OOM errors, 13 | # also consider clipping your samples to smaller chunks) 14 | max_mel_frames=900, 15 | # Only relevant when clip_mels_length = True, please only use after trying output_per_steps=3 16 | # and still getting OOM errors. 17 | 18 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 19 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 20 | # Does not work if n_ffit is not multiple of hop_size!! 21 | use_lws=False, 22 | 23 | n_fft=800, # Extra window size is filled with 0 paddings to match this parameter 24 | hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 25 | win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 26 | sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) 27 | 28 | frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) 29 | 30 | # Mel and Linear spectrograms normalization/scaling and clipping 31 | signal_normalization=True, 32 | # Whether to normalize mel spectrograms to some predefined range (following below parameters) 33 | allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True 34 | symmetric_mels=True, 35 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 36 | # faster and cleaner convergence) 37 | max_abs_value=4., 38 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 39 | # be too big to avoid gradient explosion, 40 | # not too small for fast convergence) 41 | normalize_for_wavenet=True, 42 | # whether to rescale to [0, 1] for wavenet. (better audio quality) 43 | clip_for_wavenet=True, 44 | # whether to clip [-max, max] before training/synthesizing with wavenet (better audio quality) 45 | 46 | # Contribution by @begeekmyfriend 47 | # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 48 | # levels. Also allows for better G&L phase reconstruction) 49 | preemphasize=True, # whether to apply filter 50 | preemphasis=0.97, # filter coefficient. 51 | 52 | # Limits 53 | min_level_db=-100, 54 | ref_level_db=20, 55 | fmin=55, 56 | # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 57 | # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 58 | fmax=7600, # To be increased/reduced depending on data. 59 | 60 | # Griffin Lim 61 | power=1.5, 62 | # Only used in G&L inversion, usually values between 1.2 and 1.5 are a good choice. 63 | griffin_lim_iters=60, 64 | # Number of G&L iterations, typically 30 is enough but we use 60 to ensure convergence. 65 | ) 66 | 67 | 68 | def hparams_debug_string(): 69 | values = hparams.values() 70 | hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] 71 | return "Hyperparameters:\n" + "\n".join(hp) 72 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_config": { 3 | "output_directory": "outdir", 4 | "epochs": 10000000, 5 | "learning_rate": 1e-4, 6 | "weight_decay": 1e-6, 7 | "sigma": 1.0, 8 | "iters_per_checkpoint": 5000, 9 | "batch_size": 1, 10 | "seed": 1234, 11 | "checkpoint_path": "", 12 | "ignore_layers": [], 13 | "include_layers": ["speaker", "encoder", "embedding"], 14 | "warmstart_checkpoint_path": "", 15 | "with_tensorboard": true, 16 | "fp16_run": false 17 | }, 18 | "data_config": { 19 | "training_files": "filelists/ljs_audiopaths_text_sid_train_filelist.txt", 20 | "validation_files": "filelists/ljs_audiopaths_text_sid_val_filelist.txt", 21 | "text_cleaners": ["flowtron_cleaners"], 22 | "p_arpabet": 0.5, 23 | "cmudict_path": "data/cmudict_dictionary", 24 | "sampling_rate": 22050, 25 | "filter_length": 1024, 26 | "hop_length": 256, 27 | "win_length": 1024, 28 | "mel_fmin": 0.0, 29 | "mel_fmax": 8000.0, 30 | "max_wav_value": 32768.0 31 | }, 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321" 35 | }, 36 | 37 | "model_config": { 38 | "n_speakers": 1, 39 | "n_speaker_dim": 128, 40 | "n_text": 185, 41 | "n_text_dim": 512, 42 | "n_flows": 2, 43 | "n_mel_channels": 80, 44 | "n_attn_channels": 640, 45 | "n_hidden": 1024, 46 | "n_lstm_layers": 2, 47 | "mel_encoder_n_hidden": 512, 48 | "n_components": 0, 49 | "mean_scale": 0.0, 50 | "fixed_gaussian": true, 51 | "dummy_speaker_embedding": false, 52 | "use_gate_layer": true 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /config/bair-256.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: data/bair 3 | frame_shape: [256, 256, 3] 4 | id_sampling: False 5 | augmentation_params: 6 | flip_param: 7 | horizontal_flip: True 8 | time_flip: True 9 | jitter_param: 10 | brightness: 0.1 11 | contrast: 0.1 12 | saturation: 0.1 13 | hue: 0.1 14 | 15 | 16 | model_params: 17 | common_params: 18 | num_kp: 10 19 | num_channels: 3 20 | estimate_jacobian: True 21 | kp_detector_params: 22 | temperature: 0.1 23 | block_expansion: 32 24 | max_features: 1024 25 | scale_factor: 0.25 26 | num_blocks: 5 27 | generator_params: 28 | block_expansion: 64 29 | max_features: 512 30 | num_down_blocks: 2 31 | num_bottleneck_blocks: 6 32 | estimate_occlusion_map: True 33 | dense_motion_params: 34 | block_expansion: 64 35 | max_features: 1024 36 | num_blocks: 5 37 | scale_factor: 0.25 38 | discriminator_params: 39 | scales: [1] 40 | block_expansion: 32 41 | max_features: 512 42 | num_blocks: 4 43 | sn: True 44 | 45 | train_params: 46 | num_epochs: 20 47 | num_repeats: 1 48 | epoch_milestones: [12, 18] 49 | lr_generator: 2.0e-4 50 | lr_discriminator: 2.0e-4 51 | lr_kp_detector: 2.0e-4 52 | batch_size: 36 53 | scales: [1, 0.5, 0.25, 0.125] 54 | checkpoint_freq: 10 55 | transform_params: 56 | sigma_affine: 0.05 57 | sigma_tps: 0.005 58 | points_tps: 5 59 | loss_weights: 60 | generator_gan: 1 61 | discriminator_gan: 1 62 | feature_matching: [10, 10, 10, 10] 63 | perceptual: [10, 10, 10, 10, 10] 64 | equivariance_value: 10 65 | equivariance_jacobian: 10 66 | 67 | reconstruction_params: 68 | num_videos: 1000 69 | format: '.mp4' 70 | 71 | animate_params: 72 | num_pairs: 50 73 | format: '.mp4' 74 | normalization_params: 75 | adapt_movement_scale: False 76 | use_relative_movement: True 77 | use_relative_jacobian: True 78 | 79 | visualizer_params: 80 | kp_size: 5 81 | draw_border: True 82 | colormap: 'gist_rainbow' 83 | -------------------------------------------------------------------------------- /config/fashion-256.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: data/fashion-png 3 | frame_shape: [256, 256, 3] 4 | id_sampling: False 5 | augmentation_params: 6 | flip_param: 7 | horizontal_flip: True 8 | time_flip: True 9 | jitter_param: 10 | hue: 0.1 11 | 12 | model_params: 13 | common_params: 14 | num_kp: 10 15 | num_channels: 3 16 | estimate_jacobian: True 17 | kp_detector_params: 18 | temperature: 0.1 19 | block_expansion: 32 20 | max_features: 1024 21 | scale_factor: 0.25 22 | num_blocks: 5 23 | generator_params: 24 | block_expansion: 64 25 | max_features: 512 26 | num_down_blocks: 2 27 | num_bottleneck_blocks: 6 28 | estimate_occlusion_map: True 29 | dense_motion_params: 30 | block_expansion: 64 31 | max_features: 1024 32 | num_blocks: 5 33 | scale_factor: 0.25 34 | discriminator_params: 35 | scales: [1] 36 | block_expansion: 32 37 | max_features: 512 38 | num_blocks: 4 39 | 40 | train_params: 41 | num_epochs: 100 42 | num_repeats: 50 43 | epoch_milestones: [60, 90] 44 | lr_generator: 2.0e-4 45 | lr_discriminator: 2.0e-4 46 | lr_kp_detector: 2.0e-4 47 | batch_size: 27 48 | scales: [1, 0.5, 0.25, 0.125] 49 | checkpoint_freq: 50 50 | transform_params: 51 | sigma_affine: 0.05 52 | sigma_tps: 0.005 53 | points_tps: 5 54 | loss_weights: 55 | generator_gan: 1 56 | discriminator_gan: 1 57 | feature_matching: [10, 10, 10, 10] 58 | perceptual: [10, 10, 10, 10, 10] 59 | equivariance_value: 10 60 | equivariance_jacobian: 10 61 | 62 | reconstruction_params: 63 | num_videos: 1000 64 | format: '.mp4' 65 | 66 | animate_params: 67 | num_pairs: 50 68 | format: '.mp4' 69 | normalization_params: 70 | adapt_movement_scale: False 71 | use_relative_movement: True 72 | use_relative_jacobian: True 73 | 74 | visualizer_params: 75 | kp_size: 5 76 | draw_border: True 77 | colormap: 'gist_rainbow' 78 | -------------------------------------------------------------------------------- /config/mgif-256.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: data/moving-gif 3 | frame_shape: [256, 256, 3] 4 | id_sampling: False 5 | augmentation_params: 6 | flip_param: 7 | horizontal_flip: True 8 | time_flip: True 9 | crop_param: 10 | size: [256, 256] 11 | resize_param: 12 | ratio: [0.9, 1.1] 13 | jitter_param: 14 | hue: 0.5 15 | 16 | model_params: 17 | common_params: 18 | num_kp: 10 19 | num_channels: 3 20 | estimate_jacobian: True 21 | kp_detector_params: 22 | temperature: 0.1 23 | block_expansion: 32 24 | max_features: 1024 25 | scale_factor: 0.25 26 | num_blocks: 5 27 | single_jacobian_map: True 28 | generator_params: 29 | block_expansion: 64 30 | max_features: 512 31 | num_down_blocks: 2 32 | num_bottleneck_blocks: 6 33 | estimate_occlusion_map: True 34 | dense_motion_params: 35 | block_expansion: 64 36 | max_features: 1024 37 | num_blocks: 5 38 | scale_factor: 0.25 39 | discriminator_params: 40 | scales: [1] 41 | block_expansion: 32 42 | max_features: 512 43 | num_blocks: 4 44 | sn: True 45 | 46 | train_params: 47 | num_epochs: 100 48 | num_repeats: 25 49 | epoch_milestones: [60, 90] 50 | lr_generator: 2.0e-4 51 | lr_discriminator: 2.0e-4 52 | lr_kp_detector: 2.0e-4 53 | 54 | batch_size: 36 55 | scales: [1, 0.5, 0.25, 0.125] 56 | checkpoint_freq: 100 57 | transform_params: 58 | sigma_affine: 0.05 59 | sigma_tps: 0.005 60 | points_tps: 5 61 | loss_weights: 62 | generator_gan: 1 63 | discriminator_gan: 1 64 | feature_matching: [10, 10, 10, 10] 65 | perceptual: [10, 10, 10, 10, 10] 66 | equivariance_value: 10 67 | equivariance_jacobian: 10 68 | 69 | reconstruction_params: 70 | num_videos: 1000 71 | format: '.mp4' 72 | 73 | animate_params: 74 | num_pairs: 50 75 | format: '.mp4' 76 | normalization_params: 77 | adapt_movement_scale: False 78 | use_relative_movement: True 79 | use_relative_jacobian: True 80 | 81 | visualizer_params: 82 | kp_size: 5 83 | draw_border: True 84 | colormap: 'gist_rainbow' 85 | -------------------------------------------------------------------------------- /config/nemo-256.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: data/nemo-png 3 | frame_shape: [256, 256, 3] 4 | id_sampling: False 5 | augmentation_params: 6 | flip_param: 7 | horizontal_flip: True 8 | time_flip: True 9 | 10 | model_params: 11 | common_params: 12 | num_kp: 10 13 | num_channels: 3 14 | estimate_jacobian: True 15 | kp_detector_params: 16 | temperature: 0.1 17 | block_expansion: 32 18 | max_features: 1024 19 | scale_factor: 0.25 20 | num_blocks: 5 21 | generator_params: 22 | block_expansion: 64 23 | max_features: 512 24 | num_down_blocks: 2 25 | num_bottleneck_blocks: 6 26 | estimate_occlusion_map: True 27 | dense_motion_params: 28 | block_expansion: 64 29 | max_features: 1024 30 | num_blocks: 5 31 | scale_factor: 0.25 32 | discriminator_params: 33 | scales: [1] 34 | block_expansion: 32 35 | max_features: 512 36 | num_blocks: 4 37 | sn: True 38 | 39 | train_params: 40 | num_epochs: 100 41 | num_repeats: 8 42 | epoch_milestones: [60, 90] 43 | lr_generator: 2.0e-4 44 | lr_discriminator: 2.0e-4 45 | lr_kp_detector: 2.0e-4 46 | batch_size: 36 47 | scales: [1, 0.5, 0.25, 0.125] 48 | checkpoint_freq: 50 49 | transform_params: 50 | sigma_affine: 0.05 51 | sigma_tps: 0.005 52 | points_tps: 5 53 | loss_weights: 54 | generator_gan: 1 55 | discriminator_gan: 1 56 | feature_matching: [10, 10, 10, 10] 57 | perceptual: [10, 10, 10, 10, 10] 58 | equivariance_value: 10 59 | equivariance_jacobian: 10 60 | 61 | reconstruction_params: 62 | num_videos: 1000 63 | format: '.mp4' 64 | 65 | animate_params: 66 | num_pairs: 50 67 | format: '.mp4' 68 | normalization_params: 69 | adapt_movement_scale: False 70 | use_relative_movement: True 71 | use_relative_jacobian: True 72 | 73 | visualizer_params: 74 | kp_size: 5 75 | draw_border: True 76 | colormap: 'gist_rainbow' 77 | -------------------------------------------------------------------------------- /config/vox-256.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: data/vox-png 3 | frame_shape: [256, 256, 3] 4 | id_sampling: True 5 | pairs_list: data/vox256.csv 6 | augmentation_params: 7 | flip_param: 8 | horizontal_flip: True 9 | time_flip: True 10 | jitter_param: 11 | brightness: 0.1 12 | contrast: 0.1 13 | saturation: 0.1 14 | hue: 0.1 15 | 16 | 17 | model_params: 18 | common_params: 19 | num_kp: 10 20 | num_channels: 3 21 | estimate_jacobian: True 22 | kp_detector_params: 23 | temperature: 0.1 24 | block_expansion: 32 25 | max_features: 1024 26 | scale_factor: 0.25 27 | num_blocks: 5 28 | generator_params: 29 | block_expansion: 64 30 | max_features: 512 31 | num_down_blocks: 2 32 | num_bottleneck_blocks: 6 33 | estimate_occlusion_map: True 34 | dense_motion_params: 35 | block_expansion: 64 36 | max_features: 1024 37 | num_blocks: 5 38 | scale_factor: 0.25 39 | discriminator_params: 40 | scales: [1] 41 | block_expansion: 32 42 | max_features: 512 43 | num_blocks: 4 44 | sn: True 45 | 46 | train_params: 47 | num_epochs: 100 48 | num_repeats: 75 49 | epoch_milestones: [60, 90] 50 | lr_generator: 2.0e-4 51 | lr_discriminator: 2.0e-4 52 | lr_kp_detector: 2.0e-4 53 | batch_size: 40 54 | scales: [1, 0.5, 0.25, 0.125] 55 | checkpoint_freq: 50 56 | transform_params: 57 | sigma_affine: 0.05 58 | sigma_tps: 0.005 59 | points_tps: 5 60 | loss_weights: 61 | generator_gan: 0 62 | discriminator_gan: 1 63 | feature_matching: [10, 10, 10, 10] 64 | perceptual: [10, 10, 10, 10, 10] 65 | equivariance_value: 10 66 | equivariance_jacobian: 10 67 | 68 | reconstruction_params: 69 | num_videos: 1000 70 | format: '.mp4' 71 | 72 | animate_params: 73 | num_pairs: 50 74 | format: '.mp4' 75 | normalization_params: 76 | adapt_movement_scale: False 77 | use_relative_movement: True 78 | use_relative_jacobian: True 79 | 80 | visualizer_params: 81 | kp_size: 5 82 | draw_border: True 83 | colormap: 'gist_rainbow' 84 | -------------------------------------------------------------------------------- /config/vox-adv-256.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: data/vox-png 3 | frame_shape: [256, 256, 3] 4 | id_sampling: True 5 | pairs_list: data/vox256.csv 6 | augmentation_params: 7 | flip_param: 8 | horizontal_flip: True 9 | time_flip: True 10 | jitter_param: 11 | brightness: 0.1 12 | contrast: 0.1 13 | saturation: 0.1 14 | hue: 0.1 15 | 16 | 17 | model_params: 18 | common_params: 19 | num_kp: 10 20 | num_channels: 3 21 | estimate_jacobian: True 22 | kp_detector_params: 23 | temperature: 0.1 24 | block_expansion: 32 25 | max_features: 1024 26 | scale_factor: 0.25 27 | num_blocks: 5 28 | generator_params: 29 | block_expansion: 64 30 | max_features: 512 31 | num_down_blocks: 2 32 | num_bottleneck_blocks: 6 33 | estimate_occlusion_map: True 34 | dense_motion_params: 35 | block_expansion: 64 36 | max_features: 1024 37 | num_blocks: 5 38 | scale_factor: 0.25 39 | discriminator_params: 40 | scales: [1] 41 | block_expansion: 32 42 | max_features: 512 43 | num_blocks: 4 44 | use_kp: True 45 | 46 | 47 | train_params: 48 | num_epochs: 150 49 | num_repeats: 75 50 | epoch_milestones: [] 51 | lr_generator: 2.0e-4 52 | lr_discriminator: 2.0e-4 53 | lr_kp_detector: 2.0e-4 54 | batch_size: 36 55 | scales: [1, 0.5, 0.25, 0.125] 56 | checkpoint_freq: 50 57 | transform_params: 58 | sigma_affine: 0.05 59 | sigma_tps: 0.005 60 | points_tps: 5 61 | loss_weights: 62 | generator_gan: 1 63 | discriminator_gan: 1 64 | feature_matching: [10, 10, 10, 10] 65 | perceptual: [10, 10, 10, 10, 10] 66 | equivariance_value: 10 67 | equivariance_jacobian: 10 68 | 69 | reconstruction_params: 70 | num_videos: 1000 71 | format: '.mp4' 72 | 73 | animate_params: 74 | num_pairs: 50 75 | format: '.mp4' 76 | normalization_params: 77 | adapt_movement_scale: False 78 | use_relative_movement: True 79 | use_relative_jacobian: True 80 | 81 | visualizer_params: 82 | kp_size: 5 83 | draw_border: True 84 | colormap: 'gist_rainbow' 85 | -------------------------------------------------------------------------------- /convert_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | 5 | def _check_model_old_version(model): 6 | if hasattr(model.WN[0], 'res_layers') or hasattr(model.WN[0], 'cond_layers'): 7 | return True 8 | else: 9 | return False 10 | 11 | 12 | def _update_model_res_skip(old_model, new_model): 13 | for idx in range(0, len(new_model.WN)): 14 | wavenet = new_model.WN[idx] 15 | n_channels = wavenet.n_channels 16 | n_layers = wavenet.n_layers 17 | wavenet.res_skip_layers = torch.nn.ModuleList() 18 | for i in range(0, n_layers): 19 | if i < n_layers - 1: 20 | res_skip_channels = 2*n_channels 21 | else: 22 | res_skip_channels = n_channels 23 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 24 | skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i]) 25 | if i < n_layers - 1: 26 | res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i]) 27 | res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight])) 28 | res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias])) 29 | else: 30 | res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight) 31 | res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias) 32 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 33 | wavenet.res_skip_layers.append(res_skip_layer) 34 | del wavenet.res_layers 35 | del wavenet.skip_layers 36 | 37 | def _update_model_cond(old_model, new_model): 38 | for idx in range(0, len(new_model.WN)): 39 | wavenet = new_model.WN[idx] 40 | n_channels = wavenet.n_channels 41 | n_layers = wavenet.n_layers 42 | n_mel_channels = wavenet.cond_layers[0].weight.shape[1] 43 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1) 44 | cond_layer_weight = [] 45 | cond_layer_bias = [] 46 | for i in range(0, n_layers): 47 | _cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layers[i]) 48 | cond_layer_weight.append(_cond_layer.weight) 49 | cond_layer_bias.append(_cond_layer.bias) 50 | cond_layer.weight = torch.nn.Parameter(torch.cat(cond_layer_weight)) 51 | cond_layer.bias = torch.nn.Parameter(torch.cat(cond_layer_bias)) 52 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 53 | wavenet.cond_layer = cond_layer 54 | del wavenet.cond_layers 55 | 56 | def update_model(old_model): 57 | if not _check_model_old_version(old_model): 58 | return old_model 59 | new_model = copy.deepcopy(old_model) 60 | if hasattr(old_model.WN[0], 'res_layers'): 61 | _update_model_res_skip(old_model, new_model) 62 | if hasattr(old_model.WN[0], 'cond_layers'): 63 | _update_model_cond(old_model, new_model) 64 | return new_model 65 | 66 | if __name__ == '__main__': 67 | old_model_path = sys.argv[1] 68 | new_model_path = sys.argv[2] 69 | model = torch.load(old_model_path, map_location='cpu') 70 | model['model'] = update_model(model['model']) 71 | torch.save(model, new_model_path) 72 | 73 | -------------------------------------------------------------------------------- /data/bair256.csv: -------------------------------------------------------------------------------- 1 | distance,source,driving,frame 2 | 0,000054.mp4,000048.mp4,0 3 | 0,000050.mp4,000063.mp4,0 4 | 0,000073.mp4,000007.mp4,0 5 | 0,000021.mp4,000010.mp4,0 6 | 0,000084.mp4,000046.mp4,0 7 | 0,000031.mp4,000102.mp4,0 8 | 0,000029.mp4,000111.mp4,0 9 | 0,000090.mp4,000112.mp4,0 10 | 0,000039.mp4,000010.mp4,0 11 | 0,000008.mp4,000069.mp4,0 12 | 0,000068.mp4,000076.mp4,0 13 | 0,000051.mp4,000052.mp4,0 14 | 0,000022.mp4,000098.mp4,0 15 | 0,000096.mp4,000032.mp4,0 16 | 0,000032.mp4,000099.mp4,0 17 | 0,000006.mp4,000053.mp4,0 18 | 0,000098.mp4,000020.mp4,0 19 | 0,000029.mp4,000066.mp4,0 20 | 0,000022.mp4,000007.mp4,0 21 | 0,000027.mp4,000065.mp4,0 22 | 0,000026.mp4,000059.mp4,0 23 | 0,000015.mp4,000112.mp4,0 24 | 0,000086.mp4,000123.mp4,0 25 | 0,000103.mp4,000052.mp4,0 26 | 0,000123.mp4,000103.mp4,0 27 | 0,000051.mp4,000005.mp4,0 28 | 0,000062.mp4,000125.mp4,0 29 | 0,000126.mp4,000111.mp4,0 30 | 0,000066.mp4,000090.mp4,0 31 | 0,000075.mp4,000106.mp4,0 32 | 0,000020.mp4,000010.mp4,0 33 | 0,000076.mp4,000028.mp4,0 34 | 0,000062.mp4,000002.mp4,0 35 | 0,000095.mp4,000127.mp4,0 36 | 0,000113.mp4,000072.mp4,0 37 | 0,000027.mp4,000104.mp4,0 38 | 0,000054.mp4,000124.mp4,0 39 | 0,000019.mp4,000089.mp4,0 40 | 0,000052.mp4,000072.mp4,0 41 | 0,000108.mp4,000033.mp4,0 42 | 0,000044.mp4,000118.mp4,0 43 | 0,000029.mp4,000086.mp4,0 44 | 0,000068.mp4,000066.mp4,0 45 | 0,000014.mp4,000036.mp4,0 46 | 0,000053.mp4,000071.mp4,0 47 | 0,000022.mp4,000094.mp4,0 48 | 0,000000.mp4,000121.mp4,0 49 | 0,000071.mp4,000079.mp4,0 50 | 0,000127.mp4,000005.mp4,0 51 | 0,000085.mp4,000023.mp4,0 52 | -------------------------------------------------------------------------------- /data/cmudict_dictionary: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/data/cmudict_dictionary -------------------------------------------------------------------------------- /data/heteronyms: -------------------------------------------------------------------------------- 1 | abject 2 | abrogate 3 | absent 4 | abstract 5 | abuse 6 | ache 7 | Acre 8 | acuminate 9 | addict 10 | address 11 | adduct 12 | Adele 13 | advocate 14 | affect 15 | affiliate 16 | agape 17 | aged 18 | agglomerate 19 | aggregate 20 | agonic 21 | agora 22 | allied 23 | ally 24 | alternate 25 | alum 26 | am 27 | analyses 28 | Andrea 29 | animate 30 | apply 31 | appropriate 32 | approximate 33 | ares 34 | arithmetic 35 | arsenic 36 | articulate 37 | associate 38 | attribute 39 | august 40 | axes 41 | ay 42 | aye 43 | bases 44 | bass 45 | bathed 46 | bested 47 | bifurcate 48 | blessed 49 | blotto 50 | bow 51 | bowed 52 | bowman 53 | brassy 54 | buffet 55 | bustier 56 | carbonate 57 | Celtic 58 | choral 59 | Chumash 60 | close 61 | closer 62 | coax 63 | coincidence 64 | color coordinate 65 | colour coordinate 66 | comber 67 | combine 68 | combs 69 | committee 70 | commune 71 | compact 72 | complex 73 | compound 74 | compress 75 | concert 76 | conduct 77 | confine 78 | confines 79 | conflict 80 | conglomerate 81 | conscript 82 | conserve 83 | consist 84 | console 85 | consort 86 | construct 87 | consult 88 | consummate 89 | content 90 | contest 91 | contract 92 | contracts 93 | contrast 94 | converse 95 | convert 96 | convict 97 | coop 98 | coordinate 99 | covey 100 | crooked 101 | curate 102 | cussed 103 | decollate 104 | decrease 105 | defect 106 | defense 107 | delegate 108 | deliberate 109 | denier 110 | desert 111 | detail 112 | deviate 113 | diagnoses 114 | diffuse 115 | digest 116 | discard 117 | discharge 118 | discount 119 | do 120 | document 121 | does 122 | dogged 123 | domesticate 124 | Dominican 125 | dove 126 | dr 127 | drawer 128 | duplicate 129 | egress 130 | ejaculate 131 | eject 132 | elaborate 133 | ellipses 134 | email 135 | emu 136 | entrace 137 | entrance 138 | escort 139 | estimate 140 | eta 141 | Etna 142 | evening 143 | excise 144 | excuse 145 | exploit 146 | export 147 | extract 148 | fine 149 | flower 150 | forbear 151 | four-legged 152 | frequent 153 | furrier 154 | gallant 155 | gel 156 | geminate 157 | gillie 158 | glower 159 | Gotham 160 | graduate 161 | haggis 162 | heavy 163 | hinder 164 | house 165 | housewife 166 | impact 167 | imped 168 | implant 169 | implement 170 | import 171 | impress 172 | incense 173 | incline 174 | increase 175 | infix 176 | insert 177 | instar 178 | insult 179 | integral 180 | intercept 181 | interchange 182 | interflow 183 | interleaf 184 | intermediate 185 | intern 186 | interspace 187 | intimate 188 | intrigue 189 | invalid 190 | invert 191 | invite 192 | irony 193 | jagged 194 | Jesses 195 | Julies 196 | kite 197 | laminate 198 | Laos 199 | lather 200 | lead 201 | learned 202 | leasing 203 | lech 204 | legitimate 205 | lied 206 | lima 207 | lipread 208 | live 209 | lower 210 | lunged 211 | maas 212 | Magdalen 213 | manes 214 | mare 215 | marked 216 | merchandise 217 | merlion 218 | minute 219 | misconduct 220 | misled 221 | misprint 222 | mobile 223 | moderate 224 | mong 225 | moped 226 | moth 227 | mouth 228 | mow 229 | mpg 230 | multiply 231 | mush 232 | nana 233 | nice 234 | Nice 235 | number 236 | numerate 237 | nun 238 | object 239 | opiate 240 | ornament 241 | outbox 242 | outcry 243 | outpour 244 | outreach 245 | outride 246 | outright 247 | outside 248 | outwork 249 | overall 250 | overbid 251 | overcall 252 | overcast 253 | overfall 254 | overflow 255 | overhaul 256 | overhead 257 | overlap 258 | overlay 259 | overuse 260 | overweight 261 | overwork 262 | pace 263 | palled 264 | palling 265 | para 266 | pasty 267 | pate 268 | Pauline 269 | pedal 270 | peer 271 | perfect 272 | periodic 273 | permit 274 | pervert 275 | pinta 276 | placer 277 | platy 278 | polish 279 | Polish 280 | poll 281 | pontificate 282 | postulate 283 | pram 284 | prayer 285 | precipitate 286 | predate 287 | predicate 288 | prefix 289 | preposition 290 | present 291 | pretest 292 | primer 293 | proceeds 294 | produce 295 | progress 296 | project 297 | proportionate 298 | prospect 299 | protest 300 | pussy 301 | putter 302 | putting 303 | quite 304 | ragged 305 | raven 306 | re 307 | read 308 | reading 309 | Reading 310 | real 311 | rebel 312 | recall 313 | recap 314 | recitative 315 | recollect 316 | record 317 | recreate 318 | recreation 319 | redress 320 | refill 321 | refund 322 | refuse 323 | reject 324 | relay 325 | remake 326 | repaint 327 | reprint 328 | reread 329 | rerun 330 | resent 331 | reside 332 | resign 333 | respray 334 | resume 335 | retard 336 | retest 337 | retread 338 | rewrite 339 | root 340 | routed 341 | routing 342 | row 343 | rugged 344 | rummy 345 | sais 346 | sake 347 | sambuca 348 | saucier 349 | second 350 | secrete 351 | secreted 352 | secreting 353 | segment 354 | separate 355 | sewer 356 | shirk 357 | shower 358 | sin 359 | skied 360 | slaver 361 | slough 362 | sow 363 | spoof 364 | squid 365 | stingy 366 | subject 367 | subordinate 368 | subvert 369 | supply 370 | supposed 371 | survey 372 | suspect 373 | syringes 374 | tabulate 375 | tales 376 | tarrier 377 | tarry 378 | taxes 379 | taxis 380 | tear 381 | Theron 382 | thou 383 | three-legged 384 | tier 385 | tinged 386 | torment 387 | transfer 388 | transform 389 | transplant 390 | transport 391 | transpose 392 | tush 393 | two-legged 394 | unionised 395 | unionized 396 | update 397 | uplift 398 | upset 399 | use 400 | used 401 | vale 402 | violist 403 | viva 404 | ware 405 | whinged 406 | whoop 407 | wicked 408 | wind 409 | windy 410 | wino 411 | won 412 | worsted 413 | wound 414 | -------------------------------------------------------------------------------- /data/taichi-loading/README.md: -------------------------------------------------------------------------------- 1 | # TaiChi dataset 2 | 3 | The scripst for loading the TaiChi dataset. 4 | 5 | We provide only the id of the corresponding video and the bounding box. Following script will download videos from youtube and crop them according to the provided bounding boxes. 6 | 7 | 1) Load youtube-dl: 8 | ``` 9 | wget https://yt-dl.org/downloads/latest/youtube-dl -O youtube-dl 10 | chmod a+rx youtube-dl 11 | ``` 12 | 13 | 2) Run script to download videos, there are 2 formats that can be used for storing videos one is .mp4 and another is folder with .png images. While .png images occupy significantly more space, the format is loss-less and have better i/o performance when training. 14 | 15 | ``` 16 | python load_videos.py --metadata taichi-metadata.csv --format .mp4 --out_folder taichi --workers 8 17 | ``` 18 | select number of workers based on number of cpu avaliable. Note .png format take aproximatly 80GB. 19 | -------------------------------------------------------------------------------- /data/taichi-loading/load_videos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import imageio 4 | import os 5 | import subprocess 6 | from multiprocessing import Pool 7 | from itertools import cycle 8 | import warnings 9 | import glob 10 | import time 11 | from tqdm import tqdm 12 | from argparse import ArgumentParser 13 | from skimage import img_as_ubyte 14 | from skimage.transform import resize 15 | warnings.filterwarnings("ignore") 16 | 17 | DEVNULL = open(os.devnull, 'wb') 18 | 19 | 20 | def save(path, frames, format): 21 | if format == '.mp4': 22 | imageio.mimsave(path, frames) 23 | elif format == '.png': 24 | if os.path.exists(path): 25 | print ("Warning: skiping video %s" % os.path.basename(path)) 26 | return 27 | else: 28 | os.makedirs(path) 29 | for j, frame in enumerate(frames): 30 | imageio.imsave(os.path.join(path, str(j).zfill(7) + '.png'), frames[j]) 31 | else: 32 | print ("Unknown format %s" % format) 33 | exit() 34 | 35 | 36 | def download(video_id, args): 37 | video_path = os.path.join(args.video_folder, video_id + ".mp4") 38 | subprocess.call([args.youtube, '-f', "''best/mp4''", '--write-auto-sub', '--write-sub', 39 | '--sub-lang', 'en', '--skip-unavailable-fragments', 40 | "https://www.youtube.com/watch?v=" + video_id, "--output", 41 | video_path], stdout=DEVNULL, stderr=DEVNULL) 42 | return video_path 43 | 44 | 45 | def run(data): 46 | video_id, args = data 47 | if not os.path.exists(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')): 48 | download(video_id.split('#')[0], args) 49 | 50 | if not os.path.exists(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')): 51 | print ('Can not load video %s, broken link' % video_id.split('#')[0]) 52 | return 53 | reader = imageio.get_reader(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')) 54 | fps = reader.get_meta_data()['fps'] 55 | 56 | df = pd.read_csv(args.metadata) 57 | df = df[df['video_id'] == video_id] 58 | 59 | all_chunks_dict = [{'start': df['start'].iloc[j], 'end': df['end'].iloc[j], 60 | 'bbox': list(map(int, df['bbox'].iloc[j].split('-'))), 'frames':[]} for j in range(df.shape[0])] 61 | ref_fps = df['fps'].iloc[0] 62 | ref_height = df['height'].iloc[0] 63 | ref_width = df['width'].iloc[0] 64 | partition = df['partition'].iloc[0] 65 | try: 66 | for i, frame in enumerate(reader): 67 | for entry in all_chunks_dict: 68 | if (i * ref_fps >= entry['start'] * fps) and (i * ref_fps < entry['end'] * fps): 69 | left, top, right, bot = entry['bbox'] 70 | left = int(left / (ref_width / frame.shape[1])) 71 | top = int(top / (ref_height / frame.shape[0])) 72 | right = int(right / (ref_width / frame.shape[1])) 73 | bot = int(bot / (ref_height / frame.shape[0])) 74 | crop = frame[top:bot, left:right] 75 | if args.image_shape is not None: 76 | crop = img_as_ubyte(resize(crop, args.image_shape, anti_aliasing=True)) 77 | entry['frames'].append(crop) 78 | except imageio.core.format.CannotReadFrameError: 79 | None 80 | 81 | for entry in all_chunks_dict: 82 | first_part = '#'.join(video_id.split('#')[::-1]) 83 | path = first_part + '#' + str(entry['start']).zfill(6) + '#' + str(entry['end']).zfill(6) + '.mp4' 84 | save(os.path.join(args.out_folder, partition, path), entry['frames'], args.format) 85 | 86 | 87 | if __name__ == "__main__": 88 | parser = ArgumentParser() 89 | parser.add_argument("--video_folder", default='youtube-taichi', help='Path to youtube videos') 90 | parser.add_argument("--metadata", default='taichi-metadata-new.csv', help='Path to metadata') 91 | parser.add_argument("--out_folder", default='taichi-png', help='Path to output') 92 | parser.add_argument("--format", default='.png', help='Storing format') 93 | parser.add_argument("--workers", default=1, type=int, help='Number of workers') 94 | parser.add_argument("--youtube", default='./youtube-dl', help='Path to youtube-dl') 95 | 96 | parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))), 97 | help="Image shape, None for no resize") 98 | 99 | args = parser.parse_args() 100 | if not os.path.exists(args.video_folder): 101 | os.makedirs(args.video_folder) 102 | if not os.path.exists(args.out_folder): 103 | os.makedirs(args.out_folder) 104 | for partition in ['test', 'train']: 105 | if not os.path.exists(os.path.join(args.out_folder, partition)): 106 | os.makedirs(os.path.join(args.out_folder, partition)) 107 | 108 | df = pd.read_csv(args.metadata) 109 | video_ids = set(df['video_id']) 110 | pool = Pool(processes=args.workers) 111 | args_list = cycle([args]) 112 | for chunks_data in tqdm(pool.imap_unordered(run, zip(video_ids, args_list))): 113 | None 114 | -------------------------------------------------------------------------------- /data/taichi256.csv: -------------------------------------------------------------------------------- 1 | distance,source,driving,frame 2 | 3.54437869822485,ab28GAufK8o#000261#000596.mp4,aDyyTMUBoLE#000164#000351.mp4,0 3 | 2.8639053254437887,DMEaUoA8EPE#000028#000354.mp4,0Q914by5A98#010440#010764.mp4,0 4 | 2.153846153846153,L82WHgYRq6I#000021#000479.mp4,0Q914by5A98#010440#010764.mp4,0 5 | 2.8994082840236666,oNkBx4CZuEg#000000#001024.mp4,DMEaUoA8EPE#000028#000354.mp4,0 6 | 3.3905325443786998,ab28GAufK8o#000261#000596.mp4,uEqWZ9S_-Lw#000089#000581.mp4,0 7 | 3.266272189349112,0Q914by5A98#010440#010764.mp4,ab28GAufK8o#000261#000596.mp4,0 8 | 2.7514792899408294,WlDYrq8K6nk#008186#008512.mp4,OiblkvkAHWM#014331#014459.mp4,0 9 | 3.0177514792899407,oNkBx4CZuEg#001024#002048.mp4,aDyyTMUBoLE#000375#000518.mp4,0 10 | 3.4792899408284064,aDyyTMUBoLE#000164#000351.mp4,w2awOCDRtrc#001729#002009.mp4,0 11 | 2.769230769230769,oNkBx4CZuEg#000000#001024.mp4,L82WHgYRq6I#000021#000479.mp4,0 12 | 3.8047337278106514,ab28GAufK8o#000261#000596.mp4,w2awOCDRtrc#001729#002009.mp4,0 13 | 3.4260355029585763,w2awOCDRtrc#001729#002009.mp4,oNkBx4CZuEg#000000#001024.mp4,0 14 | 3.313609467455621,DMEaUoA8EPE#000028#000354.mp4,WlDYrq8K6nk#005943#006135.mp4,0 15 | 3.8402366863905333,oNkBx4CZuEg#001024#002048.mp4,ab28GAufK8o#000261#000596.mp4,0 16 | 3.3254437869822504,aDyyTMUBoLE#000164#000351.mp4,oNkBx4CZuEg#000000#001024.mp4,0 17 | 1.2485207100591724,0Q914by5A98#010440#010764.mp4,aDyyTMUBoLE#000164#000351.mp4,0 18 | 3.804733727810652,OiblkvkAHWM#006251#006533.mp4,aDyyTMUBoLE#000375#000518.mp4,0 19 | 3.662721893491124,uEqWZ9S_-Lw#000089#000581.mp4,DMEaUoA8EPE#000028#000354.mp4,0 20 | 3.230769230769233,A3ZmT97hAWU#000095#000678.mp4,ab28GAufK8o#000261#000596.mp4,0 21 | 3.3668639053254434,w81Tr0Dp1K8#015329#015485.mp4,WlDYrq8K6nk#008186#008512.mp4,0 22 | 3.313609467455621,WlDYrq8K6nk#005943#006135.mp4,DMEaUoA8EPE#000028#000354.mp4,0 23 | 2.7514792899408294,OiblkvkAHWM#014331#014459.mp4,WlDYrq8K6nk#008186#008512.mp4,0 24 | 1.964497041420118,L82WHgYRq6I#000021#000479.mp4,DMEaUoA8EPE#000028#000354.mp4,0 25 | 3.78698224852071,FBuF0xOal9M#046824#047542.mp4,lCb5w6n8kPs#011879#012014.mp4,0 26 | 3.92307692307692,ab28GAufK8o#000261#000596.mp4,L82WHgYRq6I#000021#000479.mp4,0 27 | 3.8402366863905333,ab28GAufK8o#000261#000596.mp4,oNkBx4CZuEg#001024#002048.mp4,0 28 | 3.828402366863905,ab28GAufK8o#000261#000596.mp4,OiblkvkAHWM#006251#006533.mp4,0 29 | 2.041420118343196,L82WHgYRq6I#000021#000479.mp4,aDyyTMUBoLE#000164#000351.mp4,0 30 | 3.2485207100591724,0Q914by5A98#010440#010764.mp4,w2awOCDRtrc#001729#002009.mp4,0 31 | 3.2485207100591746,oNkBx4CZuEg#000000#001024.mp4,0Q914by5A98#010440#010764.mp4,0 32 | 1.964497041420118,DMEaUoA8EPE#000028#000354.mp4,L82WHgYRq6I#000021#000479.mp4,0 33 | 3.5266272189349115,kgvcI9oe3NI#001578#001763.mp4,lCb5w6n8kPs#004451#004631.mp4,0 34 | 3.005917159763317,A3ZmT97hAWU#000095#000678.mp4,0Q914by5A98#010440#010764.mp4,0 35 | 3.230769230769233,ab28GAufK8o#000261#000596.mp4,A3ZmT97hAWU#000095#000678.mp4,0 36 | 3.5266272189349115,lCb5w6n8kPs#004451#004631.mp4,kgvcI9oe3NI#001578#001763.mp4,0 37 | 2.769230769230769,L82WHgYRq6I#000021#000479.mp4,oNkBx4CZuEg#000000#001024.mp4,0 38 | 3.165680473372782,WlDYrq8K6nk#005943#006135.mp4,w81Tr0Dp1K8#001375#001516.mp4,0 39 | 2.8994082840236666,DMEaUoA8EPE#000028#000354.mp4,oNkBx4CZuEg#000000#001024.mp4,0 40 | 2.4556213017751523,0Q914by5A98#010440#010764.mp4,mndSqTrxpts#000000#000175.mp4,0 41 | 2.201183431952659,A3ZmT97hAWU#000095#000678.mp4,VMSqvTE90hk#007168#007312.mp4,0 42 | 3.8047337278106514,w2awOCDRtrc#001729#002009.mp4,ab28GAufK8o#000261#000596.mp4,0 43 | 3.769230769230769,uEqWZ9S_-Lw#000089#000581.mp4,0Q914by5A98#010440#010764.mp4,0 44 | 3.6568047337278102,A3ZmT97hAWU#000095#000678.mp4,aDyyTMUBoLE#000164#000351.mp4,0 45 | 3.7869822485207107,uEqWZ9S_-Lw#000089#000581.mp4,L82WHgYRq6I#000021#000479.mp4,0 46 | 3.78698224852071,lCb5w6n8kPs#011879#012014.mp4,FBuF0xOal9M#046824#047542.mp4,0 47 | 3.591715976331361,nAQEOC1Z10M#020177#020600.mp4,w81Tr0Dp1K8#004036#004218.mp4,0 48 | 3.8757396449704156,uEqWZ9S_-Lw#000089#000581.mp4,aDyyTMUBoLE#000164#000351.mp4,0 49 | 2.45562130177515,aDyyTMUBoLE#000164#000351.mp4,DMEaUoA8EPE#000028#000354.mp4,0 50 | 3.5502958579881647,uEqWZ9S_-Lw#000089#000581.mp4,OiblkvkAHWM#006251#006533.mp4,0 51 | 3.7928994082840224,aDyyTMUBoLE#000375#000518.mp4,ab28GAufK8o#000261#000596.mp4,0 52 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | 6 | import layers 7 | from utils import load_wav_to_torch, load_filepaths_and_text 8 | from text import text_to_sequence 9 | 10 | 11 | class TextMelLoader(torch.utils.data.Dataset): 12 | """ 13 | 1) loads audio,text pairs 14 | 2) normalizes text and converts them to sequences of one-hot vectors 15 | 3) computes mel-spectrograms from audio files. 16 | """ 17 | def __init__(self, audiopaths_and_text, hparams): 18 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 19 | self.text_cleaners = hparams.text_cleaners 20 | self.max_wav_value = hparams.max_wav_value 21 | self.sampling_rate = hparams.sampling_rate 22 | self.load_mel_from_disk = hparams.load_mel_from_disk 23 | self.stft = layers.TacotronSTFT( 24 | hparams.filter_length, hparams.hop_length, hparams.win_length, 25 | hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, 26 | hparams.mel_fmax) 27 | random.seed(hparams.seed) 28 | random.shuffle(self.audiopaths_and_text) 29 | 30 | def get_mel_text_pair(self, audiopath_and_text): 31 | # separate filename and text 32 | audiopath, text = audiopath_and_text[0], audiopath_and_text[1] 33 | text = self.get_text(text) 34 | mel = self.get_mel(audiopath) 35 | return (text, mel) 36 | 37 | def get_mel(self, filename): 38 | if not self.load_mel_from_disk: 39 | audio, sampling_rate = load_wav_to_torch(filename) 40 | if sampling_rate != self.stft.sampling_rate: 41 | raise ValueError("{} {} SR doesn't match target {} SR".format( 42 | sampling_rate, self.stft.sampling_rate)) 43 | audio_norm = audio / self.max_wav_value 44 | audio_norm = audio_norm.unsqueeze(0) 45 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 46 | melspec = self.stft.mel_spectrogram(audio_norm) 47 | melspec = torch.squeeze(melspec, 0) 48 | else: 49 | melspec = torch.from_numpy(np.load(filename)) 50 | assert melspec.size(0) == self.stft.n_mel_channels, ( 51 | 'Mel dimension mismatch: given {}, expected {}'.format( 52 | melspec.size(0), self.stft.n_mel_channels)) 53 | 54 | return melspec 55 | 56 | def get_text(self, text): 57 | text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners)) 58 | return text_norm 59 | 60 | def __getitem__(self, index): 61 | return self.get_mel_text_pair(self.audiopaths_and_text[index]) 62 | 63 | def __len__(self): 64 | return len(self.audiopaths_and_text) 65 | 66 | 67 | class TextMelCollate(): 68 | """ Zero-pads model inputs and targets based on number of frames per setep 69 | """ 70 | def __init__(self, n_frames_per_step): 71 | self.n_frames_per_step = n_frames_per_step 72 | 73 | def __call__(self, batch): 74 | """Collate's training batch from normalized text and mel-spectrogram 75 | PARAMS 76 | ------ 77 | batch: [text_normalized, mel_normalized] 78 | """ 79 | # Right zero-pad all one-hot text sequences to max input length 80 | input_lengths, ids_sorted_decreasing = torch.sort( 81 | torch.LongTensor([len(x[0]) for x in batch]), 82 | dim=0, descending=True) 83 | max_input_len = input_lengths[0] 84 | 85 | text_padded = torch.LongTensor(len(batch), max_input_len) 86 | text_padded.zero_() 87 | for i in range(len(ids_sorted_decreasing)): 88 | text = batch[ids_sorted_decreasing[i]][0] 89 | text_padded[i, :text.size(0)] = text 90 | 91 | # Right zero-pad mel-spec 92 | num_mels = batch[0][1].size(0) 93 | max_target_len = max([x[1].size(1) for x in batch]) 94 | if max_target_len % self.n_frames_per_step != 0: 95 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step 96 | assert max_target_len % self.n_frames_per_step == 0 97 | 98 | # include mel padded and gate padded 99 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) 100 | mel_padded.zero_() 101 | gate_padded = torch.FloatTensor(len(batch), max_target_len) 102 | gate_padded.zero_() 103 | output_lengths = torch.LongTensor(len(batch)) 104 | for i in range(len(ids_sorted_decreasing)): 105 | mel = batch[ids_sorted_decreasing[i]][1] 106 | mel_padded[i, :, :mel.size(1)] = mel 107 | gate_padded[i, mel.size(1)-1:] = 1 108 | output_lengths[i] = mel.size(1) 109 | 110 | return text_padded, input_lengths, mel_padded, gate_padded, \ 111 | output_lengths 112 | -------------------------------------------------------------------------------- /denoiser.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('tacotron2') 3 | import torch 4 | from layers import STFT 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """ Removes model bias from audio produced with waveglow """ 9 | 10 | def __init__(self, waveglow, filter_length=1024, n_overlap=4, 11 | win_length=1024, mode='zeros'): 12 | super(Denoiser, self).__init__() 13 | self.stft = STFT(filter_length=filter_length, 14 | hop_length=int(filter_length/n_overlap), 15 | win_length=win_length).cuda() 16 | if mode == 'zeros': 17 | mel_input = torch.zeros( 18 | (1, 80, 88), 19 | dtype=waveglow.upsample.weight.dtype, 20 | device=waveglow.upsample.weight.device) 21 | elif mode == 'normal': 22 | mel_input = torch.randn( 23 | (1, 80, 88), 24 | dtype=waveglow.upsample.weight.dtype, 25 | device=waveglow.upsample.weight.device) 26 | else: 27 | raise Exception("Mode {} if not supported".format(mode)) 28 | 29 | with torch.no_grad(): 30 | bias_audio = waveglow.infer(mel_input, sigma=0.0).float() 31 | bias_spec, _ = self.stft.transform(bias_audio) 32 | 33 | self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) 34 | 35 | def forward(self, audio, strength=0.1): 36 | audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) 37 | audio_spec_denoised = audio_spec - self.bias_spec * strength 38 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 39 | audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) 40 | return audio_denoised 41 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.models import load_model 3 | import numpy as np 4 | from tensorflow.keras.optimizers import Adam 5 | from tensorflow.keras.models import Model 6 | from tensorflow.keras.layers import Dense, Conv2D, Conv3D, BatchNormalization, Activation, \ 7 | Concatenate, AvgPool2D, Input, MaxPool2D, UpSampling2D, Add, \ 8 | ZeroPadding2D, ZeroPadding3D, Lambda, Reshape, Flatten, LeakyReLU 9 | from keras_contrib.layers import InstanceNormalization 10 | from tensorflow.keras.callbacks import ModelCheckpoint 11 | from tensorflow.keras import backend as K 12 | import keras 13 | import cv2 14 | import os 15 | import librosa 16 | import scipy 17 | from tensorflow.keras.utils import plot_model 18 | 19 | from tensorflow.keras.utils import multi_gpu_model 20 | from tensorflow.keras import backend as K 21 | 22 | class ModelMGPU(Model): 23 | def __init__(self, ser_model, gpus): 24 | pmodel = multi_gpu_model(ser_model, gpus) 25 | self.__dict__.update(pmodel.__dict__) 26 | self._smodel = ser_model 27 | 28 | def __getattribute__(self, attrname): 29 | '''Override load and save methods to be used from the serial-model. The 30 | serial-model holds references to the weights in the multi-gpu model. 31 | ''' 32 | # return Model.__getattribute__(self, attrname) 33 | if 'load' in attrname or 'save' in attrname: 34 | return getattr(self._smodel, attrname) 35 | 36 | return super(ModelMGPU, self).__getattribute__(attrname) 37 | 38 | def contrastive_loss(y_true, y_pred): 39 | margin = 1. 40 | loss = (1. - y_true) * K.square(y_pred) + y_true * K.square(K.maximum(0., margin - y_pred)) 41 | return K.mean(loss) 42 | 43 | def conv_block(x, num_filters, kernel_size=3, strides=2, padding='same'): 44 | x = Conv2D(filters=num_filters, kernel_size= kernel_size, 45 | strides=strides, padding=padding)(x) 46 | x = InstanceNormalization()(x) 47 | x = LeakyReLU(alpha=.2)(x) 48 | return x 49 | 50 | def create_model(args, mel_step_size): 51 | ############# encoder for face/identity 52 | input_face = Input(shape=(args.img_size, args.img_size, 3), name="input_face_disc") 53 | 54 | x = conv_block(input_face, 64, 7) 55 | x = conv_block(x, 128, 5) 56 | x = conv_block(x, 256, 3) 57 | x = conv_block(x, 512, 3) 58 | x = conv_block(x, 512, 3) 59 | x = Conv2D(filters=512, kernel_size=3, strides=1, padding="valid")(x) 60 | face_embedding = Flatten() (x) 61 | 62 | ############# encoder for audio 63 | input_audio = Input(shape=(80, mel_step_size, 1), name="input_audio") 64 | 65 | x = conv_block(input_audio, 32, strides=1) 66 | x = conv_block(x, 64, strides=3) #27X9 67 | x = conv_block(x, 128, strides=(3, 1)) #9X9 68 | x = conv_block(x, 256, strides=3) #3X3 69 | x = conv_block(x, 512, strides=1, padding='valid') #1X1 70 | x = conv_block(x, 512, 1, strides=1) 71 | 72 | audio_embedding = Flatten() (x) 73 | 74 | # L2-normalize before taking L2 distance 75 | l2_normalize = Lambda(lambda x: K.l2_normalize(x, axis=1)) 76 | face_embedding = l2_normalize(face_embedding) 77 | audio_embedding = l2_normalize(audio_embedding) 78 | 79 | d = Lambda(lambda x: K.sqrt(K.sum(K.square(x[0] - x[1]), axis=1, keepdims=True))) ([face_embedding, 80 | audio_embedding]) 81 | 82 | model = Model(inputs=[input_face, input_audio], outputs=[d]) 83 | 84 | 85 | 86 | if args.n_gpu > 1: 87 | model = ModelMGPU(model , args.n_gpu) 88 | 89 | model.compile(loss=contrastive_loss, optimizer=Adam(lr=args.lr)) 90 | 91 | return model 92 | 93 | if __name__ == '__main__': 94 | model = create_model() 95 | #plot_model(model, to_file='model.png', show_shapes=True) 96 | -------------------------------------------------------------------------------- /filelists/README.md: -------------------------------------------------------------------------------- 1 | Place the LRS2 dataset filelists here -------------------------------------------------------------------------------- /flowtron_logger.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # 3 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ############################################################################### 17 | import random 18 | import torch 19 | from tensorboardX import SummaryWriter 20 | from flowtron_plotting_utils import plot_alignment_to_numpy, plot_gate_outputs_to_numpy 21 | 22 | 23 | class FlowtronLogger(SummaryWriter): 24 | def __init__(self, logdir): 25 | super(FlowtronLogger, self).__init__(logdir) 26 | 27 | def log_training(self, loss, learning_rate, iteration): 28 | self.add_scalar("training.loss", loss, iteration) 29 | self.add_scalar("learning.rate", learning_rate, iteration) 30 | 31 | def log_validation(self, loss, attns, gate_pred, gate_out, iteration): 32 | self.add_scalar("validation.loss", loss, iteration) 33 | 34 | idx = random.randint(0, len(gate_out) - 1) 35 | for i in range(len(attns)): 36 | self.add_image( 37 | 'attention_weights_{}'.format(i), 38 | plot_alignment_to_numpy(attns[i][idx].data.cpu().numpy().T), 39 | iteration, 40 | dataformats='HWC') 41 | 42 | if gate_pred is not None: 43 | gate_pred = gate_pred.transpose(0, 1)[:, :, 0] 44 | self.add_image( 45 | "gate", 46 | plot_gate_outputs_to_numpy( 47 | gate_out[idx].data.cpu().numpy(), 48 | torch.sigmoid(gate_pred[idx]).data.cpu().numpy()), 49 | iteration, dataformats='HWC') 50 | -------------------------------------------------------------------------------- /flowtron_plotting_utils.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # 3 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ############################################################################### 17 | import matplotlib 18 | matplotlib.use("Agg") 19 | import matplotlib.pylab as plt 20 | import numpy as np 21 | 22 | 23 | def save_figure_to_numpy(fig): 24 | # save it to a numpy array. 25 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 26 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 27 | return data 28 | 29 | 30 | def plot_alignment_to_numpy(alignment, info=None): 31 | fig, ax = plt.subplots(figsize=(6, 4)) 32 | im = ax.imshow(alignment, aspect='auto', origin='lower', 33 | interpolation='none') 34 | fig.colorbar(im, ax=ax) 35 | xlabel = 'Decoder timestep' 36 | if info is not None: 37 | xlabel += '\n\n' + info 38 | plt.xlabel(xlabel) 39 | plt.ylabel('Encoder timestep') 40 | plt.tight_layout() 41 | 42 | fig.canvas.draw() 43 | data = save_figure_to_numpy(fig) 44 | plt.close() 45 | return data 46 | 47 | 48 | def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): 49 | fig, ax = plt.subplots(figsize=(12, 3)) 50 | ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5, 51 | color='green', marker='+', s=1, label='target') 52 | ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, 53 | color='red', marker='.', s=1, label='predicted') 54 | 55 | plt.xlabel("Frames (Green target, Red predicted)") 56 | plt.ylabel("Gate State") 57 | plt.tight_layout() 58 | 59 | fig.canvas.draw() 60 | data = save_figure_to_numpy(fig) 61 | plt.close() 62 | return data 63 | -------------------------------------------------------------------------------- /generate_image.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import argparse 3 | import os 4 | from PIL import Image 5 | import numpy as np 6 | import torch 7 | import random as r 8 | import stylegan2 9 | import matplotlib.pyplot as plt 10 | from stylegan2 import utils 11 | 12 | def generate_images(G, args): 13 | latent_size, label_size = G.latent_size, G.label_size 14 | device = torch.device('cpu') 15 | G.to(device) 16 | if args['truncation_psi'] != 1.0: 17 | G.set_truncation(truncation_psi=args['truncation_psi']) 18 | 19 | noise_reference = G.static_noise() 20 | 21 | def get_batch(seeds): 22 | latents = [] 23 | labels = [] 24 | 25 | noise_tensors = [[] for _ in noise_reference] 26 | for seed in seeds: 27 | rnd = np.random.RandomState(seed) 28 | latents.append(torch.from_numpy(rnd.randn(latent_size))) 29 | 30 | for i, ref in enumerate(noise_reference): 31 | noise_tensors[i].append(torch.from_numpy(rnd.randn(*ref.size()[1:]))) 32 | if label_size: 33 | labels.append(torch.tensor([rnd.randint(0, label_size)])) 34 | latents = torch.stack(latents, dim=0).to(device=device, dtype=torch.float32) 35 | if labels: 36 | labels = torch.cat(labels, dim=0).to(device=device, dtype=torch.int64) 37 | else: 38 | labels = None 39 | 40 | noise_tensors = [ 41 | torch.stack(noise, dim=0).to(device=device, dtype=torch.float32) 42 | for noise in noise_tensors 43 | ] 44 | 45 | return latents, labels, noise_tensors 46 | 47 | for i in range(0, len(args['seed'])): 48 | latents, labels, noise_tensors = get_batch(args['seed'][i: i + 1]) 49 | if noise_tensors is not None: 50 | G.static_noise(noise_tensors=noise_tensors) 51 | with torch.no_grad(): 52 | generated = G(latents, labels=labels) 53 | images = utils.tensor_to_PIL( 54 | generated, pixel_min=-1.0, pixel_max=1.0) 55 | for seed, img in zip(args['seed'][i: i + 1], images): 56 | img.save(f"{args['output']}temp.png") 57 | 58 | 59 | def main(): 60 | 61 | args ={'output':'', 62 | 'network':'models/stylegan_Gs.pth', 63 | 'seed': [int(r.uniform(0, (2**32 -1)))], 64 | 'truncation_psi':r.uniform(0.7, 1.1)} 65 | 66 | G = stylegan2.models.load(args['network']) 67 | generate_images(G, args) 68 | -------------------------------------------------------------------------------- /generation_resources/0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/0.gif -------------------------------------------------------------------------------- /generation_resources/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/1.gif -------------------------------------------------------------------------------- /generation_resources/10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/10.gif -------------------------------------------------------------------------------- /generation_resources/11.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/11.gif -------------------------------------------------------------------------------- /generation_resources/12.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/12.gif -------------------------------------------------------------------------------- /generation_resources/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/2.gif -------------------------------------------------------------------------------- /generation_resources/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/3.gif -------------------------------------------------------------------------------- /generation_resources/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/4.gif -------------------------------------------------------------------------------- /generation_resources/5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/5.gif -------------------------------------------------------------------------------- /generation_resources/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/6.gif -------------------------------------------------------------------------------- /generation_resources/7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/7.gif -------------------------------------------------------------------------------- /generation_resources/8.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/8.gif -------------------------------------------------------------------------------- /generation_resources/9.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/generation_resources/9.gif -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from text import symbols 3 | 4 | 5 | def create_hparams(hparams_string=None, verbose=False): 6 | """Create model hyperparameters. Parse nondefault from given string.""" 7 | 8 | hparams = tf.contrib.training.HParams( 9 | ################################ 10 | # Experiment Parameters # 11 | ################################ 12 | epochs=500, 13 | iters_per_checkpoint=1000, 14 | seed=1234, 15 | dynamic_loss_scaling=True, 16 | fp16_run=False, 17 | distributed_run=False, 18 | dist_backend="nccl", 19 | dist_url="tcp://localhost:54321", 20 | cudnn_enabled=True, 21 | cudnn_benchmark=False, 22 | ignore_layers=['embedding.weight'], 23 | 24 | ################################ 25 | # Data Parameters # 26 | ################################ 27 | load_mel_from_disk=False, 28 | training_files='filelists/ljs_audio_text_train_filelist.txt', 29 | validation_files='filelists/ljs_audio_text_val_filelist.txt', 30 | text_cleaners=['english_cleaners'], 31 | 32 | ################################ 33 | # Audio Parameters # 34 | ################################ 35 | max_wav_value=32768.0, 36 | sampling_rate=22050, 37 | filter_length=1024, 38 | hop_length=256, 39 | win_length=1024, 40 | n_mel_channels=80, 41 | mel_fmin=0.0, 42 | mel_fmax=8000.0, 43 | 44 | ################################ 45 | # Model Parameters # 46 | ################################ 47 | n_symbols=len(symbols), 48 | symbols_embedding_dim=512, 49 | 50 | # Encoder parameters 51 | encoder_kernel_size=5, 52 | encoder_n_convolutions=3, 53 | encoder_embedding_dim=512, 54 | 55 | # Decoder parameters 56 | n_frames_per_step=1, # currently only 1 is supported 57 | decoder_rnn_dim=1024, 58 | prenet_dim=256, 59 | max_decoder_steps=1000, 60 | gate_threshold=0.5, 61 | p_attention_dropout=0.1, 62 | p_decoder_dropout=0.1, 63 | 64 | # Attention parameters 65 | attention_rnn_dim=1024, 66 | attention_dim=128, 67 | 68 | # Location Layer parameters 69 | attention_location_n_filters=32, 70 | attention_location_kernel_size=31, 71 | 72 | # Mel-post processing network parameters 73 | postnet_embedding_dim=512, 74 | postnet_kernel_size=5, 75 | postnet_n_convolutions=5, 76 | 77 | ################################ 78 | # Optimization Hyperparameters # 79 | ################################ 80 | use_saved_learning_rate=False, 81 | learning_rate=1e-3, 82 | weight_decay=1e-6, 83 | grad_clip_thresh=1.0, 84 | batch_size=64, 85 | mask_padding=True # set model's padded outputs to padded values 86 | ) 87 | 88 | if hparams_string: 89 | tf.logging.info('Parsing command line hparams: %s', hparams_string) 90 | hparams.parse(hparams_string) 91 | 92 | if verbose: 93 | tf.logging.info('Final parsed hparams: %s', hparams.values()) 94 | 95 | return hparams 96 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # 3 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | ############################################################################### 17 | import matplotlib 18 | matplotlib.use("Agg") 19 | import matplotlib.pylab as plt 20 | 21 | import os 22 | import argparse 23 | import json 24 | import sys 25 | import numpy as np 26 | import torch 27 | 28 | 29 | from flowtron import Flowtron 30 | from torch.utils.data import DataLoader 31 | from data import Data 32 | from train import update_params 33 | 34 | sys.path.insert(0, "tacotron2") 35 | sys.path.insert(0, "tacotron2/waveglow") 36 | from glow import WaveGlow 37 | from scipy.io.wavfile import write 38 | 39 | 40 | def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames, 41 | sigma, gate_threshold, seed): 42 | torch.manual_seed(seed) 43 | torch.cuda.manual_seed(seed) 44 | 45 | # load waveglow 46 | waveglow = torch.load(waveglow_path)['model'].cuda().eval() 47 | waveglow.cuda().half() 48 | for k in waveglow.convinv: 49 | k.float() 50 | waveglow.eval() 51 | 52 | # load flowtron 53 | model = Flowtron(**model_config).cuda() 54 | state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict'] 55 | model.load_state_dict(state_dict) 56 | model.eval() 57 | print("Loaded checkpoint '{}')" .format(flowtron_path)) 58 | 59 | ignore_keys = ['training_files', 'validation_files'] 60 | trainset = Data( 61 | data_config['training_files'], 62 | **dict((k, v) for k, v in data_config.items() if k not in ignore_keys)) 63 | speaker_vecs = trainset.get_speaker_id(speaker_id).cuda() 64 | text = trainset.get_text(text).cuda() 65 | speaker_vecs = speaker_vecs[None] 66 | text = text[None] 67 | 68 | with torch.no_grad(): 69 | residual = torch.cuda.FloatTensor(1, 80, n_frames).normal_() * sigma 70 | mels, attentions = model.infer( 71 | residual, speaker_vecs, text, gate_threshold=gate_threshold) 72 | 73 | for k in range(len(attentions)): 74 | attention = torch.cat(attentions[k]).cpu().numpy() 75 | fig, axes = plt.subplots(1, 2, figsize=(16, 4)) 76 | axes[0].imshow(mels[0].cpu().numpy(), origin='bottom', aspect='auto') 77 | axes[1].imshow(attention[:, 0].transpose(), origin='bottom', aspect='auto') 78 | fig.savefig(os.path.join(output_dir, 'sid{}_sigma{}_attnlayer{}.png'.format(speaker_id, sigma, k))) 79 | plt.close("all") 80 | 81 | audio = waveglow.infer(mels.half(), sigma=0.8).float() 82 | audio = audio.cpu().numpy()[0] 83 | # normalize audio for now 84 | audio = audio / np.abs(audio).max() 85 | print(audio.shape) 86 | 87 | write(os.path.join(output_dir, 'sid{}_sigma{}.wav'.format(speaker_id, sigma)), 88 | data_config['sampling_rate'], audio) 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument('-c', '--config', type=str, 94 | help='JSON file for configuration') 95 | parser.add_argument('-p', '--params', nargs='+', default=[]) 96 | parser.add_argument('-f', '--flowtron_path', 97 | help='Path to flowtron state dict', type=str) 98 | parser.add_argument('-w', '--waveglow_path', 99 | help='Path to waveglow state dict', type=str) 100 | parser.add_argument('-t', '--text', help='Text to synthesize', type=str) 101 | parser.add_argument('-i', '--id', help='Speaker id', type=int) 102 | parser.add_argument('-n', '--n_frames', help='Number of frames', 103 | default=400, type=int) 104 | parser.add_argument('-o', "--output_dir", default="results/") 105 | parser.add_argument("-s", "--sigma", default=0.5, type=float) 106 | parser.add_argument("-g", "--gate", default=0.5, type=float) 107 | parser.add_argument("--seed", default=1234, type=int) 108 | args = parser.parse_args() 109 | 110 | # Parse configs. Globals nicer in this case 111 | with open(args.config) as f: 112 | data = f.read() 113 | 114 | global config 115 | config = json.loads(data) 116 | update_params(config, args.params) 117 | 118 | data_config = config["data_config"] 119 | global model_config 120 | model_config = config["model_config"] 121 | 122 | # Make directory if it doesn't exist 123 | if not os.path.isdir(args.output_dir): 124 | os.makedirs(args.output_dir) 125 | os.chmod(args.output_dir, 0o775) 126 | 127 | torch.backends.cudnn.enabled = True 128 | torch.backends.cudnn.benchmark = False 129 | infer(args.flowtron_path, args.waveglow_path, args.output_dir, args.text, 130 | args.id, args.n_frames, args.sigma, args.gate, args.seed) 131 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from librosa.filters import mel as librosa_mel_fn 3 | from audio_processing import dynamic_range_compression 4 | from audio_processing import dynamic_range_decompression 5 | from stft import STFT 6 | 7 | 8 | class LinearNorm(torch.nn.Module): 9 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 10 | super(LinearNorm, self).__init__() 11 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 12 | 13 | torch.nn.init.xavier_uniform_( 14 | self.linear_layer.weight, 15 | gain=torch.nn.init.calculate_gain(w_init_gain)) 16 | 17 | def forward(self, x): 18 | return self.linear_layer(x) 19 | 20 | 21 | class ConvNorm(torch.nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 23 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 24 | super(ConvNorm, self).__init__() 25 | if padding is None: 26 | assert(kernel_size % 2 == 1) 27 | padding = int(dilation * (kernel_size - 1) / 2) 28 | 29 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 30 | kernel_size=kernel_size, stride=stride, 31 | padding=padding, dilation=dilation, 32 | bias=bias) 33 | 34 | torch.nn.init.xavier_uniform_( 35 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 36 | 37 | def forward(self, signal): 38 | conv_signal = self.conv(signal) 39 | return conv_signal 40 | 41 | 42 | class TacotronSTFT(torch.nn.Module): 43 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 44 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 45 | mel_fmax=8000.0): 46 | super(TacotronSTFT, self).__init__() 47 | self.n_mel_channels = n_mel_channels 48 | self.sampling_rate = sampling_rate 49 | self.stft_fn = STFT(filter_length, hop_length, win_length) 50 | mel_basis = librosa_mel_fn( 51 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 52 | mel_basis = torch.from_numpy(mel_basis).float() 53 | self.register_buffer('mel_basis', mel_basis) 54 | 55 | def spectral_normalize(self, magnitudes): 56 | output = dynamic_range_compression(magnitudes) 57 | return output 58 | 59 | def spectral_de_normalize(self, magnitudes): 60 | output = dynamic_range_decompression(magnitudes) 61 | return output 62 | 63 | def mel_spectrogram(self, y): 64 | """Computes mel-spectrograms from a batch of waves 65 | PARAMS 66 | ------ 67 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 68 | 69 | RETURNS 70 | ------- 71 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 72 | """ 73 | assert(torch.min(y.data) >= -1) 74 | assert(torch.max(y.data) <= 1) 75 | 76 | magnitudes, phases = self.stft_fn.transform(y) 77 | magnitudes = magnitudes.data 78 | mel_output = torch.matmul(self.mel_basis, magnitudes) 79 | mel_output = self.spectral_normalize(mel_output) 80 | return mel_output 81 | -------------------------------------------------------------------------------- /log_file: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/log_file -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy 5 | from plotting_utils import plot_gate_outputs_to_numpy 6 | 7 | 8 | class Tacotron2Logger(SummaryWriter): 9 | def __init__(self, logdir): 10 | super(Tacotron2Logger, self).__init__(logdir) 11 | 12 | def log_training(self, reduced_loss, grad_norm, learning_rate, duration, 13 | iteration): 14 | self.add_scalar("training.loss", reduced_loss, iteration) 15 | self.add_scalar("grad.norm", grad_norm, iteration) 16 | self.add_scalar("learning.rate", learning_rate, iteration) 17 | self.add_scalar("duration", duration, iteration) 18 | 19 | def log_validation(self, reduced_loss, model, y, y_pred, iteration): 20 | self.add_scalar("validation.loss", reduced_loss, iteration) 21 | _, mel_outputs, gate_outputs, alignments = y_pred 22 | mel_targets, gate_targets = y 23 | 24 | # plot distribution of parameters 25 | for tag, value in model.named_parameters(): 26 | tag = tag.replace('.', '/') 27 | self.add_histogram(tag, value.data.cpu().numpy(), iteration) 28 | 29 | # plot alignment, mel target and predicted, gate target and predicted 30 | idx = random.randint(0, alignments.size(0) - 1) 31 | self.add_image( 32 | "alignment", 33 | plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T), 34 | iteration, dataformats='HWC') 35 | self.add_image( 36 | "mel_target", 37 | plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()), 38 | iteration, dataformats='HWC') 39 | self.add_image( 40 | "mel_predicted", 41 | plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()), 42 | iteration, dataformats='HWC') 43 | self.add_image( 44 | "gate", 45 | plot_gate_outputs_to_numpy( 46 | gate_targets[idx].data.cpu().numpy(), 47 | torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()), 48 | iteration, dataformats='HWC') 49 | -------------------------------------------------------------------------------- /loss_function.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Tacotron2Loss(nn.Module): 5 | def __init__(self): 6 | super(Tacotron2Loss, self).__init__() 7 | 8 | def forward(self, model_output, targets): 9 | mel_target, gate_target = targets[0], targets[1] 10 | mel_target.requires_grad = False 11 | gate_target.requires_grad = False 12 | gate_target = gate_target.view(-1, 1) 13 | 14 | mel_out, mel_out_postnet, gate_out, _ = model_output 15 | gate_out = gate_out.view(-1, 1) 16 | mel_loss = nn.MSELoss()(mel_out, mel_target) + \ 17 | nn.MSELoss()(mel_out_postnet, mel_target) 18 | gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) 19 | return mel_loss + gate_loss 20 | -------------------------------------------------------------------------------- /loss_scaler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LossScaler: 4 | 5 | def __init__(self, scale=1): 6 | self.cur_scale = scale 7 | 8 | # `params` is a list / generator of torch.Variable 9 | def has_overflow(self, params): 10 | return False 11 | 12 | # `x` is a torch.Tensor 13 | def _has_inf_or_nan(x): 14 | return False 15 | 16 | # `overflow` is boolean indicating whether we overflowed in gradient 17 | def update_scale(self, overflow): 18 | pass 19 | 20 | @property 21 | def loss_scale(self): 22 | return self.cur_scale 23 | 24 | def scale_gradient(self, module, grad_in, grad_out): 25 | return tuple(self.loss_scale * g for g in grad_in) 26 | 27 | def backward(self, loss): 28 | scaled_loss = loss*self.loss_scale 29 | scaled_loss.backward() 30 | 31 | class DynamicLossScaler: 32 | 33 | def __init__(self, 34 | init_scale=2**32, 35 | scale_factor=2., 36 | scale_window=1000): 37 | self.cur_scale = init_scale 38 | self.cur_iter = 0 39 | self.last_overflow_iter = -1 40 | self.scale_factor = scale_factor 41 | self.scale_window = scale_window 42 | 43 | # `params` is a list / generator of torch.Variable 44 | def has_overflow(self, params): 45 | # return False 46 | for p in params: 47 | if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): 48 | return True 49 | 50 | return False 51 | 52 | # `x` is a torch.Tensor 53 | def _has_inf_or_nan(x): 54 | cpu_sum = float(x.float().sum()) 55 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 56 | return True 57 | return False 58 | 59 | # `overflow` is boolean indicating whether we overflowed in gradient 60 | def update_scale(self, overflow): 61 | if overflow: 62 | #self.cur_scale /= self.scale_factor 63 | self.cur_scale = max(self.cur_scale/self.scale_factor, 1) 64 | self.last_overflow_iter = self.cur_iter 65 | else: 66 | if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 67 | self.cur_scale *= self.scale_factor 68 | # self.cur_scale = 1 69 | self.cur_iter += 1 70 | 71 | @property 72 | def loss_scale(self): 73 | return self.cur_scale 74 | 75 | def scale_gradient(self, module, grad_in, grad_out): 76 | return tuple(self.loss_scale * g for g in grad_in) 77 | 78 | def backward(self, loss): 79 | scaled_loss = loss*self.loss_scale 80 | scaled_loss.backward() 81 | 82 | ############################################################## 83 | # Example usage below here -- assuming it's in a separate file 84 | ############################################################## 85 | if __name__ == "__main__": 86 | import torch 87 | from torch.autograd import Variable 88 | from dynamic_loss_scaler import DynamicLossScaler 89 | 90 | # N is batch size; D_in is input dimension; 91 | # H is hidden dimension; D_out is output dimension. 92 | N, D_in, H, D_out = 64, 1000, 100, 10 93 | 94 | # Create random Tensors to hold inputs and outputs, and wrap them in Variables. 95 | x = Variable(torch.randn(N, D_in), requires_grad=False) 96 | y = Variable(torch.randn(N, D_out), requires_grad=False) 97 | 98 | w1 = Variable(torch.randn(D_in, H), requires_grad=True) 99 | w2 = Variable(torch.randn(H, D_out), requires_grad=True) 100 | parameters = [w1, w2] 101 | 102 | learning_rate = 1e-6 103 | optimizer = torch.optim.SGD(parameters, lr=learning_rate) 104 | loss_scaler = DynamicLossScaler() 105 | 106 | for t in range(500): 107 | y_pred = x.mm(w1).clamp(min=0).mm(w2) 108 | loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale 109 | print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) 110 | print('Iter {} scaled loss: {}'.format(t, loss.data[0])) 111 | print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) 112 | 113 | # Run backprop 114 | optimizer.zero_grad() 115 | loss.backward() 116 | 117 | # Check for overflow 118 | has_overflow = DynamicLossScaler.has_overflow(parameters) 119 | 120 | # If no overflow, unscale grad and update as usual 121 | if not has_overflow: 122 | for param in parameters: 123 | param.grad.data.mul_(1. / loss_scaler.loss_scale) 124 | optimizer.step() 125 | # Otherwise, don't do anything -- ie, skip iteration 126 | else: 127 | print('OVERFLOW!') 128 | 129 | # Update loss scale for next iteration 130 | loss_scaler.update_scale(has_overflow) 131 | 132 | -------------------------------------------------------------------------------- /modules/dense_motion.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian 5 | 6 | 7 | class DenseMotionNetwork(nn.Module): 8 | """ 9 | Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving 10 | """ 11 | 12 | def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False, 13 | scale_factor=1, kp_variance=0.01): 14 | super(DenseMotionNetwork, self).__init__() 15 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1), 16 | max_features=max_features, num_blocks=num_blocks) 17 | 18 | self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3)) 19 | 20 | if estimate_occlusion_map: 21 | self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) 22 | else: 23 | self.occlusion = None 24 | 25 | self.num_kp = num_kp 26 | self.scale_factor = scale_factor 27 | self.kp_variance = kp_variance 28 | 29 | if self.scale_factor != 1: 30 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) 31 | 32 | def create_heatmap_representations(self, source_image, kp_driving, kp_source): 33 | """ 34 | Eq 6. in the paper H_k(z) 35 | """ 36 | spatial_size = source_image.shape[2:] 37 | gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance) 38 | gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance) 39 | heatmap = gaussian_driving - gaussian_source 40 | 41 | #adding background feature 42 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()) 43 | heatmap = torch.cat([zeros, heatmap], dim=1) 44 | heatmap = heatmap.unsqueeze(2) 45 | return heatmap 46 | 47 | def create_sparse_motions(self, source_image, kp_driving, kp_source): 48 | """ 49 | Eq 4. in the paper T_{s<-d}(z) 50 | """ 51 | bs, _, h, w = source_image.shape 52 | identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type()) 53 | identity_grid = identity_grid.view(1, 1, h, w, 2) 54 | coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2) 55 | if 'jacobian' in kp_driving: 56 | jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) 57 | jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) 58 | jacobian = jacobian.repeat(1, 1, h, w, 1, 1) 59 | coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) 60 | coordinate_grid = coordinate_grid.squeeze(-1) 61 | 62 | driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2) 63 | 64 | #adding background feature 65 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) 66 | sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) 67 | return sparse_motions 68 | 69 | def create_deformed_source_image(self, source_image, sparse_motions): 70 | """ 71 | Eq 7. in the paper \hat{T}_{s<-d}(z) 72 | """ 73 | bs, _, h, w = source_image.shape 74 | source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1) 75 | source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w) 76 | sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1)) 77 | sparse_deformed = F.grid_sample(source_repeat, sparse_motions) 78 | sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w)) 79 | return sparse_deformed 80 | 81 | def forward(self, source_image, kp_driving, kp_source): 82 | if self.scale_factor != 1: 83 | source_image = self.down(source_image) 84 | 85 | bs, _, h, w = source_image.shape 86 | 87 | out_dict = dict() 88 | heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source) 89 | sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source) 90 | deformed_source = self.create_deformed_source_image(source_image, sparse_motion) 91 | out_dict['sparse_deformed'] = deformed_source 92 | 93 | input = torch.cat([heatmap_representation, deformed_source], dim=2) 94 | input = input.view(bs, -1, h, w) 95 | 96 | prediction = self.hourglass(input) 97 | 98 | mask = self.mask(prediction) 99 | mask = F.softmax(mask, dim=1) 100 | out_dict['mask'] = mask 101 | mask = mask.unsqueeze(2) 102 | sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) 103 | deformation = (sparse_motion * mask).sum(dim=1) 104 | deformation = deformation.permute(0, 2, 3, 1) 105 | 106 | out_dict['deformation'] = deformation 107 | 108 | # Sec. 3.2 in the paper 109 | if self.occlusion: 110 | occlusion_map = torch.sigmoid(self.occlusion(prediction)) 111 | out_dict['occlusion_map'] = occlusion_map 112 | 113 | return out_dict 114 | -------------------------------------------------------------------------------- /modules/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from modules.util import kp2gaussian 4 | import torch 5 | 6 | 7 | class DownBlock2d(nn.Module): 8 | """ 9 | Simple block for processing video (encoder). 10 | """ 11 | 12 | def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): 13 | super(DownBlock2d, self).__init__() 14 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) 15 | 16 | if sn: 17 | self.conv = nn.utils.spectral_norm(self.conv) 18 | 19 | if norm: 20 | self.norm = nn.InstanceNorm2d(out_features, affine=True) 21 | else: 22 | self.norm = None 23 | self.pool = pool 24 | 25 | def forward(self, x): 26 | out = x 27 | out = self.conv(out) 28 | if self.norm: 29 | out = self.norm(out) 30 | out = F.leaky_relu(out, 0.2) 31 | if self.pool: 32 | out = F.avg_pool2d(out, (2, 2)) 33 | return out 34 | 35 | 36 | class Discriminator(nn.Module): 37 | """ 38 | Discriminator similar to Pix2Pix 39 | """ 40 | 41 | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, 42 | sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs): 43 | super(Discriminator, self).__init__() 44 | 45 | down_blocks = [] 46 | for i in range(num_blocks): 47 | down_blocks.append( 48 | DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)), 49 | min(max_features, block_expansion * (2 ** (i + 1))), 50 | norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) 51 | 52 | self.down_blocks = nn.ModuleList(down_blocks) 53 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) 54 | if sn: 55 | self.conv = nn.utils.spectral_norm(self.conv) 56 | self.use_kp = use_kp 57 | self.kp_variance = kp_variance 58 | 59 | def forward(self, x, kp=None): 60 | feature_maps = [] 61 | out = x 62 | if self.use_kp: 63 | heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance) 64 | out = torch.cat([out, heatmap], dim=1) 65 | 66 | for down_block in self.down_blocks: 67 | feature_maps.append(down_block(out)) 68 | out = feature_maps[-1] 69 | prediction_map = self.conv(out) 70 | 71 | return feature_maps, prediction_map 72 | 73 | 74 | class MultiScaleDiscriminator(nn.Module): 75 | """ 76 | Multi-scale (scale) discriminator 77 | """ 78 | 79 | def __init__(self, scales=(), **kwargs): 80 | super(MultiScaleDiscriminator, self).__init__() 81 | self.scales = scales 82 | discs = {} 83 | for scale in scales: 84 | discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) 85 | self.discs = nn.ModuleDict(discs) 86 | 87 | def forward(self, x, kp=None): 88 | out_dict = {} 89 | for scale, disc in self.discs.items(): 90 | scale = str(scale).replace('-', '.') 91 | key = 'prediction_' + scale 92 | feature_maps, prediction_map = disc(x[key], kp) 93 | out_dict['feature_maps_' + scale] = feature_maps 94 | out_dict['prediction_map_' + scale] = prediction_map 95 | return out_dict 96 | -------------------------------------------------------------------------------- /modules/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d 5 | from modules.dense_motion import DenseMotionNetwork 6 | 7 | 8 | class OcclusionAwareGenerator(nn.Module): 9 | """ 10 | Generator that given source image and and keypoints try to transform image according to movement trajectories 11 | induced by keypoints. Generator follows Johnson architecture. 12 | """ 13 | 14 | def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks, 15 | num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): 16 | super(OcclusionAwareGenerator, self).__init__() 17 | 18 | if dense_motion_params is not None: 19 | self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels, 20 | estimate_occlusion_map=estimate_occlusion_map, 21 | **dense_motion_params) 22 | else: 23 | self.dense_motion_network = None 24 | 25 | self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) 26 | 27 | down_blocks = [] 28 | for i in range(num_down_blocks): 29 | in_features = min(max_features, block_expansion * (2 ** i)) 30 | out_features = min(max_features, block_expansion * (2 ** (i + 1))) 31 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 32 | self.down_blocks = nn.ModuleList(down_blocks) 33 | 34 | up_blocks = [] 35 | for i in range(num_down_blocks): 36 | in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i))) 37 | out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1))) 38 | up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 39 | self.up_blocks = nn.ModuleList(up_blocks) 40 | 41 | self.bottleneck = torch.nn.Sequential() 42 | in_features = min(max_features, block_expansion * (2 ** num_down_blocks)) 43 | for i in range(num_bottleneck_blocks): 44 | self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) 45 | 46 | self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) 47 | self.estimate_occlusion_map = estimate_occlusion_map 48 | self.num_channels = num_channels 49 | 50 | def deform_input(self, inp, deformation): 51 | _, h_old, w_old, _ = deformation.shape 52 | _, _, h, w = inp.shape 53 | if h_old != h or w_old != w: 54 | deformation = deformation.permute(0, 3, 1, 2) 55 | deformation = F.interpolate(deformation, size=(h, w), mode='bilinear') 56 | deformation = deformation.permute(0, 2, 3, 1) 57 | return F.grid_sample(inp, deformation) 58 | 59 | def forward(self, source_image, kp_driving, kp_source): 60 | # Encoding (downsampling) part 61 | out = self.first(source_image) 62 | for i in range(len(self.down_blocks)): 63 | out = self.down_blocks[i](out) 64 | 65 | # Transforming feature representation according to deformation and occlusion 66 | output_dict = {} 67 | if self.dense_motion_network is not None: 68 | dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving, 69 | kp_source=kp_source) 70 | output_dict['mask'] = dense_motion['mask'] 71 | output_dict['sparse_deformed'] = dense_motion['sparse_deformed'] 72 | 73 | if 'occlusion_map' in dense_motion: 74 | occlusion_map = dense_motion['occlusion_map'] 75 | output_dict['occlusion_map'] = occlusion_map 76 | else: 77 | occlusion_map = None 78 | deformation = dense_motion['deformation'] 79 | out = self.deform_input(out, deformation) 80 | 81 | if occlusion_map is not None: 82 | if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: 83 | occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') 84 | out = out * occlusion_map 85 | 86 | output_dict["deformed"] = self.deform_input(source_image, deformation) 87 | 88 | # Decoding part 89 | out = self.bottleneck(out) 90 | for i in range(len(self.up_blocks)): 91 | out = self.up_blocks[i](out) 92 | out = self.final(out) 93 | out = F.sigmoid(out) 94 | 95 | output_dict["prediction"] = out 96 | 97 | return output_dict 98 | -------------------------------------------------------------------------------- /modules/keypoint_detector.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d 5 | 6 | 7 | class KPDetector(nn.Module): 8 | """ 9 | Detecting a keypoints. Return keypoint position and jacobian near each keypoint. 10 | """ 11 | 12 | def __init__(self, block_expansion, num_kp, num_channels, max_features, 13 | num_blocks, temperature, estimate_jacobian=False, scale_factor=1, 14 | single_jacobian_map=False, pad=0): 15 | super(KPDetector, self).__init__() 16 | 17 | self.predictor = Hourglass(block_expansion, in_features=num_channels, 18 | max_features=max_features, num_blocks=num_blocks) 19 | 20 | self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7), 21 | padding=pad) 22 | 23 | if estimate_jacobian: 24 | self.num_jacobian_maps = 1 if single_jacobian_map else num_kp 25 | self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, 26 | out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad) 27 | self.jacobian.weight.data.zero_() 28 | self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) 29 | else: 30 | self.jacobian = None 31 | 32 | self.temperature = temperature 33 | self.scale_factor = scale_factor 34 | if self.scale_factor != 1: 35 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) 36 | 37 | def gaussian2kp(self, heatmap): 38 | """ 39 | Extract the mean and from a heatmap 40 | """ 41 | shape = heatmap.shape 42 | heatmap = heatmap.unsqueeze(-1) 43 | grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) 44 | value = (heatmap * grid).sum(dim=(2, 3)) 45 | kp = {'value': value} 46 | 47 | return kp 48 | 49 | def forward(self, x): 50 | if self.scale_factor != 1: 51 | x = self.down(x) 52 | 53 | feature_map = self.predictor(x) 54 | prediction = self.kp(feature_map) 55 | 56 | final_shape = prediction.shape 57 | heatmap = prediction.view(final_shape[0], final_shape[1], -1) 58 | heatmap = F.softmax(heatmap / self.temperature, dim=2) 59 | heatmap = heatmap.view(*final_shape) 60 | 61 | out = self.gaussian2kp(heatmap) 62 | 63 | if self.jacobian is not None: 64 | jacobian_map = self.jacobian(feature_map) 65 | jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], 66 | final_shape[3]) 67 | heatmap = heatmap.unsqueeze(2) 68 | 69 | jacobian = heatmap * jacobian_map 70 | jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) 71 | jacobian = jacobian.sum(dim=-1) 72 | jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) 73 | out['jacobian'] = jacobian 74 | 75 | return out 76 | -------------------------------------------------------------------------------- /multiproc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import sys 4 | import subprocess 5 | 6 | argslist = list(sys.argv)[1:] 7 | num_gpus = torch.cuda.device_count() 8 | argslist.append('--n_gpus={}'.format(num_gpus)) 9 | workers = [] 10 | job_id = time.strftime("%Y_%m_%d-%H%M%S") 11 | argslist.append("--group_name=group_{}".format(job_id)) 12 | 13 | for i in range(num_gpus): 14 | argslist.append('--rank={}'.format(i)) 15 | stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i), 16 | "w") 17 | print(argslist) 18 | p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) 19 | workers.append(p) 20 | argslist = argslist[:-1] 21 | 22 | for p in workers: 23 | p.wait() 24 | -------------------------------------------------------------------------------- /output/audio/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/output/audio/test.wav -------------------------------------------------------------------------------- /plotting_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | return data 12 | 13 | 14 | def plot_alignment_to_numpy(alignment, info=None): 15 | fig, ax = plt.subplots(figsize=(6, 4)) 16 | im = ax.imshow(alignment, aspect='auto', origin='lower', 17 | interpolation='none') 18 | fig.colorbar(im, ax=ax) 19 | xlabel = 'Decoder timestep' 20 | if info is not None: 21 | xlabel += '\n\n' + info 22 | plt.xlabel(xlabel) 23 | plt.ylabel('Encoder timestep') 24 | plt.tight_layout() 25 | 26 | fig.canvas.draw() 27 | data = save_figure_to_numpy(fig) 28 | plt.close() 29 | return data 30 | 31 | 32 | def plot_spectrogram_to_numpy(spectrogram): 33 | fig, ax = plt.subplots(figsize=(12, 3)) 34 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 35 | interpolation='none') 36 | plt.colorbar(im, ax=ax) 37 | plt.xlabel("Frames") 38 | plt.ylabel("Channels") 39 | plt.tight_layout() 40 | 41 | fig.canvas.draw() 42 | data = save_figure_to_numpy(fig) 43 | plt.close() 44 | return data 45 | 46 | 47 | def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): 48 | fig, ax = plt.subplots(figsize=(12, 3)) 49 | ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5, 50 | color='green', marker='+', s=1, label='target') 51 | ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, 52 | color='red', marker='.', s=1, label='predicted') 53 | 54 | plt.xlabel("Frames (Green target, Red predicted)") 55 | plt.ylabel("Gate State") 56 | plt.tight_layout() 57 | 58 | fig.canvas.draw() 59 | data = save_figure_to_numpy(fig) 60 | plt.close() 61 | return data 62 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[0] < 3 and sys.version_info[1] < 2: 4 | raise Exception("Must be using >= Python 3.2") 5 | 6 | import multiprocessing as mp 7 | from concurrent.futures import ThreadPoolExecutor, as_completed 8 | from os import listdir, path 9 | import numpy as np 10 | import argparse, os, cv2, traceback, subprocess 11 | from tqdm import tqdm 12 | import dlib, audio 13 | 14 | detector = dlib.get_frontal_face_detector() 15 | 16 | def rect_to_bb(rect): 17 | x = rect.left() 18 | y = rect.top() 19 | w = rect.right() - x 20 | h = rect.bottom() - y 21 | return (x, y, w, h) 22 | 23 | def calcMaxArea(rects): 24 | max_cords = (-1,-1,-1,-1) 25 | max_area = 0 26 | max_rect = None 27 | for i in range(len(rects)): 28 | cur_rect = rects[i] 29 | (x,y,w,h) = rect_to_bb(cur_rect) 30 | if w*h > max_area: 31 | max_area = w*h 32 | max_cords = (x,y,w,h) 33 | max_rect = cur_rect 34 | return max_cords, max_rect 35 | 36 | def face_detect(image): 37 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 38 | rects = detector(gray, 1) 39 | (x, y, w, h), max_rect = calcMaxArea(rects) 40 | if x == -1: 41 | return None, False 42 | faceAligned = image[y:y+h, x:x+w] 43 | if 0 in faceAligned.shape: return None, False 44 | return faceAligned, True 45 | 46 | step_size_in_ms = 40 47 | window_size = 350 48 | mfcc_chunk_size = window_size // 10 49 | mfcc_step_size = 4 50 | fps = 25 51 | video_step_size_in_ms = mfcc_step_size * 10 # for 25 fps video 52 | sr = 16000 53 | 54 | template = 'ffmpeg -loglevel panic -y -i {} -ar {} {}' 55 | 56 | def process_video_file(vfile, args, split): 57 | video_stream = cv2.VideoCapture(vfile) 58 | frames = [] 59 | while 1: 60 | still_reading, frame = video_stream.read() 61 | if not still_reading: 62 | video_stream.release() 63 | break 64 | frames.append(frame) 65 | mid_frames = [] 66 | ss = 0. 67 | es = (ss + (window_size / 1000.)) 68 | 69 | while int(es * fps) <= len(frames): 70 | mid_second = (ss + es) / 2. 71 | mid_frames.append(frames[int(mid_second * fps)]) 72 | 73 | ss += (video_step_size_in_ms / 1000.) 74 | es = (ss + (window_size / 1000.)) 75 | 76 | 77 | dst_subdir = path.join(vfile.split('/')[-2], vfile.split('/')[-1].split('.')[0]) 78 | fulldir = path.join(args.final_data_root, split, dst_subdir) 79 | os.makedirs(fulldir, exist_ok=True) 80 | wavpath = path.join(fulldir, 'audio.wav') 81 | 82 | command = template.format(vfile, sr, wavpath) 83 | subprocess.call(command, shell=True) 84 | 85 | specpath = path.join(fulldir, 'mels.npz') 86 | 87 | if path.isfile(wavpath): 88 | wav = audio.load_wav(wavpath, sr) 89 | 90 | spec = audio.melspectrogram(wav) 91 | np.savez_compressed(specpath, spec=spec) 92 | else: 93 | return 94 | 95 | for i, f in enumerate(mid_frames): 96 | face, valid_frame = face_detect(f) 97 | 98 | if not valid_frame: 99 | continue 100 | 101 | resized_face = cv2.resize(face, (args.img_size, args.img_size)) 102 | 103 | cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), resized_face) 104 | 105 | def mp_handler(job): 106 | vfile, args, split = job 107 | try: 108 | process_video_file(vfile, args, split) 109 | except: 110 | traceback.print_exc() 111 | 112 | def dump_split(args): 113 | print('Started processing for {} with {} CPU cores'.format(args.split, args.num_workers)) 114 | 115 | filelist = [path.join(args.videos_data_root, ('pretrain' if args.split == 'pretrain' else 'main'), 116 | '{}.mp4'.format(line.strip())) \ 117 | for line in open(path.join(args.filelists, '{}.txt'.format(args.split))).readlines()] 118 | 119 | jobs = [(vfile, args, ('pretrain' if args.split == 'pretrain' else 'main')) for vfile in filelist] 120 | p = ThreadPoolExecutor(args.num_workers) 121 | futures = [p.submit(mp_handler, j) for j in jobs] 122 | _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))] 123 | 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument('--split', help='LRS2 dataset split to preprocess', default='train') 126 | parser.add_argument('--num_workers', help='Number of workers to run in parallel', default=10, type=int) 127 | parser.add_argument('--filelists', help='List of train, val, test, pretrain files', default='./filelists/') 128 | parser.add_argument("--videos_data_root", help="Root folder of LRS", required=True) 129 | 130 | parser.add_argument("--final_data_root", help="Folder where preprocessed files will reside", 131 | required=True) 132 | 133 | ### hyperparams #### 134 | parser.add_argument("--img_size", help="Square face image to resize to", default=96, type=int) 135 | 136 | args = parser.parse_args() 137 | 138 | if __name__ == '__main__': 139 | dump_split(args) 140 | -------------------------------------------------------------------------------- /reconstruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from logger_firstorder import Logger, Visualizer 6 | import numpy as np 7 | import imageio 8 | from sync_batchnorm import DataParallelWithCallback 9 | 10 | 11 | def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset): 12 | png_dir = os.path.join(log_dir, 'reconstruction/png') 13 | log_dir = os.path.join(log_dir, 'reconstruction') 14 | 15 | if checkpoint is not None: 16 | Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) 17 | else: 18 | raise AttributeError("Checkpoint should be specified for mode='reconstruction'.") 19 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 20 | 21 | if not os.path.exists(log_dir): 22 | os.makedirs(log_dir) 23 | 24 | if not os.path.exists(png_dir): 25 | os.makedirs(png_dir) 26 | 27 | loss_list = [] 28 | if torch.cuda.is_available(): 29 | generator = DataParallelWithCallback(generator) 30 | kp_detector = DataParallelWithCallback(kp_detector) 31 | 32 | generator.eval() 33 | kp_detector.eval() 34 | 35 | for it, x in tqdm(enumerate(dataloader)): 36 | if config['reconstruction_params']['num_videos'] is not None: 37 | if it > config['reconstruction_params']['num_videos']: 38 | break 39 | with torch.no_grad(): 40 | predictions = [] 41 | visualizations = [] 42 | if torch.cuda.is_available(): 43 | x['video'] = x['video'].cuda() 44 | kp_source = kp_detector(x['video'][:, :, 0]) 45 | for frame_idx in range(x['video'].shape[2]): 46 | source = x['video'][:, :, 0] 47 | driving = x['video'][:, :, frame_idx] 48 | kp_driving = kp_detector(driving) 49 | out = generator(source, kp_source=kp_source, kp_driving=kp_driving) 50 | out['kp_source'] = kp_source 51 | out['kp_driving'] = kp_driving 52 | del out['sparse_deformed'] 53 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) 54 | 55 | visualization = Visualizer(**config['visualizer_params']).visualize(source=source, 56 | driving=driving, out=out) 57 | visualizations.append(visualization) 58 | 59 | loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy()) 60 | 61 | predictions = np.concatenate(predictions, axis=1) 62 | imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8)) 63 | 64 | image_name = x['name'][0] + config['reconstruction_params']['format'] 65 | imageio.mimsave(os.path.join(log_dir, image_name), visualizations) 66 | 67 | print("Reconstruction loss: %s" % np.mean(loss_list)) 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Package Version 2 | ------------------------ ------------ 3 | absl-py 0.11.0 4 | appdirs 1.4.4 5 | astunparse 1.6.3 6 | audioread 2.1.9 7 | cachetools 4.2.0 8 | certifi 2020.12.5 9 | cffi 1.14.4 10 | chardet 4.0.0 11 | cmake 3.18.4.post1 12 | cycler 0.10.0 13 | decorator 4.4.2 14 | dlib 19.21.1 15 | dotmap 1.3.23 16 | easygui 0.98.1 17 | gast 0.3.3 18 | google-auth 1.24.0 19 | google-auth-oauthlib 0.4.2 20 | google-pasta 0.2.0 21 | grpcio 1.34.0 22 | h5py 2.10.0 23 | idna 2.10 24 | imageio 2.9.0 25 | ISR 2.2.0 26 | joblib 1.0.0 27 | Keras 2.4.3 28 | keras-contrib 2.0.8 29 | Keras-Preprocessing 1.1.2 30 | kiwisolver 1.3.1 31 | librosa 0.8.0 32 | llvmlite 0.35.0 33 | Markdown 3.3.3 34 | matplotlib 3.3.3 35 | networkx 2.5 36 | numba 0.52.0 37 | numpy 1.19.4 38 | oauthlib 3.1.0 39 | opencv-python 4.4.0.46 40 | opt-einsum 3.3.0 41 | packaging 20.8 42 | Pillow 8.0.1 43 | pip 20.1.1 44 | pooch 1.3.0 45 | protobuf 3.14.0 46 | pyasn1 0.4.8 47 | pyasn1-modules 0.2.8 48 | pycparser 2.20 49 | pyparsing 2.4.7 50 | PyQt5 5.15.2 51 | PyQt5-sip 12.8.1 52 | python-dateutil 2.8.1 53 | PyWavelets 1.1.1 54 | PyYAML 5.3.1 55 | requests 2.25.1 56 | requests-oauthlib 1.3.0 57 | resampy 0.2.2 58 | rsa 4.6 59 | scikit-image 0.18.1 60 | scikit-learn 0.24.0 61 | scipy 1.4.1 62 | setuptools 47.1.0 63 | six 1.15.0 64 | SoundFile 0.10.3.post1 65 | tensorboard 2.2.2 66 | tensorboard-plugin-wit 1.7.0 67 | tensorflow 2.2.0 68 | tensorflow-estimator 2.2.0 69 | tensorflow-gpu 2.2.0 70 | tensorflow-gpu-estimator 2.2.0 71 | termcolor 1.1.0 72 | threadpoolctl 2.1.0 73 | tifffile 2020.12.8 74 | torch 1.7.1+cu101 75 | torchvision 0.8.2+cu101 76 | tqdm 4.55.0 77 | typing-extensions 3.7.4.3 78 | urllib3 1.26.2 79 | Werkzeug 1.0.1 80 | wheel 0.36.2 81 | wrapt 1.12.1 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('Agg') 4 | 5 | import os, sys 6 | import yaml 7 | from argparse import ArgumentParser 8 | from time import gmtime, strftime 9 | from shutil import copy 10 | 11 | from frames_dataset import FramesDataset 12 | 13 | from modules.generator import OcclusionAwareGenerator 14 | from modules.discriminator import MultiScaleDiscriminator 15 | from modules.keypoint_detector import KPDetector 16 | 17 | import torch 18 | 19 | from train_firstorder import train 20 | from reconstruction import reconstruction 21 | from animate import animate 22 | 23 | if __name__ == "__main__": 24 | 25 | if sys.version_info[0] < 3: 26 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") 27 | 28 | parser = ArgumentParser() 29 | parser.add_argument("--config", required=True, help="path to config") 30 | parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"]) 31 | parser.add_argument("--log_dir", default='log', help="path to log into") 32 | parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") 33 | parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))), 34 | help="Names of the devices comma separated.") 35 | parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture") 36 | parser.set_defaults(verbose=False) 37 | 38 | opt = parser.parse_args() 39 | with open(opt.config) as f: 40 | config = yaml.load(f) 41 | 42 | if opt.checkpoint is not None: 43 | log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) 44 | else: 45 | log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) 46 | log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime()) 47 | 48 | generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], 49 | **config['model_params']['common_params']) 50 | 51 | if torch.cuda.is_available(): 52 | generator.to(opt.device_ids[0]) 53 | if opt.verbose: 54 | print(generator) 55 | 56 | discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'], 57 | **config['model_params']['common_params']) 58 | if torch.cuda.is_available(): 59 | discriminator.to(opt.device_ids[0]) 60 | if opt.verbose: 61 | print(discriminator) 62 | 63 | kp_detector = KPDetector(**config['model_params']['kp_detector_params'], 64 | **config['model_params']['common_params']) 65 | 66 | if torch.cuda.is_available(): 67 | kp_detector.to(opt.device_ids[0]) 68 | 69 | if opt.verbose: 70 | print(kp_detector) 71 | 72 | dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params']) 73 | 74 | if not os.path.exists(log_dir): 75 | os.makedirs(log_dir) 76 | if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): 77 | copy(opt.config, log_dir) 78 | 79 | if opt.mode == 'train': 80 | print("Training...") 81 | train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.device_ids) 82 | elif opt.mode == 'reconstruction': 83 | print("Reconstruction...") 84 | reconstruction(config, generator, kp_detector, opt.checkpoint, log_dir, dataset) 85 | elif opt.mode == 'animate': 86 | print("Animate...") 87 | animate(config, generator, kp_detector, opt.checkpoint, log_dir, dataset) 88 | -------------------------------------------------------------------------------- /stylegan2/__init__.py: -------------------------------------------------------------------------------- 1 | from . import external_models 2 | from . import metrics 3 | from . import models 4 | from . import project 5 | from . import train 6 | -------------------------------------------------------------------------------- /stylegan2/external_models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import inception 2 | from . import lpips 3 | -------------------------------------------------------------------------------- /stylegan2/external_models/lpips.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/richzhang/PerceptualSimilarity 3 | 4 | Original License: 5 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | """ 29 | import torch 30 | from torch import nn 31 | import torchvision 32 | 33 | 34 | class LPIPS_VGG16(nn.Module): 35 | _FEATURE_IDX = [0, 4, 9, 16, 23, 30] 36 | _LINEAR_WEIGHTS_URL = 'https://github.com/richzhang/PerceptualSimilarity' + \ 37 | '/blob/master/models/weights/v0.1/vgg.pth?raw=true' 38 | 39 | def __init__(self, pixel_min=-1, pixel_max=1): 40 | super(LPIPS_VGG16, self).__init__() 41 | features = torchvision.models.vgg16(pretrained=True).features 42 | self.slices = nn.ModuleList() 43 | linear_weights = torch.utils.model_zoo.load_url(self._LINEAR_WEIGHTS_URL) 44 | for i in range(1, len(self._FEATURE_IDX)): 45 | idx_range = range(self._FEATURE_IDX[i - 1], self._FEATURE_IDX[i]) 46 | self.slices.append(nn.Sequential(*[features[j] for j in idx_range])) 47 | self.linear_layers = nn.ModuleList() 48 | for weight in torch.utils.model_zoo.load_url(self._LINEAR_WEIGHTS_URL).values(): 49 | weight = weight.view(1, -1) 50 | linear = nn.Linear(weight.size(1), 1, bias=False) 51 | linear.weight.data.copy_(weight) 52 | self.linear_layers.append(linear) 53 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188]).view(1, -1, 1, 1)) 54 | self.register_buffer('scale', torch.Tensor([.458,.448,.450]).view(1, -1, 1, 1)) 55 | self.pixel_min = pixel_min 56 | self.pixel_max = pixel_max 57 | self.requires_grad_(False) 58 | self.eval() 59 | 60 | def _scale(self, x): 61 | if self.pixel_min != -1 or self.pixel_max != 1: 62 | x = (2*x - self.pixel_min - self.pixel_max) \ 63 | / (self.pixel_max - self.pixel_min) 64 | return (x - self.shift) / self.scale 65 | 66 | @staticmethod 67 | def _normalize_tensor(feature_maps, eps=1e-8): 68 | rnorm = torch.rsqrt(torch.sum(feature_maps ** 2, dim=1, keepdim=True) + eps) 69 | return feature_maps * rnorm 70 | 71 | def forward(self, x0, x1, eps=1e-8): 72 | x0, x1 = self._scale(x0), self._scale(x1) 73 | dist = 0 74 | for slice, linear in zip(self.slices, self.linear_layers): 75 | x0, x1 = slice(x0), slice(x1) 76 | _x0, _x1 = self._normalize_tensor(x0, eps), self._normalize_tensor(x1, eps) 77 | dist += linear(torch.mean((_x0 - _x1) ** 2, dim=[-1, -2])) 78 | return dist.view(-1) 79 | -------------------------------------------------------------------------------- /stylegan2/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fid 2 | from . import ppl 3 | -------------------------------------------------------------------------------- /sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /temp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/singularitai/Morphling/e7a3af969123c0d3c0f3c6f1036a97e9be0b289c/temp.png -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | from text.symbols import _punctuation as punctuation_symbols 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | # for arpabet with apostrophe 15 | _apostrophe = re.compile(r"(?=\S*['])([a-zA-Z'-]+)") 16 | 17 | def text_to_sequence(text): 18 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 19 | 20 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 21 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 22 | 23 | Args: 24 | text: string to convert to a sequence 25 | cleaner_names: names of the cleaner functions to run the text through 26 | 27 | Returns: 28 | List of integers corresponding to the symbols in the text 29 | ''' 30 | sequence = [] 31 | 32 | # Check for curly braces and treat their contents as ARPAbet: 33 | while len(text): 34 | m = _curly_re.match(text) 35 | if not m: 36 | sequence += _symbols_to_sequence(text) 37 | break 38 | sequence += _symbols_to_sequence(m.group(1)) 39 | sequence += _arpabet_to_sequence(m.group(2)) 40 | text = m.group(3) 41 | 42 | return sequence 43 | 44 | 45 | def sequence_to_text(sequence): 46 | '''Converts a sequence of IDs back to a string''' 47 | result = '' 48 | for symbol_id in sequence: 49 | if symbol_id in _id_to_symbol: 50 | s = _id_to_symbol[symbol_id] 51 | # Enclose ARPAbet back in curly braces: 52 | if len(s) > 1 and s[0] == '@': 53 | s = '{%s}' % s[1:] 54 | result += s 55 | return result.replace('}{', ' ') 56 | 57 | 58 | def _clean_text(text, cleaner_names): 59 | for name in cleaner_names: 60 | cleaner = getattr(cleaners, name) 61 | if not cleaner: 62 | raise Exception('Unknown cleaner: %s' % name) 63 | text = cleaner(text) 64 | 65 | return text 66 | 67 | 68 | def _symbols_to_sequence(symbols): 69 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 70 | 71 | 72 | def _arpabet_to_sequence(text): 73 | return _symbols_to_sequence(['@' + s for s in text.split()]) 74 | 75 | 76 | def _should_keep_symbol(s): 77 | return s in _symbol_to_id and s is not '_' and s is not '~' 78 | 79 | 80 | def get_arpabet(word, cmudict, index=0): 81 | re_start_punc = r"\A\W+" 82 | re_end_punc = r"\W+\Z" 83 | 84 | start_symbols = re.findall(re_start_punc, word) 85 | if len(start_symbols): 86 | start_symbols = start_symbols[0] 87 | word = word[len(start_symbols):] 88 | else: 89 | start_symbols = '' 90 | 91 | end_symbols = re.findall(re_end_punc, word) 92 | if len(end_symbols): 93 | end_symbols = end_symbols[0] 94 | word = word[:-len(end_symbols)] 95 | else: 96 | end_symbols = '' 97 | 98 | arpabet_suffix = '' 99 | if _apostrophe.match(word) is not None and word.lower() != "it's" and word.lower()[-1] == 's': 100 | word = word[:-2] 101 | arpabet_suffix = ' Z' 102 | arpabet = None if word.lower() in HETERONYMS else cmudict.lookup(word) 103 | 104 | if arpabet is not None: 105 | return start_symbols + '{%s}' % (arpabet[index] + arpabet_suffix) + end_symbols 106 | else: 107 | return start_symbols + word + end_symbols 108 | 109 | 110 | def files_to_list(filename): 111 | """ 112 | Takes a text file of filenames and makes a list of filenames 113 | """ 114 | with open(filename, encoding='utf-8') as f: 115 | files = f.readlines() 116 | 117 | files = [f.rstrip() for f in files] 118 | return files 119 | 120 | HETERONYMS = set(files_to_list('data/heteronyms')) 121 | -------------------------------------------------------------------------------- /text/acronyms.py: -------------------------------------------------------------------------------- 1 | import re 2 | from .cmudict import CMUDict 3 | 4 | _letter_to_arpabet = { 5 | 'A': 'EY1', 6 | 'B': 'B IY1', 7 | 'C': 'S IY1', 8 | 'D': 'D IY1', 9 | 'E': 'IY1', 10 | 'F': 'EH1 F', 11 | 'G': 'JH IY1', 12 | 'H': 'EY1 CH', 13 | 'I': 'AY1', 14 | 'J': 'JH EY1', 15 | 'K': 'K EY1', 16 | 'L': 'EH1 L', 17 | 'M': 'EH1 M', 18 | 'N': 'EH1 N', 19 | 'O': 'OW1', 20 | 'P': 'P IY1', 21 | 'Q': 'K Y UW1', 22 | 'R': 'AA1 R', 23 | 'S': 'EH1 S', 24 | 'T': 'T IY1', 25 | 'U': 'Y UW1', 26 | 'V': 'V IY1', 27 | 'X': 'EH1 K S', 28 | 'Y': 'W AY1', 29 | 'W': 'D AH1 B AH0 L Y UW0', 30 | 'Z': 'Z IY1', 31 | 's': 'Z' 32 | } 33 | 34 | # must ignore roman numerals 35 | _acronym_re = re.compile(r'([A-Z][A-Z]+)s?|([A-Z]\.([A-Z]\.)+s?)') 36 | cmudict = CMUDict('data/cmudict_dictionary', keep_ambiguous=False) 37 | 38 | 39 | def _expand_acronyms(m, add_spaces=True): 40 | acronym = m.group(0) 41 | 42 | # remove dots if they exist 43 | acronym = re.sub('\.', '', acronym) 44 | 45 | acronym = "".join(acronym.split()) 46 | arpabet = cmudict.lookup(acronym) 47 | 48 | if arpabet is None: 49 | acronym = list(acronym) 50 | arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym] 51 | # temporary fix 52 | if arpabet[-1] == '{Z}' and len(arpabet) > 1: 53 | arpabet[-2] = arpabet[-2][:-1] + ' ' + arpabet[-1][1:] 54 | del arpabet[-1] 55 | 56 | arpabet = ' '.join(arpabet) 57 | else: 58 | arpabet = "{" + arpabet[0] + "}" 59 | 60 | return arpabet 61 | 62 | 63 | def normalize_acronyms(text): 64 | text = re.sub(_acronym_re, _expand_acronyms, text) 65 | return text 66 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ adapted from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from .numbers import normalize_numbers 18 | from .acronyms import normalize_acronyms 19 | from .datestime import normalize_datestime 20 | 21 | 22 | # Regular expression matching whitespace: 23 | _whitespace_re = re.compile(r'\s+') 24 | 25 | # List of (regular expression, replacement) pairs for abbreviations: 26 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 27 | ('mrs', 'misess'), 28 | ('ms', 'miss'), 29 | ('mr', 'mister'), 30 | ('dr', 'doctor'), 31 | ('st', 'saint'), 32 | ('co', 'company'), 33 | ('jr', 'junior'), 34 | ('maj', 'major'), 35 | ('gen', 'general'), 36 | ('drs', 'doctors'), 37 | ('rev', 'reverend'), 38 | ('lt', 'lieutenant'), 39 | ('hon', 'honorable'), 40 | ('sgt', 'sergeant'), 41 | ('capt', 'captain'), 42 | ('esq', 'esquire'), 43 | ('ltd', 'limited'), 44 | ('col', 'colonel'), 45 | ('ft', 'fort'), 46 | ]] 47 | 48 | _safe_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 49 | ('no', 'number'), 50 | ]] 51 | 52 | 53 | 54 | def expand_abbreviations(text): 55 | for regex, replacement in _abbreviations: 56 | text = re.sub(regex, replacement, text) 57 | return text 58 | 59 | def expand_safe_abbreviations(text): 60 | for regex, replacement in _safe_abbreviations: 61 | text = re.sub(regex, replacement, text) 62 | return text 63 | 64 | def expand_numbers(text): 65 | return normalize_numbers(text) 66 | 67 | 68 | def expand_acronyms(text): 69 | return normalize_acronyms(text) 70 | 71 | 72 | def expand_datestime(text): 73 | return normalize_datestime(text) 74 | 75 | 76 | def lowercase(text): 77 | return text.lower() 78 | 79 | 80 | def collapse_whitespace(text): 81 | return re.sub(_whitespace_re, ' ', text) 82 | 83 | 84 | def separate_acronyms(text): 85 | text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text) 86 | text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text) 87 | return text 88 | 89 | 90 | def remove_hyphens(text): 91 | text = re.sub(r'(?<=\w)(-)(?=\w)', ' ', text) 92 | return text 93 | 94 | 95 | def convert_to_ascii(text): 96 | return unidecode(text) 97 | 98 | 99 | def basic_cleaners(text): 100 | '''Basic pipeline that collapses whitespace without transliteration.''' 101 | text = lowercase(text) 102 | text = collapse_whitespace(text) 103 | return text 104 | 105 | 106 | def transliteration_cleaners(text): 107 | '''Pipeline for non-English text that transliterates to ASCII.''' 108 | text = convert_to_ascii(text) 109 | text = lowercase(text) 110 | text = collapse_whitespace(text) 111 | return text 112 | 113 | 114 | def flowtron_cleaners(text): 115 | text = collapse_whitespace(text) 116 | text = remove_hyphens(text) 117 | text = expand_datestime(text) 118 | text = expand_numbers(text) 119 | text = expand_safe_abbreviations(text) 120 | text = expand_acronyms(text) 121 | return text 122 | 123 | 124 | def english_cleaners(text): 125 | '''Pipeline for English text, with number and abbreviation expansion.''' 126 | text = convert_to_ascii(text) 127 | text = lowercase(text) 128 | text = expand_numbers(text) 129 | text = expand_abbreviations(text) 130 | text = collapse_whitespace(text) 131 | return text 132 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/datestime.py: -------------------------------------------------------------------------------- 1 | import re 2 | _ampm_re = re.compile(r'([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)') 3 | 4 | 5 | def _expand_ampm(m): 6 | matches = list(m.groups(0)) 7 | txt = matches[0] 8 | if matches[1] == 0 or matches[1] == '0' or matches[1] == '00': 9 | pass 10 | else: 11 | txt += ' ' + matches[1] 12 | 13 | if matches[2][0] == 'a': 14 | txt += ' AM' 15 | elif matches[2][0] == 'p': 16 | txt += ' PM' 17 | 18 | return txt 19 | 20 | 21 | def normalize_datestime(text): 22 | text = re.sub(_ampm_re, _expand_ampm, text) 23 | text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text) 24 | return text 25 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | _large_numbers = '(trillion|billion|million|thousand|hundred)' 6 | _measurements = '(f|c|k|d)' 7 | _measurements_key = {'f': 'fahrenheit', 'c': 'celsius', 'k': 'thousand', 'd': 'd'} 8 | _inflect = inflect.engine() 9 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 10 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 11 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 12 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+[ ]?{}?)'.format(_large_numbers), re.IGNORECASE) 13 | _measurement_re = re.compile(r'([0-9\.\,]*[0-9]+(\s)?{}\b)'.format(_measurements), re.IGNORECASE) 14 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 15 | _number_re = re.compile(r"[0-9]+'s|[0-9]+") 16 | 17 | def _remove_commas(m): 18 | return m.group(1).replace(',', '') 19 | 20 | 21 | def _expand_decimal_point(m): 22 | return m.group(1).replace('.', ' point ') 23 | 24 | 25 | def _expand_dollars(m): 26 | match = m.group(1) 27 | 28 | # check for million, billion, etc... 29 | parts = match.split(' ') 30 | if len(parts) == 2 and len(parts[1]) > 0 and parts[1] in _large_numbers: 31 | return "{} {} {} ".format(parts[0], parts[1], 'dollars') 32 | 33 | parts = parts[0].split('.') 34 | if len(parts) > 2: 35 | return match + " dollars" # Unexpected format 36 | dollars = int(parts[0]) if parts[0] else 0 37 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 38 | if dollars and cents: 39 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 40 | cent_unit = 'cent' if cents == 1 else 'cents' 41 | return "{} {}, {} {} ".format( 42 | _inflect.number_to_words(dollars), dollar_unit, 43 | _inflect.number_to_words(cents), cent_unit) 44 | elif dollars: 45 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 46 | return "{} {} ".format(_inflect.number_to_words(dollars), dollar_unit) 47 | elif cents: 48 | cent_unit = 'cent' if cents == 1 else 'cents' 49 | return "{} {} ".format(_inflect.number_to_words(cents), cent_unit) 50 | else: 51 | return 'zero dollars' 52 | 53 | 54 | def _expand_ordinal(m): 55 | return _inflect.number_to_words(m.group(0)) 56 | 57 | 58 | def _expand_measurement(m): 59 | _, number, measurement = re.split('(\d+(?:\.\d+)?)', m.group(0)) 60 | number = _inflect.number_to_words(number) 61 | measurement = "".join(measurement.split()) 62 | measurement = _measurements_key[measurement.lower()] 63 | return "{} {}".format(number, measurement) 64 | 65 | 66 | def _expand_number(m): 67 | _, number, suffix = re.split(r"(\d+(?:'\d+)?)", m.group(0)) 68 | num = int(number) 69 | if num > 1000 and num < 3000: 70 | if num == 2000: 71 | text = 'two thousand' 72 | elif num > 2000 and num < 2010: 73 | text = 'two thousand ' + _inflect.number_to_words(num % 100) 74 | elif num % 100 == 0: 75 | text = _inflect.number_to_words(num // 100) + ' hundred' 76 | else: 77 | num = _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 78 | num = re.sub(r'-', ' ', num) 79 | text = num 80 | else: 81 | num = _inflect.number_to_words(num, andword='') 82 | num = re.sub(r'-', ' ', num) 83 | num = re.sub(r',', '', num) 84 | text = num 85 | 86 | if suffix == "'s" and text[-1] == 'y': 87 | text = text[:-1] + 'ies' 88 | 89 | return text 90 | 91 | 92 | def normalize_numbers(text): 93 | text = re.sub(_comma_number_re, _remove_commas, text) 94 | text = re.sub(_pounds_re, r'\1 pounds', text) 95 | text = re.sub(_dollars_re, _expand_dollars, text) 96 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 97 | text = re.sub(_ordinal_re, _expand_ordinal, text) 98 | text = re.sub(_measurement_re, _expand_measurement, text) 99 | text = re.sub(_number_re, _expand_number, text) 100 | return text 101 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | 9 | _punctuation = '!\'",.:;? ' 10 | _math = '#%&*+-/[]()' 11 | _special = '_@©°½—₩€$' 12 | _accented = 'áçéêëñöøćž' 13 | _numbers = '0123456789' 14 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 18 | 19 | # Export all symbols: 20 | symbols = list(_punctuation + _math + _special + _accented + _numbers + _letters) + _arpabet 21 | -------------------------------------------------------------------------------- /train_firstorder.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from logger_firstorder import Logger 7 | from modules.model import GeneratorFullModel, DiscriminatorFullModel 8 | 9 | from torch.optim.lr_scheduler import MultiStepLR 10 | 11 | from sync_batchnorm import DataParallelWithCallback 12 | 13 | from frames_dataset import DatasetRepeater 14 | 15 | 16 | def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids): 17 | train_params = config['train_params'] 18 | 19 | optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) 20 | optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999)) 21 | optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999)) 22 | 23 | if checkpoint is not None: 24 | start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, 25 | optimizer_generator, optimizer_discriminator, 26 | None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector) 27 | else: 28 | start_epoch = 0 29 | 30 | scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, 31 | last_epoch=start_epoch - 1) 32 | scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, 33 | last_epoch=start_epoch - 1) 34 | scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, 35 | last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0)) 36 | 37 | if 'num_repeats' in train_params or train_params['num_repeats'] != 1: 38 | dataset = DatasetRepeater(dataset, train_params['num_repeats']) 39 | dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=6, drop_last=True) 40 | 41 | generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params) 42 | discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) 43 | 44 | if torch.cuda.is_available(): 45 | generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) 46 | discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) 47 | 48 | with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: 49 | for epoch in trange(start_epoch, train_params['num_epochs']): 50 | for x in dataloader: 51 | losses_generator, generated = generator_full(x) 52 | 53 | loss_values = [val.mean() for val in losses_generator.values()] 54 | loss = sum(loss_values) 55 | 56 | loss.backward() 57 | optimizer_generator.step() 58 | optimizer_generator.zero_grad() 59 | optimizer_kp_detector.step() 60 | optimizer_kp_detector.zero_grad() 61 | 62 | if train_params['loss_weights']['generator_gan'] != 0: 63 | optimizer_discriminator.zero_grad() 64 | losses_discriminator = discriminator_full(x, generated) 65 | loss_values = [val.mean() for val in losses_discriminator.values()] 66 | loss = sum(loss_values) 67 | 68 | loss.backward() 69 | optimizer_discriminator.step() 70 | optimizer_discriminator.zero_grad() 71 | else: 72 | losses_discriminator = {} 73 | 74 | losses_generator.update(losses_discriminator) 75 | losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} 76 | logger.log_iter(losses=losses) 77 | 78 | scheduler_generator.step() 79 | scheduler_discriminator.step() 80 | scheduler_kp_detector.step() 81 | 82 | logger.log_epoch(epoch, {'generator': generator, 83 | 'discriminator': discriminator, 84 | 'kp_detector': kp_detector, 85 | 'optimizer_generator': optimizer_generator, 86 | 'optimizer_discriminator': optimizer_discriminator, 87 | 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated) 88 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io.wavfile import read 3 | import torch 4 | 5 | 6 | def get_mask_from_lengths(lengths): 7 | max_len = torch.max(lengths).item() 8 | ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) 9 | mask = (ids < lengths.unsqueeze(1)).bool() 10 | return mask 11 | 12 | 13 | def load_wav_to_torch(full_path): 14 | sampling_rate, data = read(full_path) 15 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 16 | 17 | 18 | def load_filepaths_and_text(filename, split="|"): 19 | with open(filename, encoding='utf-8') as f: 20 | filepaths_and_text = [line.strip().split(split) for line in f] 21 | return filepaths_and_text 22 | 23 | 24 | def to_gpu(x): 25 | x = x.contiguous() 26 | 27 | if torch.cuda.is_available(): 28 | x = x.cuda(non_blocking=True) 29 | return torch.autograd.Variable(x) 30 | --------------------------------------------------------------------------------