├── .gitattributes ├── .gitignore ├── LICENSE ├── MultiSpecific.ipynb ├── README.md ├── SimSwap colab.ipynb ├── cog.yaml ├── crop_224 ├── 1_source.jpg ├── 2.jpg ├── 6.jpg ├── cage.jpg ├── dnl.jpg ├── ds.jpg ├── gdg.jpg ├── gy.jpg ├── hzc.jpg ├── hzxc.jpg ├── james.jpg ├── jl.jpg ├── lcw.jpg ├── ljm.jpg ├── ljm2.jpg ├── ljm3.jpg ├── mars2.jpg ├── mouth_open.jpg ├── mtdm.jpg ├── trump.jpg ├── wlh.jpg ├── zjl.jpg ├── zrf.jpg └── zxy.jpg ├── data └── data_loader_Swapping.py ├── demo_file ├── Iron_man.jpg ├── multi_people.jpg ├── multi_people_1080p.mp4 ├── multispecific │ ├── DST_01.jpg │ ├── DST_02.jpg │ ├── DST_03.jpg │ ├── SRC_01.png │ ├── SRC_02.png │ └── SRC_03.png ├── specific1.png ├── specific2.png └── specific3.png ├── docs ├── css │ ├── bootstrap-theme.min.css │ ├── bootstrap.min.css │ ├── ie10-viewport-bug-workaround.css │ ├── jumbotron.css │ └── page.css ├── favicon.ico ├── fonts │ ├── glyphicons-halflings-regular.eot │ ├── glyphicons-halflings-regular.svg │ ├── glyphicons-halflings-regular.ttf │ ├── glyphicons-halflings-regular.woff │ └── glyphicons-halflings-regular.woff2 ├── guidance │ ├── preparation.md │ └── usage.md ├── img │ ├── LRGT_201110059_201110091.webp │ ├── anni.webp │ ├── chenglong.webp │ ├── girl2-RGB.png │ ├── girl2.gif │ ├── id │ │ ├── Iron_man.jpg │ │ ├── anni.jpg │ │ ├── chenglong.jpg │ │ ├── wuyifan.png │ │ ├── zhoujielun.jpg │ │ └── zhuyin.jpg │ ├── logo.png │ ├── logo1.png │ ├── logo2.png │ ├── mama_mask_short.webp │ ├── mama_mask_wuyifan_short.webp │ ├── multi_face_comparison.png │ ├── new.gif │ ├── nrsig.png │ ├── result_whole_swap_multispecific_512.jpg │ ├── results1.PNG │ ├── title.png │ ├── total.PNG │ ├── vggface2_hq_compare.png │ ├── video.webp │ ├── zhoujielun.webp │ └── zhuyin.webp ├── index.html └── js │ ├── bootstrap.min.js │ ├── ie-emulation-modes-warning.js │ ├── ie10-viewport-bug-workaround.js │ ├── vendor │ └── jquery.min.js │ └── which-image.js ├── download-weights.sh ├── insightface_func ├── __init__.py ├── face_detect_crop_multi.py ├── face_detect_crop_single.py └── utils │ └── face_align_ffhqandnewarc.py ├── models ├── __init__.py ├── arcface_models.py ├── base_model.py ├── config.py ├── fs_model.py ├── fs_networks.py ├── fs_networks_512.py ├── fs_networks_fix.py ├── models.py ├── networks.py ├── pix2pixHD_model.py ├── projected_model.py ├── projectionhead.py └── ui_model.py ├── options ├── base_options.py ├── test_options.py └── train_options.py ├── output └── result.jpg ├── parsing_model ├── model.py └── resnet.py ├── pg_modules ├── blocks.py ├── diffaug.py ├── projected_discriminator.py └── projector.py ├── predict.py ├── simswaplogo ├── simswaplogo.png └── socialbook_logo.2020.357eed90add7705e54a8.svg ├── test_one_image.py ├── test_video_swap_multispecific.py ├── test_video_swapmulti.py ├── test_video_swapsingle.py ├── test_video_swapspecific.py ├── test_wholeimage_swap_multispecific.py ├── test_wholeimage_swapmulti.py ├── test_wholeimage_swapsingle.py ├── test_wholeimage_swapspecific.py ├── train.ipynb ├── train.py └── util ├── add_watermark.py ├── html.py ├── image_pool.py ├── json_config.py ├── logo_class.py ├── norm.py ├── plot.py ├── reverse2original.py ├── save_heatmap.py ├── util.py ├── videoswap.py ├── videoswap_multispecific.py ├── videoswap_specific.py └── visualizer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | docs/ppt/ 132 | checkpoints/ 133 | *.tar 134 | *.patch 135 | *.zip 136 | *.avi 137 | *.pdf 138 | *.pptx 139 | 140 | *.pth 141 | *.onnx 142 | wandb/ 143 | temp_results/ 144 | output/*.* -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | python_version: "3.8" 4 | system_packages: 5 | - "libgl1-mesa-glx" 6 | - "libglib2.0-0" 7 | python_packages: 8 | - "imageio==2.9.0" 9 | - "torch==1.8.0" 10 | - "torchvision==0.9.0" 11 | - "numpy==1.21.1" 12 | - "insightface==0.2.1" 13 | - "ipython==7.21.0" 14 | - "Pillow==8.3.1" 15 | - "opencv-python==4.5.3.56" 16 | - "Fraction==1.5.1" 17 | - "onnxruntime-gpu==1.8.1" 18 | - "moviepy==1.0.3" 19 | 20 | predict: "predict.py:Predictor" 21 | -------------------------------------------------------------------------------- /crop_224/1_source.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/1_source.jpg -------------------------------------------------------------------------------- /crop_224/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/2.jpg -------------------------------------------------------------------------------- /crop_224/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/6.jpg -------------------------------------------------------------------------------- /crop_224/cage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/cage.jpg -------------------------------------------------------------------------------- /crop_224/dnl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/dnl.jpg -------------------------------------------------------------------------------- /crop_224/ds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/ds.jpg -------------------------------------------------------------------------------- /crop_224/gdg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/gdg.jpg -------------------------------------------------------------------------------- /crop_224/gy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/gy.jpg -------------------------------------------------------------------------------- /crop_224/hzc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/hzc.jpg -------------------------------------------------------------------------------- /crop_224/hzxc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/hzxc.jpg -------------------------------------------------------------------------------- /crop_224/james.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/james.jpg -------------------------------------------------------------------------------- /crop_224/jl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/jl.jpg -------------------------------------------------------------------------------- /crop_224/lcw.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/lcw.jpg -------------------------------------------------------------------------------- /crop_224/ljm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/ljm.jpg -------------------------------------------------------------------------------- /crop_224/ljm2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/ljm2.jpg -------------------------------------------------------------------------------- /crop_224/ljm3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/ljm3.jpg -------------------------------------------------------------------------------- /crop_224/mars2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/mars2.jpg -------------------------------------------------------------------------------- /crop_224/mouth_open.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/mouth_open.jpg -------------------------------------------------------------------------------- /crop_224/mtdm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/mtdm.jpg -------------------------------------------------------------------------------- /crop_224/trump.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/trump.jpg -------------------------------------------------------------------------------- /crop_224/wlh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/wlh.jpg -------------------------------------------------------------------------------- /crop_224/zjl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/zjl.jpg -------------------------------------------------------------------------------- /crop_224/zrf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/zrf.jpg -------------------------------------------------------------------------------- /crop_224/zxy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/zxy.jpg -------------------------------------------------------------------------------- /data/data_loader_Swapping.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import random 5 | from PIL import Image 6 | from torch.utils import data 7 | from torchvision import transforms as T 8 | 9 | class data_prefetcher(): 10 | def __init__(self, loader): 11 | self.loader = loader 12 | self.dataiter = iter(loader) 13 | self.stream = torch.cuda.Stream() 14 | self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1) 15 | self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1) 16 | # With Amp, it isn't necessary to manually convert data to half. 17 | # if args.fp16: 18 | # self.mean = self.mean.half() 19 | # self.std = self.std.half() 20 | self.num_images = len(loader) 21 | self.preload() 22 | 23 | def preload(self): 24 | try: 25 | self.src_image1, self.src_image2 = next(self.dataiter) 26 | except StopIteration: 27 | self.dataiter = iter(self.loader) 28 | self.src_image1, self.src_image2 = next(self.dataiter) 29 | 30 | with torch.cuda.stream(self.stream): 31 | self.src_image1 = self.src_image1.cuda(non_blocking=True) 32 | self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std) 33 | self.src_image2 = self.src_image2.cuda(non_blocking=True) 34 | self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std) 35 | 36 | def next(self): 37 | torch.cuda.current_stream().wait_stream(self.stream) 38 | src_image1 = self.src_image1 39 | src_image2 = self.src_image2 40 | self.preload() 41 | return src_image1, src_image2 42 | 43 | def __len__(self): 44 | """Return the number of images.""" 45 | return self.num_images 46 | 47 | class SwappingDataset(data.Dataset): 48 | """Dataset class for the Artworks dataset and content dataset.""" 49 | 50 | def __init__(self, 51 | image_dir, 52 | img_transform, 53 | subffix='jpg', 54 | random_seed=1234): 55 | """Initialize and preprocess the Swapping dataset.""" 56 | self.image_dir = image_dir 57 | self.img_transform = img_transform 58 | self.subffix = subffix 59 | self.dataset = [] 60 | self.random_seed = random_seed 61 | self.preprocess() 62 | self.num_images = len(self.dataset) 63 | 64 | def preprocess(self): 65 | """Preprocess the Swapping dataset.""" 66 | print("processing Swapping dataset images...") 67 | 68 | temp_path = os.path.join(self.image_dir,'*/') 69 | pathes = glob.glob(temp_path) 70 | self.dataset = [] 71 | for dir_item in pathes: 72 | join_path = glob.glob(os.path.join(dir_item,'*.jpg')) 73 | print("processing %s"%dir_item,end='\r') 74 | temp_list = [] 75 | for item in join_path: 76 | temp_list.append(item) 77 | self.dataset.append(temp_list) 78 | random.seed(self.random_seed) 79 | random.shuffle(self.dataset) 80 | print('Finished preprocessing the Swapping dataset, total dirs number: %d...'%len(self.dataset)) 81 | 82 | def __getitem__(self, index): 83 | """Return two src domain images and two dst domain images.""" 84 | dir_tmp1 = self.dataset[index] 85 | dir_tmp1_len = len(dir_tmp1) 86 | 87 | filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] 88 | filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] 89 | image1 = self.img_transform(Image.open(filename1)) 90 | image2 = self.img_transform(Image.open(filename2)) 91 | return image1, image2 92 | 93 | def __len__(self): 94 | """Return the number of images.""" 95 | return self.num_images 96 | 97 | def GetLoader( dataset_roots, 98 | batch_size=16, 99 | dataloader_workers=8, 100 | random_seed = 1234 101 | ): 102 | """Build and return a data loader.""" 103 | 104 | num_workers = dataloader_workers 105 | data_root = dataset_roots 106 | random_seed = random_seed 107 | 108 | c_transforms = [] 109 | 110 | c_transforms.append(T.ToTensor()) 111 | c_transforms = T.Compose(c_transforms) 112 | 113 | content_dataset = SwappingDataset( 114 | data_root, 115 | c_transforms, 116 | "jpg", 117 | random_seed) 118 | content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size, 119 | drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True) 120 | prefetcher = data_prefetcher(content_data_loader) 121 | return prefetcher 122 | 123 | def denorm(x): 124 | out = (x + 1) / 2 125 | return out.clamp_(0, 1) -------------------------------------------------------------------------------- /demo_file/Iron_man.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/Iron_man.jpg -------------------------------------------------------------------------------- /demo_file/multi_people.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multi_people.jpg -------------------------------------------------------------------------------- /demo_file/multi_people_1080p.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multi_people_1080p.mp4 -------------------------------------------------------------------------------- /demo_file/multispecific/DST_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/DST_01.jpg -------------------------------------------------------------------------------- /demo_file/multispecific/DST_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/DST_02.jpg -------------------------------------------------------------------------------- /demo_file/multispecific/DST_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/DST_03.jpg -------------------------------------------------------------------------------- /demo_file/multispecific/SRC_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/SRC_01.png -------------------------------------------------------------------------------- /demo_file/multispecific/SRC_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/SRC_02.png -------------------------------------------------------------------------------- /demo_file/multispecific/SRC_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/SRC_03.png -------------------------------------------------------------------------------- /demo_file/specific1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/specific1.png -------------------------------------------------------------------------------- /demo_file/specific2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/specific2.png -------------------------------------------------------------------------------- /demo_file/specific3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/specific3.png -------------------------------------------------------------------------------- /docs/css/ie10-viewport-bug-workaround.css: -------------------------------------------------------------------------------- 1 | /*! 2 | * IE10 viewport hack for Surface/desktop Windows 8 bug 3 | * Copyright 2014-2015 Twitter, Inc. 4 | * Licensed under MIT (https://github.com/twbs/bootstrap/blob/master/LICENSE) 5 | */ 6 | 7 | /* 8 | * See the Getting Started docs for more information: 9 | * http://getbootstrap.com/getting-started/#support-ie10-width 10 | */ 11 | @-ms-viewport { width: device-width; } 12 | @-o-viewport { width: device-width; } 13 | @viewport { width: device-width; } 14 | -------------------------------------------------------------------------------- /docs/css/jumbotron.css: -------------------------------------------------------------------------------- 1 | /* Move down content because we have a fixed navbar that is 50px tall */ 2 | body { 3 | padding-top: 50px; 4 | padding-bottom: 20px; 5 | } 6 | -------------------------------------------------------------------------------- /docs/css/page.css: -------------------------------------------------------------------------------- 1 | .which-image-container { 2 | display: flex; 3 | flex-wrap: wrap; 4 | align-items: center; 5 | flex-direction: column; 6 | justify-content: space-between; 7 | height: 100%; 8 | } 9 | 10 | .which-image-container :nth-child(2) { 11 | display: flex; 12 | flex-wrap: wrap; 13 | align-items: center; 14 | flex-direction: column; 15 | justify-content: flex-end; 16 | } 17 | 18 | .which-image { 19 | display: inline-block; 20 | flex: 1; 21 | flex-basis: 45%; 22 | margin: 1% 1% 1% 1%; 23 | width: 100%; 24 | height: auto; 25 | 26 | } 27 | 28 | .which-image img { 29 | float: right; 30 | width: 100%; 31 | } 32 | 33 | .image-display2 img { 34 | float: right; 35 | width: 100%; 36 | } 37 | /* .image-display{ 38 | align-items: center; 39 | } 40 | .image-display2{ 41 | align-items: center; 42 | } */ 43 | .select-show { 44 | border-style: dashed; 45 | border-width: 2px; 46 | border-color: purple; 47 | 48 | /* padding: 2px; */ 49 | } -------------------------------------------------------------------------------- /docs/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/favicon.ico -------------------------------------------------------------------------------- /docs/fonts/glyphicons-halflings-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/fonts/glyphicons-halflings-regular.eot -------------------------------------------------------------------------------- /docs/fonts/glyphicons-halflings-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/fonts/glyphicons-halflings-regular.ttf -------------------------------------------------------------------------------- /docs/fonts/glyphicons-halflings-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/fonts/glyphicons-halflings-regular.woff -------------------------------------------------------------------------------- /docs/fonts/glyphicons-halflings-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/fonts/glyphicons-halflings-regular.woff2 -------------------------------------------------------------------------------- /docs/guidance/preparation.md: -------------------------------------------------------------------------------- 1 | 2 | # Preparation 3 | 4 | ### Installation 5 | **We highly recommand that you use Anaconda for Installation** 6 | ``` 7 | conda create -n simswap python=3.6 8 | conda activate simswap 9 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch 10 | (option): pip install --ignore-installed imageio 11 | pip install insightface==0.2.1 onnxruntime moviepy 12 | (option): pip install onnxruntime-gpu (If you want to reduce the inference time)(It will be diffcult to install onnxruntime-gpu , the specify version of onnxruntime-gpu may depends on your machine and cuda version.) 13 | ``` 14 | - ***We have now updated the prepare document. The main change gpu version of onnx is supported now. If you have configured the environment before, now use pip install onnxruntime-gpu ,You can increase the computing speed.*** 15 | - We use the face detection and alignment methods from **[insightface](https://github.com/deepinsight/insightface)** for image preprocessing. Please download the relative files and unzip them to ./insightface_func/models from [this link](https://onedrive.live.com/?authkey=%21ADJ0aAOSsc90neY&cid=4A83B6B633B029CC&id=4A83B6B633B029CC%215837&parId=4A83B6B633B029CC%215834&action=locate). 16 | - We use the face parsing from **[face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch)** for image postprocessing. Please download the relative file and place it in ./parsing_model/checkpoint from [this link](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view). 17 | - The pytorch and cuda versions above are most recommanded. They may vary. 18 | - Using insightface with different versions is not recommanded. Please use this specific version. 19 | - These settings are tested valid on both Windows and Ubuntu. 20 | 21 | ### Pretrained model 22 | There are two archive files in the drive: **checkpoints.zip** and **arcface_checkpoint.tar** 23 | 24 | - **Copy the arcface_checkpoint.tar into ./arcface_model** 25 | - **Unzip checkpoints.zip, place it in the root dir ./** 26 | 27 | [[Google Drive]](https://drive.google.com/drive/folders/1jV6_0FIMPC53FZ2HzZNJZGMe55bbu17R?usp=sharing) 28 | [[Baidu Drive]](https://pan.baidu.com/s/1wFV11RVZMHqd-ky4YpLdcA) Password: ```jd2v``` 29 | 30 | **Simswap 512 (optional)** 31 | 32 | The checkpoint of **Simswap 512 beta version** has been uploaded in [Github release](https://github.com/neuralchen/SimSwap/releases/download/512_beta/512.zip).If you want to experience Simswap 512, feel free to try. 33 | - **Unzip 512.zip, place it in the root dir ./checkpoints**. 34 | 35 | 36 | ### Note 37 | We expect users to have GPU with at least 3G memory. For those who do not, we provide [[Colab Notebook implementation]](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/SimSwap%20colab.ipynb). 38 | -------------------------------------------------------------------------------- /docs/img/LRGT_201110059_201110091.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/LRGT_201110059_201110091.webp -------------------------------------------------------------------------------- /docs/img/anni.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/anni.webp -------------------------------------------------------------------------------- /docs/img/chenglong.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/chenglong.webp -------------------------------------------------------------------------------- /docs/img/girl2-RGB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/girl2-RGB.png -------------------------------------------------------------------------------- /docs/img/girl2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/girl2.gif -------------------------------------------------------------------------------- /docs/img/id/Iron_man.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/Iron_man.jpg -------------------------------------------------------------------------------- /docs/img/id/anni.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/anni.jpg -------------------------------------------------------------------------------- /docs/img/id/chenglong.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/chenglong.jpg -------------------------------------------------------------------------------- /docs/img/id/wuyifan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/wuyifan.png -------------------------------------------------------------------------------- /docs/img/id/zhoujielun.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/zhoujielun.jpg -------------------------------------------------------------------------------- /docs/img/id/zhuyin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/zhuyin.jpg -------------------------------------------------------------------------------- /docs/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/logo.png -------------------------------------------------------------------------------- /docs/img/logo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/logo1.png -------------------------------------------------------------------------------- /docs/img/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/logo2.png -------------------------------------------------------------------------------- /docs/img/mama_mask_short.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/mama_mask_short.webp -------------------------------------------------------------------------------- /docs/img/mama_mask_wuyifan_short.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/mama_mask_wuyifan_short.webp -------------------------------------------------------------------------------- /docs/img/multi_face_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/multi_face_comparison.png -------------------------------------------------------------------------------- /docs/img/new.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/new.gif -------------------------------------------------------------------------------- /docs/img/nrsig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/nrsig.png -------------------------------------------------------------------------------- /docs/img/result_whole_swap_multispecific_512.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/result_whole_swap_multispecific_512.jpg -------------------------------------------------------------------------------- /docs/img/results1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/results1.PNG -------------------------------------------------------------------------------- /docs/img/title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/title.png -------------------------------------------------------------------------------- /docs/img/total.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/total.PNG -------------------------------------------------------------------------------- /docs/img/vggface2_hq_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/vggface2_hq_compare.png -------------------------------------------------------------------------------- /docs/img/video.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/video.webp -------------------------------------------------------------------------------- /docs/img/zhoujielun.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/zhoujielun.webp -------------------------------------------------------------------------------- /docs/img/zhuyin.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/zhuyin.webp -------------------------------------------------------------------------------- /docs/js/ie-emulation-modes-warning.js: -------------------------------------------------------------------------------- 1 | // NOTICE!! DO NOT USE ANY OF THIS JAVASCRIPT 2 | // IT'S JUST JUNK FOR OUR DOCS! 3 | // ++++++++++++++++++++++++++++++++++++++++++ 4 | /*! 5 | * Copyright 2014-2015 Twitter, Inc. 6 | * 7 | * Licensed under the Creative Commons Attribution 3.0 Unported License. For 8 | * details, see https://creativecommons.org/licenses/by/3.0/. 9 | */ 10 | // Intended to prevent false-positive bug reports about Bootstrap not working properly in old versions of IE due to folks testing using IE's unreliable emulation modes. 11 | (function () { 12 | 'use strict'; 13 | 14 | function emulatedIEMajorVersion() { 15 | var groups = /MSIE ([0-9.]+)/.exec(window.navigator.userAgent) 16 | if (groups === null) { 17 | return null 18 | } 19 | var ieVersionNum = parseInt(groups[1], 10) 20 | var ieMajorVersion = Math.floor(ieVersionNum) 21 | return ieMajorVersion 22 | } 23 | 24 | function actualNonEmulatedIEMajorVersion() { 25 | // Detects the actual version of IE in use, even if it's in an older-IE emulation mode. 26 | // IE JavaScript conditional compilation docs: https://msdn.microsoft.com/library/121hztk3%28v=vs.94%29.aspx 27 | // @cc_on docs: https://msdn.microsoft.com/library/8ka90k2e%28v=vs.94%29.aspx 28 | var jscriptVersion = new Function('/*@cc_on return @_jscript_version; @*/')() // jshint ignore:line 29 | if (jscriptVersion === undefined) { 30 | return 11 // IE11+ not in emulation mode 31 | } 32 | if (jscriptVersion < 9) { 33 | return 8 // IE8 (or lower; haven't tested on IE<8) 34 | } 35 | return jscriptVersion // IE9 or IE10 in any mode, or IE11 in non-IE11 mode 36 | } 37 | 38 | var ua = window.navigator.userAgent 39 | if (ua.indexOf('Opera') > -1 || ua.indexOf('Presto') > -1) { 40 | return // Opera, which might pretend to be IE 41 | } 42 | var emulated = emulatedIEMajorVersion() 43 | if (emulated === null) { 44 | return // Not IE 45 | } 46 | var nonEmulated = actualNonEmulatedIEMajorVersion() 47 | 48 | if (emulated !== nonEmulated) { 49 | window.alert('WARNING: You appear to be using IE' + nonEmulated + ' in IE' + emulated + ' emulation mode.\nIE emulation modes can behave significantly differently from ACTUAL older versions of IE.\nPLEASE DON\'T FILE BOOTSTRAP BUGS based on testing in IE emulation modes!') 50 | } 51 | })(); 52 | -------------------------------------------------------------------------------- /docs/js/ie10-viewport-bug-workaround.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * IE10 viewport hack for Surface/desktop Windows 8 bug 3 | * Copyright 2014-2015 Twitter, Inc. 4 | * Licensed under MIT (https://github.com/twbs/bootstrap/blob/master/LICENSE) 5 | */ 6 | 7 | // See the Getting Started docs for more information: 8 | // http://getbootstrap.com/getting-started/#support-ie10-width 9 | 10 | (function () { 11 | 'use strict'; 12 | 13 | if (navigator.userAgent.match(/IEMobile\/10\.0/)) { 14 | var msViewportStyle = document.createElement('style') 15 | msViewportStyle.appendChild( 16 | document.createTextNode( 17 | '@-ms-viewport{width:auto!important}' 18 | ) 19 | ) 20 | document.querySelector('head').appendChild(msViewportStyle) 21 | } 22 | 23 | })(); 24 | -------------------------------------------------------------------------------- /docs/js/which-image.js: -------------------------------------------------------------------------------- 1 | /* 2 | * @FilePath: \SimSwap\docs\js\which-image.js 3 | * @Author: Ziang Liu 4 | * @Date: 2021-07-03 16:34:56 5 | * @LastEditors: AceSix 6 | * @LastEditTime: 2021-07-20 00:46:27 7 | * Copyright (C) 2021 SJTU. All rights reserved. 8 | */ 9 | 10 | 11 | 12 | function select_source(number) { 13 | var items = ['anni', 'chenglong', 'zhoujielun', 'zhuyin']; 14 | var item_id = items[number]; 15 | 16 | for (i = 0; i < 4; i++) { 17 | if (number == i) { 18 | document.getElementById(items[i]).style.borderWidth = '5px'; 19 | document.getElementById(items[i]).style.borderColor = 'red'; 20 | document.getElementById(items[i]).style.borderStyle = 'outset'; 21 | } else { 22 | document.getElementById(items[i]).style.border = 'none'; 23 | } 24 | } 25 | document.getElementById('jiroujinlun').src = './img/' + item_id + '.webp'; 26 | 27 | } 28 | 29 | function select_source2(number) { 30 | var items = ['Iron_man', 'wuyifan']; 31 | var item_id = items[number]; 32 | 33 | for (i = 0; i < 2; i++) { 34 | if (number == i) { 35 | document.getElementById(items[i]).style.borderWidth = '5px'; 36 | document.getElementById(items[i]).style.borderColor = 'red'; 37 | document.getElementById(items[i]).style.borderStyle = 'outset'; 38 | } else { 39 | document.getElementById(items[i]).style.border = 'none'; 40 | } 41 | } 42 | if (item_id=='Iron_man'){ 43 | document.getElementById('mama').src = './img/mama_mask_short.webp'; 44 | } 45 | else{ 46 | document.getElementById('mama').src = './img/mama_mask_wuyifan_short.webp'; 47 | 48 | } 49 | 50 | } -------------------------------------------------------------------------------- /download-weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | wget -P ./arcface_model https://github.com/neuralchen/SimSwap/releases/download/1.0/arcface_checkpoint.tar 3 | wget https://github.com/neuralchen/SimSwap/releases/download/1.0/checkpoints.zip 4 | unzip ./checkpoints.zip -d ./checkpoints 5 | rm checkpoints.zip 6 | wget --no-check-certificate "https://sh23tw.dm.files.1drv.com/y4mmGiIkNVigkSwOKDcV3nwMJulRGhbtHdkheehR5TArc52UjudUYNXAEvKCii2O5LAmzGCGK6IfleocxuDeoKxDZkNzDRSt4ZUlEt8GlSOpCXAFEkBwaZimtWGDRbpIGpb_pz9Nq5jATBQpezBS6G_UtspWTkgrXHHxhviV2nWy8APPx134zOZrUIbkSF6xnsqzs3uZ_SEX_m9Rey0ykpx9w" -O antelope.zip 7 | mkdir -p insightface_func/models 8 | unzip ./antelope.zip -d ./insightface_func/models/ 9 | rm antelope.zip 10 | -------------------------------------------------------------------------------- /insightface_func/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/insightface_func/__init__.py -------------------------------------------------------------------------------- /insightface_func/face_detect_crop_multi.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 16:45:41 7 | Description: 8 | ''' 9 | from __future__ import division 10 | import collections 11 | import numpy as np 12 | import glob 13 | import os 14 | import os.path as osp 15 | import cv2 16 | from insightface.model_zoo import model_zoo 17 | from insightface_func.utils import face_align_ffhqandnewarc as face_align 18 | 19 | __all__ = ['Face_detect_crop', 'Face'] 20 | 21 | Face = collections.namedtuple('Face', [ 22 | 'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age', 23 | 'embedding_norm', 'normed_embedding', 24 | 'landmark' 25 | ]) 26 | 27 | Face.__new__.__defaults__ = (None, ) * len(Face._fields) 28 | 29 | 30 | class Face_detect_crop: 31 | def __init__(self, name, root='~/.insightface_func/models'): 32 | self.models = {} 33 | root = os.path.expanduser(root) 34 | onnx_files = glob.glob(osp.join(root, name, '*.onnx')) 35 | onnx_files = sorted(onnx_files) 36 | for onnx_file in onnx_files: 37 | if onnx_file.find('_selfgen_')>0: 38 | #print('ignore:', onnx_file) 39 | continue 40 | model = model_zoo.get_model(onnx_file) 41 | if model.taskname not in self.models: 42 | print('find model:', onnx_file, model.taskname) 43 | self.models[model.taskname] = model 44 | else: 45 | print('duplicated model task type, ignore:', onnx_file, model.taskname) 46 | del model 47 | assert 'detection' in self.models 48 | self.det_model = self.models['detection'] 49 | 50 | 51 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640), mode ='None'): 52 | self.det_thresh = det_thresh 53 | self.mode = mode 54 | assert det_size is not None 55 | print('set det-size:', det_size) 56 | self.det_size = det_size 57 | for taskname, model in self.models.items(): 58 | if taskname=='detection': 59 | model.prepare(ctx_id, input_size=det_size) 60 | else: 61 | model.prepare(ctx_id) 62 | 63 | def get(self, img, crop_size, max_num=0): 64 | bboxes, kpss = self.det_model.detect(img, 65 | threshold=self.det_thresh, 66 | max_num=max_num, 67 | metric='default') 68 | if bboxes.shape[0] == 0: 69 | return None 70 | ret = [] 71 | # for i in range(bboxes.shape[0]): 72 | # bbox = bboxes[i, 0:4] 73 | # det_score = bboxes[i, 4] 74 | # kps = None 75 | # if kpss is not None: 76 | # kps = kpss[i] 77 | # M, _ = face_align.estimate_norm(kps, crop_size, mode ='None') 78 | # align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 79 | align_img_list = [] 80 | M_list = [] 81 | for i in range(bboxes.shape[0]): 82 | kps = None 83 | if kpss is not None: 84 | kps = kpss[i] 85 | M, _ = face_align.estimate_norm(kps, crop_size, mode = self.mode) 86 | align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 87 | align_img_list.append(align_img) 88 | M_list.append(M) 89 | 90 | # det_score = bboxes[..., 4] 91 | 92 | # best_index = np.argmax(det_score) 93 | 94 | # kps = None 95 | # if kpss is not None: 96 | # kps = kpss[best_index] 97 | # M, _ = face_align.estimate_norm(kps, crop_size, mode ='None') 98 | # align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 99 | 100 | return align_img_list, M_list 101 | -------------------------------------------------------------------------------- /insightface_func/face_detect_crop_single.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 16:46:04 7 | Description: 8 | ''' 9 | from __future__ import division 10 | import collections 11 | import numpy as np 12 | import glob 13 | import os 14 | import os.path as osp 15 | import cv2 16 | from insightface.model_zoo import model_zoo 17 | from insightface_func.utils import face_align_ffhqandnewarc as face_align 18 | 19 | __all__ = ['Face_detect_crop', 'Face'] 20 | 21 | Face = collections.namedtuple('Face', [ 22 | 'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age', 23 | 'embedding_norm', 'normed_embedding', 24 | 'landmark' 25 | ]) 26 | 27 | Face.__new__.__defaults__ = (None, ) * len(Face._fields) 28 | 29 | 30 | class Face_detect_crop: 31 | def __init__(self, name, root='~/.insightface_func/models'): 32 | self.models = {} 33 | root = os.path.expanduser(root) 34 | onnx_files = glob.glob(osp.join(root, name, '*.onnx')) 35 | onnx_files = sorted(onnx_files) 36 | for onnx_file in onnx_files: 37 | if onnx_file.find('_selfgen_')>0: 38 | #print('ignore:', onnx_file) 39 | continue 40 | model = model_zoo.get_model(onnx_file) 41 | if model.taskname not in self.models: 42 | print('find model:', onnx_file, model.taskname) 43 | self.models[model.taskname] = model 44 | else: 45 | print('duplicated model task type, ignore:', onnx_file, model.taskname) 46 | del model 47 | assert 'detection' in self.models 48 | self.det_model = self.models['detection'] 49 | 50 | 51 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640), mode ='None'): 52 | self.det_thresh = det_thresh 53 | self.mode = mode 54 | assert det_size is not None 55 | print('set det-size:', det_size) 56 | self.det_size = det_size 57 | for taskname, model in self.models.items(): 58 | if taskname=='detection': 59 | model.prepare(ctx_id, input_size=det_size) 60 | else: 61 | model.prepare(ctx_id) 62 | 63 | def get(self, img, crop_size, max_num=0): 64 | bboxes, kpss = self.det_model.detect(img, 65 | threshold=self.det_thresh, 66 | max_num=max_num, 67 | metric='default') 68 | if bboxes.shape[0] == 0: 69 | return None 70 | # ret = [] 71 | # for i in range(bboxes.shape[0]): 72 | # bbox = bboxes[i, 0:4] 73 | # det_score = bboxes[i, 4] 74 | # kps = None 75 | # if kpss is not None: 76 | # kps = kpss[i] 77 | # M, _ = face_align.estimate_norm(kps, crop_size, mode ='None') 78 | # align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 79 | # for i in range(bboxes.shape[0]): 80 | # kps = None 81 | # if kpss is not None: 82 | # kps = kpss[i] 83 | # M, _ = face_align.estimate_norm(kps, crop_size, mode ='None') 84 | # align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 85 | 86 | det_score = bboxes[..., 4] 87 | 88 | # select the face with the hightest detection score 89 | best_index = np.argmax(det_score) 90 | 91 | kps = None 92 | if kpss is not None: 93 | kps = kpss[best_index] 94 | M, _ = face_align.estimate_norm(kps, crop_size, mode = self.mode) 95 | align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 96 | 97 | return [align_img], [M] 98 | -------------------------------------------------------------------------------- /insightface_func/utils/face_align_ffhqandnewarc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-15 19:42:42 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-15 20:01:47 7 | Description: 8 | ''' 9 | 10 | import cv2 11 | import numpy as np 12 | from skimage import transform as trans 13 | 14 | src1 = np.array([[51.642, 50.115], [57.617, 49.990], [35.740, 69.007], 15 | [51.157, 89.050], [57.025, 89.702]], 16 | dtype=np.float32) 17 | #<--left 18 | src2 = np.array([[45.031, 50.118], [65.568, 50.872], [39.677, 68.111], 19 | [45.177, 86.190], [64.246, 86.758]], 20 | dtype=np.float32) 21 | 22 | #---frontal 23 | src3 = np.array([[39.730, 51.138], [72.270, 51.138], [56.000, 68.493], 24 | [42.463, 87.010], [69.537, 87.010]], 25 | dtype=np.float32) 26 | 27 | #-->right 28 | src4 = np.array([[46.845, 50.872], [67.382, 50.118], [72.737, 68.111], 29 | [48.167, 86.758], [67.236, 86.190]], 30 | dtype=np.float32) 31 | 32 | #-->right profile 33 | src5 = np.array([[54.796, 49.990], [60.771, 50.115], [76.673, 69.007], 34 | [55.388, 89.702], [61.257, 89.050]], 35 | dtype=np.float32) 36 | 37 | src = np.array([src1, src2, src3, src4, src5]) 38 | src_map = src 39 | 40 | ffhq_src = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], 41 | [201.26117, 371.41043], [313.08905, 371.15118]]) 42 | ffhq_src = np.expand_dims(ffhq_src, axis=0) 43 | 44 | # arcface_src = np.array( 45 | # [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], 46 | # [41.5493, 92.3655], [70.7299, 92.2041]], 47 | # dtype=np.float32) 48 | 49 | # arcface_src = np.expand_dims(arcface_src, axis=0) 50 | 51 | # In[66]: 52 | 53 | 54 | # lmk is prediction; src is template 55 | def estimate_norm(lmk, image_size=112, mode='ffhq'): 56 | assert lmk.shape == (5, 2) 57 | tform = trans.SimilarityTransform() 58 | lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1) 59 | min_M = [] 60 | min_index = [] 61 | min_error = float('inf') 62 | if mode == 'ffhq': 63 | # assert image_size == 112 64 | src = ffhq_src * image_size / 512 65 | else: 66 | src = src_map * image_size / 112 67 | for i in np.arange(src.shape[0]): 68 | tform.estimate(lmk, src[i]) 69 | M = tform.params[0:2, :] 70 | results = np.dot(M, lmk_tran.T) 71 | results = results.T 72 | error = np.sum(np.sqrt(np.sum((results - src[i])**2, axis=1))) 73 | # print(error) 74 | if error < min_error: 75 | min_error = error 76 | min_M = M 77 | min_index = i 78 | return min_M, min_index 79 | 80 | 81 | def norm_crop(img, landmark, image_size=112, mode='ffhq'): 82 | if mode == 'Both': 83 | M_None, _ = estimate_norm(landmark, image_size, mode = 'newarc') 84 | M_ffhq, _ = estimate_norm(landmark, image_size, mode='ffhq') 85 | warped_None = cv2.warpAffine(img, M_None, (image_size, image_size), borderValue=0.0) 86 | warped_ffhq = cv2.warpAffine(img, M_ffhq, (image_size, image_size), borderValue=0.0) 87 | return warped_ffhq, warped_None 88 | else: 89 | M, pose_index = estimate_norm(landmark, image_size, mode) 90 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) 91 | return warped 92 | 93 | def square_crop(im, S): 94 | if im.shape[0] > im.shape[1]: 95 | height = S 96 | width = int(float(im.shape[1]) / im.shape[0] * S) 97 | scale = float(S) / im.shape[0] 98 | else: 99 | width = S 100 | height = int(float(im.shape[0]) / im.shape[1] * S) 101 | scale = float(S) / im.shape[1] 102 | resized_im = cv2.resize(im, (width, height)) 103 | det_im = np.zeros((S, S, 3), dtype=np.uint8) 104 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im 105 | return det_im, scale 106 | 107 | 108 | def transform(data, center, output_size, scale, rotation): 109 | scale_ratio = scale 110 | rot = float(rotation) * np.pi / 180.0 111 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) 112 | t1 = trans.SimilarityTransform(scale=scale_ratio) 113 | cx = center[0] * scale_ratio 114 | cy = center[1] * scale_ratio 115 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) 116 | t3 = trans.SimilarityTransform(rotation=rot) 117 | t4 = trans.SimilarityTransform(translation=(output_size / 2, 118 | output_size / 2)) 119 | t = t1 + t2 + t3 + t4 120 | M = t.params[0:2] 121 | cropped = cv2.warpAffine(data, 122 | M, (output_size, output_size), 123 | borderValue=0.0) 124 | return cropped, M 125 | 126 | 127 | def trans_points2d(pts, M): 128 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 129 | for i in range(pts.shape[0]): 130 | pt = pts[i] 131 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 132 | new_pt = np.dot(M, new_pt) 133 | #print('new_pt', new_pt.shape, new_pt) 134 | new_pts[i] = new_pt[0:2] 135 | 136 | return new_pts 137 | 138 | 139 | def trans_points3d(pts, M): 140 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) 141 | #print(scale) 142 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 143 | for i in range(pts.shape[0]): 144 | pt = pts[i] 145 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 146 | new_pt = np.dot(M, new_pt) 147 | #print('new_pt', new_pt.shape, new_pt) 148 | new_pts[i][0:2] = new_pt[0:2] 149 | new_pts[i][2] = pts[i][2] * scale 150 | 151 | return new_pts 152 | 153 | 154 | def trans_points(pts, M): 155 | if pts.shape[1] == 2: 156 | return trans_points2d(pts, M) 157 | else: 158 | return trans_points3d(pts, M) 159 | 160 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .arcface_models import ArcMarginModel 2 | from .arcface_models import ResNet 3 | from .arcface_models import IRBlock 4 | from .arcface_models import SEBlock -------------------------------------------------------------------------------- /models/arcface_models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.nn import Parameter 6 | from .config import device, num_classes 7 | 8 | 9 | 10 | class SEBlock(nn.Module): 11 | def __init__(self, channel, reduction=16): 12 | super(SEBlock, self).__init__() 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.fc = nn.Sequential( 15 | nn.Linear(channel, channel // reduction), 16 | nn.PReLU(), 17 | nn.Linear(channel // reduction, channel), 18 | nn.Sigmoid() 19 | ) 20 | 21 | def forward(self, x): 22 | b, c, _, _ = x.size() 23 | y = self.avg_pool(x).view(b, c) 24 | y = self.fc(y).view(b, c, 1, 1) 25 | return x * y 26 | 27 | 28 | class IRBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): 32 | super(IRBlock, self).__init__() 33 | self.bn0 = nn.BatchNorm2d(inplanes) 34 | self.conv1 = conv3x3(inplanes, inplanes) 35 | self.bn1 = nn.BatchNorm2d(inplanes) 36 | self.prelu = nn.PReLU() 37 | self.conv2 = conv3x3(inplanes, planes, stride) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | self.use_se = use_se 42 | if self.use_se: 43 | self.se = SEBlock(planes) 44 | 45 | def forward(self, x): 46 | residual = x 47 | out = self.bn0(x) 48 | out = self.conv1(out) 49 | out = self.bn1(out) 50 | out = self.prelu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | if self.use_se: 55 | out = self.se(out) 56 | 57 | if self.downsample is not None: 58 | residual = self.downsample(x) 59 | 60 | out += residual 61 | out = self.prelu(out) 62 | 63 | return out 64 | 65 | 66 | class ResNet(nn.Module): 67 | 68 | def __init__(self, block, layers, use_se=True): 69 | self.inplanes = 64 70 | self.use_se = use_se 71 | super(ResNet, self).__init__() 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.prelu = nn.PReLU() 75 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 76 | self.layer1 = self._make_layer(block, 64, layers[0]) 77 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 78 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 79 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 80 | self.bn2 = nn.BatchNorm2d(512) 81 | self.dropout = nn.Dropout() 82 | self.fc = nn.Linear(512 * 7 * 7, 512) 83 | self.bn3 = nn.BatchNorm1d(512) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | nn.init.xavier_normal_(m.weight) 88 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 89 | nn.init.constant_(m.weight, 1) 90 | nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, nn.Linear): 92 | nn.init.xavier_normal_(m.weight) 93 | nn.init.constant_(m.bias, 0) 94 | 95 | def _make_layer(self, block, planes, blocks, stride=1): 96 | downsample = None 97 | if stride != 1 or self.inplanes != planes * block.expansion: 98 | downsample = nn.Sequential( 99 | nn.Conv2d(self.inplanes, planes * block.expansion, 100 | kernel_size=1, stride=stride, bias=False), 101 | nn.BatchNorm2d(planes * block.expansion), 102 | ) 103 | 104 | layers = [] 105 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) 106 | self.inplanes = planes 107 | for i in range(1, blocks): 108 | layers.append(block(self.inplanes, planes, use_se=self.use_se)) 109 | 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x): 113 | x = self.conv1(x) 114 | x = self.bn1(x) 115 | x = self.prelu(x) 116 | x = self.maxpool(x) 117 | 118 | x = self.layer1(x) 119 | x = self.layer2(x) 120 | x = self.layer3(x) 121 | x = self.layer4(x) 122 | 123 | x = self.bn2(x) 124 | x = self.dropout(x) 125 | # feature = x 126 | x = x.view(x.size(0), -1) 127 | x = self.fc(x) 128 | x = self.bn3(x) 129 | 130 | return x 131 | 132 | 133 | class ArcMarginModel(nn.Module): 134 | def __init__(self, args): 135 | super(ArcMarginModel, self).__init__() 136 | 137 | self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size)) 138 | nn.init.xavier_uniform_(self.weight) 139 | 140 | self.easy_margin = args.easy_margin 141 | self.m = args.margin_m 142 | self.s = args.margin_s 143 | 144 | self.cos_m = math.cos(self.m) 145 | self.sin_m = math.sin(self.m) 146 | self.th = math.cos(math.pi - self.m) 147 | self.mm = math.sin(math.pi - self.m) * self.m 148 | 149 | def forward(self, input, label): 150 | x = F.normalize(input) 151 | W = F.normalize(self.weight) 152 | cosine = F.linear(x, W) 153 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 154 | phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) 155 | if self.easy_margin: 156 | phi = torch.where(cosine > 0, phi, cosine) 157 | else: 158 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 159 | one_hot = torch.zeros(cosine.size(), device=device) 160 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 161 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 162 | output *= self.s 163 | return output -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | 5 | class BaseModel(torch.nn.Module): 6 | def name(self): 7 | return 'BaseModel' 8 | 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.gpu_ids = opt.gpu_ids 12 | self.isTrain = opt.isTrain 13 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 15 | 16 | def set_input(self, input): 17 | self.input = input 18 | 19 | def forward(self): 20 | pass 21 | 22 | # used in test time, no backprop 23 | def test(self): 24 | pass 25 | 26 | def get_image_paths(self): 27 | pass 28 | 29 | def optimize_parameters(self): 30 | pass 31 | 32 | def get_current_visuals(self): 33 | return self.input 34 | 35 | def get_current_errors(self): 36 | return {} 37 | 38 | def save(self, label): 39 | pass 40 | 41 | # helper saving function that can be used by subclasses 42 | def save_network(self, network, network_label, epoch_label, gpu_ids=None): 43 | save_filename = '{}_net_{}.pth'.format(epoch_label, network_label) 44 | save_path = os.path.join(self.save_dir, save_filename) 45 | torch.save(network.cpu().state_dict(), save_path) 46 | if torch.cuda.is_available(): 47 | network.cuda() 48 | 49 | def save_optim(self, network, network_label, epoch_label, gpu_ids=None): 50 | save_filename = '{}_optim_{}.pth'.format(epoch_label, network_label) 51 | save_path = os.path.join(self.save_dir, save_filename) 52 | torch.save(network.state_dict(), save_path) 53 | 54 | 55 | # helper loading function that can be used by subclasses 56 | def load_network(self, network, network_label, epoch_label, save_dir=''): 57 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 58 | if not save_dir: 59 | save_dir = self.save_dir 60 | save_path = os.path.join(save_dir, save_filename) 61 | if not os.path.isfile(save_path): 62 | print('%s not exists yet!' % save_path) 63 | if network_label == 'G': 64 | raise('Generator must exist!') 65 | else: 66 | #network.load_state_dict(torch.load(save_path)) 67 | try: 68 | network.load_state_dict(torch.load(save_path)) 69 | except: 70 | pretrained_dict = torch.load(save_path) 71 | model_dict = network.state_dict() 72 | try: 73 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 74 | network.load_state_dict(pretrained_dict) 75 | if self.opt.verbose: 76 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) 77 | except: 78 | print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) 79 | for k, v in pretrained_dict.items(): 80 | if v.size() == model_dict[k].size(): 81 | model_dict[k] = v 82 | 83 | if sys.version_info >= (3,0): 84 | not_initialized = set() 85 | else: 86 | from sets import Set 87 | not_initialized = Set() 88 | 89 | for k, v in model_dict.items(): 90 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): 91 | not_initialized.add(k.split('.')[0]) 92 | 93 | print(sorted(not_initialized)) 94 | network.load_state_dict(model_dict) 95 | 96 | # helper loading function that can be used by subclasses 97 | def load_optim(self, network, network_label, epoch_label, save_dir=''): 98 | save_filename = '%s_optim_%s.pth' % (epoch_label, network_label) 99 | if not save_dir: 100 | save_dir = self.save_dir 101 | save_path = os.path.join(save_dir, save_filename) 102 | if not os.path.isfile(save_path): 103 | print('%s not exists yet!' % save_path) 104 | if network_label == 'G': 105 | raise('Generator must exist!') 106 | else: 107 | #network.load_state_dict(torch.load(save_path)) 108 | try: 109 | network.load_state_dict(torch.load(save_path, map_location=torch.device("cpu"))) 110 | except: 111 | pretrained_dict = torch.load(save_path, map_location=torch.device("cpu")) 112 | model_dict = network.state_dict() 113 | try: 114 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 115 | network.load_state_dict(pretrained_dict) 116 | if self.opt.verbose: 117 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) 118 | except: 119 | print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) 120 | for k, v in pretrained_dict.items(): 121 | if v.size() == model_dict[k].size(): 122 | model_dict[k] = v 123 | 124 | if sys.version_info >= (3,0): 125 | not_initialized = set() 126 | else: 127 | from sets import Set 128 | not_initialized = Set() 129 | 130 | for k, v in model_dict.items(): 131 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): 132 | not_initialized.add(k.split('.')[0]) 133 | 134 | print(sorted(not_initialized)) 135 | network.load_state_dict(model_dict) 136 | 137 | def update_learning_rate(): 138 | pass 139 | -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors 6 | 7 | # Model parameters 8 | image_w = 112 9 | image_h = 112 10 | channel = 3 11 | emb_size = 512 12 | 13 | # Training parameters 14 | num_workers = 1 # for data-loading; right now, only 1 works with h5py 15 | grad_clip = 5. # clip gradients at an absolute value of 16 | print_freq = 100 # print training/validation stats every __ batches 17 | checkpoint = None # path to checkpoint, None if none 18 | 19 | # Data parameters 20 | num_classes = 93431 21 | num_samples = 5179510 22 | DATA_DIR = 'data' 23 | # faces_ms1m_folder = 'data/faces_ms1m_112x112' 24 | faces_ms1m_folder = 'data/ms1m-retinaface-t1' 25 | path_imgidx = os.path.join(faces_ms1m_folder, 'train.idx') 26 | path_imgrec = os.path.join(faces_ms1m_folder, 'train.rec') 27 | IMG_DIR = 'data/images' 28 | pickle_file = 'data/faces_ms1m_112x112.pickle' 29 | -------------------------------------------------------------------------------- /models/fs_networks_fix.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class InstanceNorm(nn.Module): 11 | def __init__(self, epsilon=1e-8): 12 | """ 13 | @notice: avoid in-place ops. 14 | https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 15 | """ 16 | super(InstanceNorm, self).__init__() 17 | self.epsilon = epsilon 18 | 19 | def forward(self, x): 20 | x = x - torch.mean(x, (2, 3), True) 21 | tmp = torch.mul(x, x) # or x ** 2 22 | tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) 23 | return x * tmp 24 | 25 | class ApplyStyle(nn.Module): 26 | """ 27 | @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb 28 | """ 29 | def __init__(self, latent_size, channels): 30 | super(ApplyStyle, self).__init__() 31 | self.linear = nn.Linear(latent_size, channels * 2) 32 | 33 | def forward(self, x, latent): 34 | style = self.linear(latent) # style => [batch_size, n_channels*2] 35 | shape = [-1, 2, x.size(1), 1, 1] 36 | style = style.view(shape) # [batch_size, 2, n_channels, ...] 37 | #x = x * (style[:, 0] + 1.) + style[:, 1] 38 | x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1 39 | return x 40 | 41 | class ResnetBlock_Adain(nn.Module): 42 | def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): 43 | super(ResnetBlock_Adain, self).__init__() 44 | 45 | p = 0 46 | conv1 = [] 47 | if padding_type == 'reflect': 48 | conv1 += [nn.ReflectionPad2d(1)] 49 | elif padding_type == 'replicate': 50 | conv1 += [nn.ReplicationPad2d(1)] 51 | elif padding_type == 'zero': 52 | p = 1 53 | else: 54 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 55 | conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] 56 | self.conv1 = nn.Sequential(*conv1) 57 | self.style1 = ApplyStyle(latent_size, dim) 58 | self.act1 = activation 59 | 60 | p = 0 61 | conv2 = [] 62 | if padding_type == 'reflect': 63 | conv2 += [nn.ReflectionPad2d(1)] 64 | elif padding_type == 'replicate': 65 | conv2 += [nn.ReplicationPad2d(1)] 66 | elif padding_type == 'zero': 67 | p = 1 68 | else: 69 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 70 | conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] 71 | self.conv2 = nn.Sequential(*conv2) 72 | self.style2 = ApplyStyle(latent_size, dim) 73 | 74 | 75 | def forward(self, x, dlatents_in_slice): 76 | y = self.conv1(x) 77 | y = self.style1(y, dlatents_in_slice) 78 | y = self.act1(y) 79 | y = self.conv2(y) 80 | y = self.style2(y, dlatents_in_slice) 81 | out = x + y 82 | return out 83 | 84 | 85 | 86 | class Generator_Adain_Upsample(nn.Module): 87 | def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False, 88 | norm_layer=nn.BatchNorm2d, 89 | padding_type='reflect'): 90 | assert (n_blocks >= 0) 91 | super(Generator_Adain_Upsample, self).__init__() 92 | 93 | activation = nn.ReLU(True) 94 | 95 | self.deep = deep 96 | 97 | self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0), 98 | norm_layer(64), activation) 99 | ### downsample 100 | self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 101 | norm_layer(128), activation) 102 | self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), 103 | norm_layer(256), activation) 104 | self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), 105 | norm_layer(512), activation) 106 | 107 | if self.deep: 108 | self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), 109 | norm_layer(512), activation) 110 | 111 | ### resnet blocks 112 | BN = [] 113 | for i in range(n_blocks): 114 | BN += [ 115 | ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)] 116 | self.BottleNeck = nn.Sequential(*BN) 117 | 118 | if self.deep: 119 | self.up4 = nn.Sequential( 120 | nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), 121 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), 122 | nn.BatchNorm2d(512), activation 123 | ) 124 | self.up3 = nn.Sequential( 125 | nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), 126 | nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), 127 | nn.BatchNorm2d(256), activation 128 | ) 129 | self.up2 = nn.Sequential( 130 | nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), 131 | nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), 132 | nn.BatchNorm2d(128), activation 133 | ) 134 | self.up1 = nn.Sequential( 135 | nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), 136 | nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), 137 | nn.BatchNorm2d(64), activation 138 | ) 139 | self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0)) 140 | 141 | def forward(self, input, dlatents): 142 | x = input # 3*224*224 143 | 144 | skip1 = self.first_layer(x) 145 | skip2 = self.down1(skip1) 146 | skip3 = self.down2(skip2) 147 | if self.deep: 148 | skip4 = self.down3(skip3) 149 | x = self.down4(skip4) 150 | else: 151 | x = self.down3(skip3) 152 | bot = [] 153 | bot.append(x) 154 | features = [] 155 | for i in range(len(self.BottleNeck)): 156 | x = self.BottleNeck[i](x, dlatents) 157 | bot.append(x) 158 | 159 | if self.deep: 160 | x = self.up4(x) 161 | features.append(x) 162 | x = self.up3(x) 163 | features.append(x) 164 | x = self.up2(x) 165 | features.append(x) 166 | x = self.up1(x) 167 | features.append(x) 168 | x = self.last_layer(x) 169 | # x = (x + 1) / 2 170 | 171 | # return x, bot, features, dlatents 172 | return x -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.nn import Parameter 6 | from .config import device, num_classes 7 | 8 | 9 | def create_model(opt): 10 | #from .pix2pixHD_model import Pix2PixHDModel, InferenceModel 11 | from .fs_model import fsModel 12 | model = fsModel() 13 | 14 | model.initialize(opt) 15 | if opt.verbose: 16 | print("model [%s] was created" % (model.name())) 17 | 18 | if opt.isTrain and len(opt.gpu_ids) and not opt.fp16: 19 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 20 | 21 | return model 22 | 23 | 24 | 25 | class SEBlock(nn.Module): 26 | def __init__(self, channel, reduction=16): 27 | super(SEBlock, self).__init__() 28 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 29 | self.fc = nn.Sequential( 30 | nn.Linear(channel, channel // reduction), 31 | nn.PReLU(), 32 | nn.Linear(channel // reduction, channel), 33 | nn.Sigmoid() 34 | ) 35 | 36 | def forward(self, x): 37 | b, c, _, _ = x.size() 38 | y = self.avg_pool(x).view(b, c) 39 | y = self.fc(y).view(b, c, 1, 1) 40 | return x * y 41 | 42 | 43 | class IRBlock(nn.Module): 44 | expansion = 1 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): 47 | super(IRBlock, self).__init__() 48 | self.bn0 = nn.BatchNorm2d(inplanes) 49 | self.conv1 = conv3x3(inplanes, inplanes) 50 | self.bn1 = nn.BatchNorm2d(inplanes) 51 | self.prelu = nn.PReLU() 52 | self.conv2 = conv3x3(inplanes, planes, stride) 53 | self.bn2 = nn.BatchNorm2d(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | self.use_se = use_se 57 | if self.use_se: 58 | self.se = SEBlock(planes) 59 | 60 | def forward(self, x): 61 | residual = x 62 | out = self.bn0(x) 63 | out = self.conv1(out) 64 | out = self.bn1(out) 65 | out = self.prelu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | if self.use_se: 70 | out = self.se(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | out = self.prelu(out) 77 | 78 | return out 79 | 80 | 81 | class ResNet(nn.Module): 82 | 83 | def __init__(self, block, layers, use_se=True): 84 | self.inplanes = 64 85 | self.use_se = use_se 86 | super(ResNet, self).__init__() 87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False) 88 | self.bn1 = nn.BatchNorm2d(64) 89 | self.prelu = nn.PReLU() 90 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 91 | self.layer1 = self._make_layer(block, 64, layers[0]) 92 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 93 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 94 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 95 | self.bn2 = nn.BatchNorm2d(512) 96 | self.dropout = nn.Dropout() 97 | self.fc = nn.Linear(512 * 7 * 7, 512) 98 | self.bn3 = nn.BatchNorm1d(512) 99 | 100 | for m in self.modules(): 101 | if isinstance(m, nn.Conv2d): 102 | nn.init.xavier_normal_(m.weight) 103 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 104 | nn.init.constant_(m.weight, 1) 105 | nn.init.constant_(m.bias, 0) 106 | elif isinstance(m, nn.Linear): 107 | nn.init.xavier_normal_(m.weight) 108 | nn.init.constant_(m.bias, 0) 109 | 110 | def _make_layer(self, block, planes, blocks, stride=1): 111 | downsample = None 112 | if stride != 1 or self.inplanes != planes * block.expansion: 113 | downsample = nn.Sequential( 114 | nn.Conv2d(self.inplanes, planes * block.expansion, 115 | kernel_size=1, stride=stride, bias=False), 116 | nn.BatchNorm2d(planes * block.expansion), 117 | ) 118 | 119 | layers = [] 120 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) 121 | self.inplanes = planes 122 | for i in range(1, blocks): 123 | layers.append(block(self.inplanes, planes, use_se=self.use_se)) 124 | 125 | return nn.Sequential(*layers) 126 | 127 | def forward(self, x): 128 | x = self.conv1(x) 129 | x = self.bn1(x) 130 | x = self.prelu(x) 131 | x = self.maxpool(x) 132 | 133 | x = self.layer1(x) 134 | x = self.layer2(x) 135 | x = self.layer3(x) 136 | x = self.layer4(x) 137 | 138 | x = self.bn2(x) 139 | x = self.dropout(x) 140 | x = x.view(x.size(0), -1) 141 | x = self.fc(x) 142 | x = self.bn3(x) 143 | 144 | return x 145 | 146 | 147 | class ArcMarginModel(nn.Module): 148 | def __init__(self, args): 149 | super(ArcMarginModel, self).__init__() 150 | 151 | self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size)) 152 | nn.init.xavier_uniform_(self.weight) 153 | 154 | self.easy_margin = args.easy_margin 155 | self.m = args.margin_m 156 | self.s = args.margin_s 157 | 158 | self.cos_m = math.cos(self.m) 159 | self.sin_m = math.sin(self.m) 160 | self.th = math.cos(math.pi - self.m) 161 | self.mm = math.sin(math.pi - self.m) * self.m 162 | 163 | def forward(self, input, label): 164 | x = F.normalize(input) 165 | W = F.normalize(self.weight) 166 | cosine = F.linear(x, W) 167 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 168 | phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) 169 | if self.easy_margin: 170 | phi = torch.where(cosine > 0, phi, cosine) 171 | else: 172 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 173 | one_hot = torch.zeros(cosine.size(), device=device) 174 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 175 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 176 | output *= self.s 177 | return output 178 | -------------------------------------------------------------------------------- /models/projected_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | ############################################################# 4 | # File: fs_model_fix_idnorm_donggp_saveoptim copy.py 5 | # Created Date: Wednesday January 12th 2022 6 | # Author: Chen Xuanhong 7 | # Email: chenxuanhongzju@outlook.com 8 | # Last Modified: Saturday, 13th May 2023 9:56:35 am 9 | # Modified By: Chen Xuanhong 10 | # Copyright (c) 2022 Shanghai Jiao Tong University 11 | ############################################################# 12 | 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from .base_model import BaseModel 18 | from .fs_networks_fix import Generator_Adain_Upsample 19 | 20 | from pg_modules.projected_discriminator import ProjectedDiscriminator 21 | 22 | def compute_grad2(d_out, x_in): 23 | batch_size = x_in.size(0) 24 | grad_dout = torch.autograd.grad( 25 | outputs=d_out.sum(), inputs=x_in, 26 | create_graph=True, retain_graph=True, only_inputs=True 27 | )[0] 28 | grad_dout2 = grad_dout.pow(2) 29 | assert(grad_dout2.size() == x_in.size()) 30 | reg = grad_dout2.view(batch_size, -1).sum(1) 31 | return reg 32 | 33 | class fsModel(BaseModel): 34 | def name(self): 35 | return 'fsModel' 36 | 37 | def initialize(self, opt): 38 | BaseModel.initialize(self, opt) 39 | # if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM 40 | self.isTrain = opt.isTrain 41 | 42 | # Generator network 43 | self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep) 44 | self.netG.cuda() 45 | 46 | # Id network 47 | netArc_checkpoint = opt.Arc_path 48 | netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu")) 49 | self.netArc = netArc_checkpoint 50 | self.netArc = self.netArc.cuda() 51 | self.netArc.eval() 52 | self.netArc.requires_grad_(False) 53 | if not self.isTrain: 54 | pretrained_path = opt.checkpoints_dir 55 | self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) 56 | return 57 | self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{}) 58 | # self.netD.feature_network.requires_grad_(False) 59 | self.netD.cuda() 60 | 61 | 62 | if self.isTrain: 63 | # define loss functions 64 | self.criterionFeat = nn.L1Loss() 65 | self.criterionRec = nn.L1Loss() 66 | 67 | 68 | # initialize optimizers 69 | 70 | # optimizer G 71 | params = list(self.netG.parameters()) 72 | self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8) 73 | 74 | # optimizer D 75 | params = list(self.netD.parameters()) 76 | self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8) 77 | 78 | # load networks 79 | if opt.continue_train: 80 | pretrained_path = '' if not self.isTrain else opt.load_pretrain 81 | # print (pretrained_path) 82 | self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) 83 | self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) 84 | self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path) 85 | self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path) 86 | torch.cuda.empty_cache() 87 | 88 | def cosin_metric(self, x1, x2): 89 | #return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2)) 90 | return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1)) 91 | 92 | 93 | 94 | def save(self, which_epoch): 95 | self.save_network(self.netG, 'G', which_epoch) 96 | self.save_network(self.netD, 'D', which_epoch) 97 | self.save_optim(self.optimizer_G, 'G', which_epoch) 98 | self.save_optim(self.optimizer_D, 'D', which_epoch) 99 | '''if self.gen_features: 100 | self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)''' 101 | 102 | def update_fixed_params(self): 103 | # after fixing the global generator for a number of iterations, also start finetuning it 104 | params = list(self.netG.parameters()) 105 | if self.gen_features: 106 | params += list(self.netE.parameters()) 107 | self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 108 | if self.opt.verbose: 109 | print('------------ Now also finetuning global generator -----------') 110 | 111 | def update_learning_rate(self): 112 | lrd = self.opt.lr / self.opt.niter_decay 113 | lr = self.old_lr - lrd 114 | for param_group in self.optimizer_D.param_groups: 115 | param_group['lr'] = lr 116 | for param_group in self.optimizer_G.param_groups: 117 | param_group['lr'] = lr 118 | if self.opt.verbose: 119 | print('update learning rate: %f -> %f' % (self.old_lr, lr)) 120 | self.old_lr = lr 121 | 122 | 123 | -------------------------------------------------------------------------------- /models/projectionhead.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ProjectionHead(nn.Module): 4 | def __init__(self, proj_dim=256): 5 | super(ProjectionHead, self).__init__() 6 | 7 | self.proj = nn.Sequential( 8 | nn.Linear(proj_dim, proj_dim), 9 | nn.ReLU(), 10 | nn.Linear(proj_dim, proj_dim), 11 | ) 12 | 13 | def forward(self, x): 14 | return self.proj(x) -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | class BaseOptions(): 7 | def __init__(self): 8 | self.parser = argparse.ArgumentParser() 9 | self.initialized = False 10 | 11 | def initialize(self): 12 | # experiment specifics 13 | self.parser.add_argument('--name', type=str, default='people', help='name of the experiment. It decides where to store samples and models') 14 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 15 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 16 | # self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use') 17 | self.parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization') 18 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') 19 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit") 20 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose') 21 | self.parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP') 22 | self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training') 23 | self.parser.add_argument('--isTrain', type=bool, default=True, help='local rank for distributed training') 24 | 25 | # input/output sizes 26 | self.parser.add_argument('--batchSize', type=int, default=8, help='input batch size') 27 | self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size') 28 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 29 | self.parser.add_argument('--label_nc', type=int, default=0, help='# of input label channels') 30 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 31 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 32 | 33 | # for setting inputs 34 | self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/') 35 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 36 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 37 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 38 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 39 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 40 | 41 | # for displays 42 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') 43 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 44 | 45 | # for generator 46 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG') 47 | self.parser.add_argument('--latent_size', type=int, default=512, help='latent size of Adain layer') 48 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 49 | self.parser.add_argument('--n_downsample_global', type=int, default=3, help='number of downsampling layers in netG') 50 | self.parser.add_argument('--n_blocks_global', type=int, default=6, help='number of residual blocks in the global generator network') 51 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network') 52 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use') 53 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer') 54 | 55 | # for instance-wise features 56 | self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input') 57 | self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input') 58 | self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input') 59 | self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features') 60 | self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps') 61 | self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder') 62 | self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') 63 | self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features') 64 | self.parser.add_argument('--image_size', type=int, default=224, help='number of clusters for features') 65 | self.parser.add_argument('--norm_G', type=str, default='spectralspadesyncbatch3x3', help='instance normalization or batch normalization') 66 | self.parser.add_argument('--semantic_nc', type=int, default=3, help='number of clusters for features') 67 | self.initialized = True 68 | 69 | def parse(self, save=True): 70 | if not self.initialized: 71 | self.initialize() 72 | self.opt = self.parser.parse_args() 73 | self.opt.isTrain = self.isTrain # train or test 74 | 75 | str_ids = self.opt.gpu_ids.split(',') 76 | self.opt.gpu_ids = [] 77 | for str_id in str_ids: 78 | id = int(str_id) 79 | if id >= 0: 80 | self.opt.gpu_ids.append(id) 81 | 82 | # set gpu ids 83 | if len(self.opt.gpu_ids) > 0: 84 | torch.cuda.set_device(self.opt.gpu_ids[0]) 85 | 86 | args = vars(self.opt) 87 | 88 | print('------------ Options -------------') 89 | for k, v in sorted(args.items()): 90 | print('%s: %s' % (str(k), str(v))) 91 | print('-------------- End ----------------') 92 | 93 | # save to the disk 94 | if self.opt.isTrain: 95 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 96 | util.mkdirs(expr_dir) 97 | if save and not self.opt.continue_train: 98 | file_name = os.path.join(expr_dir, 'opt.txt') 99 | with open(file_name, 'wt') as opt_file: 100 | opt_file.write('------------ Options -------------\n') 101 | for k, v in sorted(args.items()): 102 | opt_file.write('%s: %s\n' % (str(k), str(v))) 103 | opt_file.write('-------------- End ----------------\n') 104 | return self.opt 105 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-23 17:08:08 7 | Description: 8 | ''' 9 | from .base_options import BaseOptions 10 | 11 | class TestOptions(BaseOptions): 12 | def initialize(self): 13 | BaseOptions.initialize(self) 14 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 15 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 16 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 17 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 18 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 19 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') 20 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') 21 | self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map') 22 | self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") 23 | self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") 24 | self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") 25 | self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT") 26 | self.parser.add_argument("--pic_a_path", type=str, default='G:/swap_data/ID/elon-musk-hero-image.jpeg', help="Person who provides identity information") 27 | self.parser.add_argument("--pic_b_path", type=str, default='./demo_file/multi_people.jpg', help="Person who provides information other than their identity") 28 | self.parser.add_argument("--pic_specific_path", type=str, default='./crop_224/zrf.jpg', help="The specific person to be swapped") 29 | self.parser.add_argument("--multisepcific_dir", type=str, default='./demo_file/multispecific', help="Dir for multi specific") 30 | self.parser.add_argument("--video_path", type=str, default='G:/swap_data/video/HSB_Demo_Trim.mp4', help="path for the video to swap") 31 | self.parser.add_argument("--temp_path", type=str, default='./temp_results', help="path to save temporarily images") 32 | self.parser.add_argument("--output_path", type=str, default='./output/', help="results path") 33 | self.parser.add_argument('--id_thres', type=float, default=0.03, help='how many test images to run') 34 | self.parser.add_argument('--no_simswaplogo', action='store_true', help='Remove the watermark') 35 | self.parser.add_argument('--use_mask', action='store_true', help='Use mask for better result') 36 | self.parser.add_argument('--crop_size', type=int, default=512, help='Crop of size of input image') 37 | 38 | self.isTrain = False -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TrainOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | # for displays 7 | self.parser.add_argument('--display_freq', type=int, default=99, help='frequency of showing training results on screen') 8 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 9 | self.parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results') 10 | self.parser.add_argument('--save_epoch_freq', type=int, default=10000, help='frequency of saving checkpoints at the end of epochs') 11 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 12 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 13 | 14 | # for training 15 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 16 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') 17 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 18 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 19 | self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate') 20 | self.parser.add_argument('--niter_decay', type=int, default=10000, help='# of iter to linearly decay learning rate to zero') 21 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 22 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 23 | 24 | # for discriminators 25 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') 26 | self.parser.add_argument('--n_layers_D', type=int, default=4, help='only used if which_model_netD==n_layers') 27 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 28 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 29 | self.parser.add_argument('--lambda_id', type=float, default=20.0, help='weight for id loss') 30 | self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss') 31 | self.parser.add_argument('--lambda_GP', type=float, default=10.0, help='weight for gradient penalty loss') 32 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 33 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 34 | self.parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') 35 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') 36 | self.parser.add_argument('--times_G', type=int, default=1, 37 | help='time of training generator before traning discriminator') 38 | 39 | self.isTrain = True 40 | -------------------------------------------------------------------------------- /output/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/output/result.jpg -------------------------------------------------------------------------------- /parsing_model/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /pg_modules/diffaug.py: -------------------------------------------------------------------------------- 1 | # Differentiable Augmentation for Data-Efficient GAN Training 2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | # https://arxiv.org/pdf/2006.10738 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def DiffAugment(x, policy='', channels_first=True): 10 | if policy: 11 | if not channels_first: 12 | x = x.permute(0, 3, 1, 2) 13 | for p in policy.split(','): 14 | for f in AUGMENT_FNS[p]: 15 | x = f(x) 16 | if not channels_first: 17 | x = x.permute(0, 2, 3, 1) 18 | x = x.contiguous() 19 | return x 20 | 21 | 22 | def rand_brightness(x): 23 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 24 | return x 25 | 26 | 27 | def rand_saturation(x): 28 | x_mean = x.mean(dim=1, keepdim=True) 29 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 30 | return x 31 | 32 | 33 | def rand_contrast(x): 34 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 35 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 36 | return x 37 | 38 | 39 | def rand_translation(x, ratio=0.125): 40 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 41 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 42 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 43 | grid_batch, grid_x, grid_y = torch.meshgrid( 44 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 45 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 46 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 47 | ) 48 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 49 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 50 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 51 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 52 | return x 53 | 54 | 55 | def rand_cutout(x, ratio=0.2): 56 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 57 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 58 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 59 | grid_batch, grid_x, grid_y = torch.meshgrid( 60 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 61 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 62 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 63 | ) 64 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 65 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 66 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 67 | mask[grid_batch, grid_x, grid_y] = 0 68 | x = x * mask.unsqueeze(1) 69 | return x 70 | 71 | 72 | AUGMENT_FNS = { 73 | 'color': [rand_brightness, rand_saturation, rand_contrast], 74 | 'translation': [rand_translation], 75 | 'cutout': [rand_cutout], 76 | } 77 | -------------------------------------------------------------------------------- /pg_modules/projected_discriminator.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from pg_modules.blocks import DownBlock, DownBlockPatch, conv2d 7 | from pg_modules.projector import F_RandomProj 8 | from pg_modules.diffaug import DiffAugment 9 | 10 | 11 | class SingleDisc(nn.Module): 12 | def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False): 13 | super().__init__() 14 | channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, 15 | 256: 32, 512: 16, 1024: 8} 16 | 17 | # interpolate for start sz that are not powers of two 18 | if start_sz not in channel_dict.keys(): 19 | sizes = np.array(list(channel_dict.keys())) 20 | start_sz = sizes[np.argmin(abs(sizes - start_sz))] 21 | self.start_sz = start_sz 22 | 23 | # if given ndf, allocate all layers with the same ndf 24 | if ndf is None: 25 | nfc = channel_dict 26 | else: 27 | nfc = {k: ndf for k, v in channel_dict.items()} 28 | 29 | # for feature map discriminators with nfc not in channel_dict 30 | # this is the case for the pretrained backbone (midas.pretrained) 31 | if nc is not None and head is None: 32 | nfc[start_sz] = nc 33 | 34 | layers = [] 35 | 36 | # Head if the initial input is the full modality 37 | if head: 38 | layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), 39 | nn.LeakyReLU(0.2, inplace=True)] 40 | 41 | # Down Blocks 42 | DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) 43 | while start_sz > end_sz: 44 | layers.append(DB(nfc[start_sz], nfc[start_sz//2])) 45 | start_sz = start_sz // 2 46 | 47 | layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False)) 48 | self.main = nn.Sequential(*layers) 49 | 50 | def forward(self, x, c): 51 | return self.main(x) 52 | 53 | 54 | class SingleDiscCond(nn.Module): 55 | def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128): 56 | super().__init__() 57 | self.cmap_dim = cmap_dim 58 | 59 | # midas channels 60 | channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, 61 | 256: 32, 512: 16, 1024: 8} 62 | 63 | # interpolate for start sz that are not powers of two 64 | if start_sz not in channel_dict.keys(): 65 | sizes = np.array(list(channel_dict.keys())) 66 | start_sz = sizes[np.argmin(abs(sizes - start_sz))] 67 | self.start_sz = start_sz 68 | 69 | # if given ndf, allocate all layers with the same ndf 70 | if ndf is None: 71 | nfc = channel_dict 72 | else: 73 | nfc = {k: ndf for k, v in channel_dict.items()} 74 | 75 | # for feature map discriminators with nfc not in channel_dict 76 | # this is the case for the pretrained backbone (midas.pretrained) 77 | if nc is not None and head is None: 78 | nfc[start_sz] = nc 79 | 80 | layers = [] 81 | 82 | # Head if the initial input is the full modality 83 | if head: 84 | layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), 85 | nn.LeakyReLU(0.2, inplace=True)] 86 | 87 | # Down Blocks 88 | DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) 89 | while start_sz > end_sz: 90 | layers.append(DB(nfc[start_sz], nfc[start_sz//2])) 91 | start_sz = start_sz // 2 92 | self.main = nn.Sequential(*layers) 93 | 94 | # additions for conditioning on class information 95 | self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False) 96 | self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim) 97 | self.embed_proj = nn.Sequential( 98 | nn.Linear(self.embed.embedding_dim, self.cmap_dim), 99 | nn.LeakyReLU(0.2, inplace=True), 100 | ) 101 | 102 | def forward(self, x, c): 103 | h = self.main(x) 104 | out = self.cls(h) 105 | 106 | # conditioning via projection 107 | cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1) 108 | out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) 109 | 110 | return out 111 | 112 | 113 | class MultiScaleD(nn.Module): 114 | def __init__( 115 | self, 116 | channels, 117 | resolutions, 118 | num_discs=4, 119 | proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing 120 | cond=0, 121 | separable=False, 122 | patch=False, 123 | **kwargs, 124 | ): 125 | super().__init__() 126 | 127 | assert num_discs in [1, 2, 3, 4] 128 | 129 | # the first disc is on the lowest level of the backbone 130 | self.disc_in_channels = channels[:num_discs] 131 | self.disc_in_res = resolutions[:num_discs] 132 | Disc = SingleDiscCond if cond else SingleDisc 133 | 134 | mini_discs = [] 135 | for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)): 136 | start_sz = res if not patch else 16 137 | mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)], 138 | self.mini_discs = nn.ModuleDict(mini_discs) 139 | 140 | def forward(self, features, c): 141 | all_logits = [] 142 | for k, disc in self.mini_discs.items(): 143 | res = disc(features[k], c).view(features[k].size(0), -1) 144 | all_logits.append(res) 145 | 146 | all_logits = torch.cat(all_logits, dim=1) 147 | return all_logits 148 | 149 | 150 | class ProjectedDiscriminator(torch.nn.Module): 151 | def __init__( 152 | self, 153 | diffaug=True, 154 | interp224=True, 155 | backbone_kwargs={}, 156 | **kwargs 157 | ): 158 | super().__init__() 159 | self.diffaug = diffaug 160 | self.interp224 = interp224 161 | self.feature_network = F_RandomProj(**backbone_kwargs) 162 | self.discriminator = MultiScaleD( 163 | channels=self.feature_network.CHANNELS, 164 | resolutions=self.feature_network.RESOLUTIONS, 165 | **backbone_kwargs, 166 | ) 167 | 168 | def train(self, mode=True): 169 | self.feature_network = self.feature_network.train(False) 170 | self.discriminator = self.discriminator.train(mode) 171 | return self 172 | 173 | def eval(self): 174 | return self.train(False) 175 | 176 | def get_feature(self, x): 177 | features = self.feature_network(x, get_features=True) 178 | return features 179 | 180 | def forward(self, x, c): 181 | # if self.diffaug: 182 | # x = DiffAugment(x, policy='color,translation,cutout') 183 | 184 | # if self.interp224: 185 | # x = F.interpolate(x, 224, mode='bilinear', align_corners=False) 186 | 187 | features,backbone_features = self.feature_network(x) 188 | logits = self.discriminator(features, c) 189 | 190 | return logits,backbone_features 191 | 192 | -------------------------------------------------------------------------------- /pg_modules/projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | from pg_modules.blocks import FeatureFusionBlock 5 | 6 | 7 | def _make_scratch_ccm(scratch, in_channels, cout, expand=False): 8 | # shapes 9 | out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4 10 | 11 | scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True) 12 | scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True) 13 | scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True) 14 | scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True) 15 | 16 | scratch.CHANNELS = out_channels 17 | 18 | return scratch 19 | 20 | 21 | def _make_scratch_csm(scratch, in_channels, cout, expand): 22 | scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True) 23 | scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand) 24 | scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand) 25 | scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False)) 26 | 27 | # last refinenet does not expand to save channels in higher dimensions 28 | scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4 29 | 30 | return scratch 31 | 32 | 33 | def _make_efficientnet(model): 34 | pretrained = nn.Module() 35 | pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2]) 36 | pretrained.layer1 = nn.Sequential(*model.blocks[2:3]) 37 | pretrained.layer2 = nn.Sequential(*model.blocks[3:5]) 38 | pretrained.layer3 = nn.Sequential(*model.blocks[5:9]) 39 | return pretrained 40 | 41 | 42 | def calc_channels(pretrained, inp_res=224): 43 | channels = [] 44 | tmp = torch.zeros(1, 3, inp_res, inp_res) 45 | 46 | # forward pass 47 | tmp = pretrained.layer0(tmp) 48 | channels.append(tmp.shape[1]) 49 | tmp = pretrained.layer1(tmp) 50 | channels.append(tmp.shape[1]) 51 | tmp = pretrained.layer2(tmp) 52 | channels.append(tmp.shape[1]) 53 | tmp = pretrained.layer3(tmp) 54 | channels.append(tmp.shape[1]) 55 | 56 | return channels 57 | 58 | 59 | def _make_projector(im_res, cout, proj_type, expand=False): 60 | assert proj_type in [0, 1, 2], "Invalid projection type" 61 | 62 | ### Build pretrained feature network 63 | model = timm.create_model('tf_efficientnet_lite0', pretrained=True) 64 | pretrained = _make_efficientnet(model) 65 | 66 | # determine resolution of feature maps, this is later used to calculate the number 67 | # of down blocks in the discriminators. Interestingly, the best results are achieved 68 | # by fixing this to 256, ie., we use the same number of down blocks per discriminator 69 | # independent of the dataset resolution 70 | im_res = 256 71 | pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32] 72 | pretrained.CHANNELS = calc_channels(pretrained) 73 | 74 | if proj_type == 0: return pretrained, None 75 | 76 | ### Build CCM 77 | scratch = nn.Module() 78 | scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand) 79 | pretrained.CHANNELS = scratch.CHANNELS 80 | 81 | if proj_type == 1: return pretrained, scratch 82 | 83 | ### build CSM 84 | scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand) 85 | 86 | # CSM upsamples x2 so the feature map resolution doubles 87 | pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS] 88 | pretrained.CHANNELS = scratch.CHANNELS 89 | 90 | return pretrained, scratch 91 | 92 | 93 | class F_RandomProj(nn.Module): 94 | def __init__( 95 | self, 96 | im_res=256, 97 | cout=64, 98 | expand=True, 99 | proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing 100 | **kwargs, 101 | ): 102 | super().__init__() 103 | self.proj_type = proj_type 104 | self.cout = cout 105 | self.expand = expand 106 | 107 | # build pretrained feature network and random decoder (scratch) 108 | self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand) 109 | self.CHANNELS = self.pretrained.CHANNELS 110 | self.RESOLUTIONS = self.pretrained.RESOLUTIONS 111 | 112 | def forward(self, x, get_features=False): 113 | # predict feature maps 114 | out0 = self.pretrained.layer0(x) 115 | out1 = self.pretrained.layer1(out0) 116 | out2 = self.pretrained.layer2(out1) 117 | out3 = self.pretrained.layer3(out2) 118 | 119 | # start enumerating at the lowest layer (this is where we put the first discriminator) 120 | backbone_features = { 121 | '0': out0, 122 | '1': out1, 123 | '2': out2, 124 | '3': out3, 125 | } 126 | if get_features: 127 | return backbone_features 128 | 129 | if self.proj_type == 0: return backbone_features 130 | 131 | out0_channel_mixed = self.scratch.layer0_ccm(backbone_features['0']) 132 | out1_channel_mixed = self.scratch.layer1_ccm(backbone_features['1']) 133 | out2_channel_mixed = self.scratch.layer2_ccm(backbone_features['2']) 134 | out3_channel_mixed = self.scratch.layer3_ccm(backbone_features['3']) 135 | 136 | out = { 137 | '0': out0_channel_mixed, 138 | '1': out1_channel_mixed, 139 | '2': out2_channel_mixed, 140 | '3': out3_channel_mixed, 141 | } 142 | 143 | if self.proj_type == 1: return out 144 | 145 | # from bottom to top 146 | out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed) 147 | out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed) 148 | out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed) 149 | out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed) 150 | 151 | out = { 152 | '0': out0_scale_mixed, 153 | '1': out1_scale_mixed, 154 | '2': out2_scale_mixed, 155 | '3': out3_scale_mixed, 156 | } 157 | 158 | return out, backbone_features 159 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import cog 2 | import tempfile 3 | from pathlib import Path 4 | import argparse 5 | import cv2 6 | import torch 7 | from PIL import Image 8 | import torch.nn.functional as F 9 | from torchvision import transforms 10 | from models.models import create_model 11 | from options.test_options import TestOptions 12 | from util.reverse2original import reverse2wholeimage 13 | from util.norm import SpecificNorm 14 | from test_wholeimage_swapmulti import _totensor 15 | from insightface_func.face_detect_crop_multi import Face_detect_crop as Face_detect_crop_multi 16 | from insightface_func.face_detect_crop_single import Face_detect_crop as Face_detect_crop_single 17 | 18 | 19 | class Predictor(cog.Predictor): 20 | def setup(self): 21 | self.transformer_Arcface = transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 24 | ]) 25 | 26 | @cog.input("source", type=Path, help="source image") 27 | @cog.input("target", type=Path, help="target image") 28 | @cog.input("mode", type=str, options=['single', 'all'], default='all', 29 | help="swap a single face (the one with highest confidence by face detection) or all faces in the target image") 30 | def predict(self, source, target, mode='all'): 31 | 32 | app = Face_detect_crop_multi(name='antelope', root='./insightface_func/models') 33 | 34 | if mode == 'single': 35 | app = Face_detect_crop_single(name='antelope', root='./insightface_func/models') 36 | 37 | app.prepare(ctx_id=0, det_thresh=0.6, det_size=(640, 640)) 38 | 39 | options = TestOptions() 40 | options.initialize() 41 | opt = options.parser.parse_args(["--Arc_path", 'arcface_model/arcface_checkpoint.tar', "--pic_a_path", str(source), 42 | "--pic_b_path", str(target), "--isTrain", False, "--no_simswaplogo"]) 43 | 44 | str_ids = opt.gpu_ids.split(',') 45 | opt.gpu_ids = [] 46 | for str_id in str_ids: 47 | id = int(str_id) 48 | if id >= 0: 49 | opt.gpu_ids.append(id) 50 | 51 | # set gpu ids 52 | if len(opt.gpu_ids) > 0: 53 | torch.cuda.set_device(opt.gpu_ids[0]) 54 | 55 | torch.nn.Module.dump_patches = True 56 | model = create_model(opt) 57 | model.eval() 58 | 59 | crop_size = opt.crop_size 60 | spNorm = SpecificNorm() 61 | 62 | with torch.no_grad(): 63 | pic_a = opt.pic_a_path 64 | img_a_whole = cv2.imread(pic_a) 65 | img_a_align_crop, _ = app.get(img_a_whole, crop_size) 66 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0], cv2.COLOR_BGR2RGB)) 67 | img_a = self.transformer_Arcface(img_a_align_crop_pil) 68 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) 69 | 70 | # convert numpy to tensor 71 | img_id = img_id.cuda() 72 | 73 | # create latent id 74 | img_id_downsample = F.interpolate(img_id, size=(112,112)) 75 | latend_id = model.netArc(img_id_downsample) 76 | latend_id = F.normalize(latend_id, p=2, dim=1) 77 | 78 | ############## Forward Pass ###################### 79 | 80 | pic_b = opt.pic_b_path 81 | img_b_whole = cv2.imread(pic_b) 82 | img_b_align_crop_list, b_mat_list = app.get(img_b_whole, crop_size) 83 | 84 | swap_result_list = [] 85 | b_align_crop_tenor_list = [] 86 | 87 | for b_align_crop in img_b_align_crop_list: 88 | b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop, cv2.COLOR_BGR2RGB))[None, ...].cuda() 89 | 90 | swap_result = model(None, b_align_crop_tenor, latend_id, None, True)[0] 91 | swap_result_list.append(swap_result) 92 | b_align_crop_tenor_list.append(b_align_crop_tenor) 93 | 94 | net = None 95 | 96 | out_path = Path(tempfile.mkdtemp()) / "output.png" 97 | 98 | reverse2wholeimage(b_align_crop_tenor_list, swap_result_list, b_mat_list, crop_size, img_b_whole, None, 99 | str(out_path), opt.no_simswaplogo, 100 | pasring_model=net, use_mask=opt.use_mask, norm=spNorm) 101 | return out_path 102 | -------------------------------------------------------------------------------- /simswaplogo/simswaplogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/simswaplogo/simswaplogo.png -------------------------------------------------------------------------------- /simswaplogo/socialbook_logo.2020.357eed90add7705e54a8.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 10 | 14 | 15 | 56 | 58 | 59 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /test_one_image.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import torch 4 | import fractions 5 | import numpy as np 6 | from PIL import Image 7 | import torch.nn.functional as F 8 | from torchvision import transforms 9 | from models.models import create_model 10 | from options.test_options import TestOptions 11 | 12 | 13 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 14 | 15 | transformer = transforms.Compose([ 16 | transforms.ToTensor(), 17 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 18 | ]) 19 | 20 | transformer_Arcface = transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 23 | ]) 24 | 25 | detransformer = transforms.Compose([ 26 | transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]), 27 | transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1]) 28 | ]) 29 | if __name__ == '__main__': 30 | opt = TestOptions().parse() 31 | 32 | start_epoch, epoch_iter = 1, 0 33 | 34 | torch.nn.Module.dump_patches = True 35 | model = create_model(opt) 36 | model.eval() 37 | 38 | with torch.no_grad(): 39 | 40 | pic_a = opt.pic_a_path 41 | img_a = Image.open(pic_a).convert('RGB') 42 | img_a = transformer_Arcface(img_a) 43 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) 44 | 45 | pic_b = opt.pic_b_path 46 | 47 | img_b = Image.open(pic_b).convert('RGB') 48 | img_b = transformer(img_b) 49 | img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2]) 50 | 51 | # convert numpy to tensor 52 | img_id = img_id.cuda() 53 | img_att = img_att.cuda() 54 | 55 | #create latent id 56 | img_id_downsample = F.interpolate(img_id, size=(112,112)) 57 | latend_id = model.netArc(img_id_downsample) 58 | latend_id = latend_id.detach().to('cpu') 59 | latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True) 60 | latend_id = latend_id.to('cuda') 61 | 62 | 63 | ############## Forward Pass ###################### 64 | img_fake = model(img_id, img_att, latend_id, latend_id, True) 65 | 66 | 67 | for i in range(img_id.shape[0]): 68 | if i == 0: 69 | row1 = img_id[i] 70 | row2 = img_att[i] 71 | row3 = img_fake[i] 72 | else: 73 | row1 = torch.cat([row1, img_id[i]], dim=2) 74 | row2 = torch.cat([row2, img_att[i]], dim=2) 75 | row3 = torch.cat([row3, img_fake[i]], dim=2) 76 | 77 | #full = torch.cat([row1, row2, row3], dim=1).detach() 78 | full = row3.detach() 79 | full = full.permute(1, 2, 0) 80 | output = full.to('cpu') 81 | output = np.array(output) 82 | output = output[..., ::-1] 83 | 84 | output = output*255 85 | 86 | cv2.imwrite(opt.output_path + 'result.jpg', output) -------------------------------------------------------------------------------- /test_video_swap_multispecific.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import torch 4 | import fractions 5 | from PIL import Image 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | from models.models import create_model 9 | from options.test_options import TestOptions 10 | from insightface_func.face_detect_crop_multi import Face_detect_crop 11 | from util.videoswap_multispecific import video_swap 12 | import os 13 | import glob 14 | 15 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 16 | 17 | transformer = transforms.Compose([ 18 | transforms.ToTensor(), 19 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 20 | ]) 21 | 22 | transformer_Arcface = transforms.Compose([ 23 | transforms.ToTensor(), 24 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 25 | ]) 26 | 27 | # detransformer = transforms.Compose([ 28 | # transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]), 29 | # transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1]) 30 | # ]) 31 | 32 | 33 | if __name__ == '__main__': 34 | opt = TestOptions().parse() 35 | pic_specific = opt.pic_specific_path 36 | start_epoch, epoch_iter = 1, 0 37 | crop_size = opt.crop_size 38 | 39 | multisepcific_dir = opt.multisepcific_dir 40 | torch.nn.Module.dump_patches = True 41 | if crop_size == 512: 42 | opt.which_epoch = 550000 43 | opt.name = '512' 44 | mode = 'ffhq' 45 | else: 46 | mode = 'None' 47 | model = create_model(opt) 48 | model.eval() 49 | 50 | 51 | app = Face_detect_crop(name='antelope', root='./insightface_func/models') 52 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode) 53 | 54 | # The specific person to be swapped(source) 55 | 56 | source_specific_id_nonorm_list = [] 57 | source_path = os.path.join(multisepcific_dir,'SRC_*') 58 | source_specific_images_path = sorted(glob.glob(source_path)) 59 | with torch.no_grad(): 60 | for source_specific_image_path in source_specific_images_path: 61 | specific_person_whole = cv2.imread(source_specific_image_path) 62 | specific_person_align_crop, _ = app.get(specific_person_whole,crop_size) 63 | specific_person_align_crop_pil = Image.fromarray(cv2.cvtColor(specific_person_align_crop[0],cv2.COLOR_BGR2RGB)) 64 | specific_person = transformer_Arcface(specific_person_align_crop_pil) 65 | specific_person = specific_person.view(-1, specific_person.shape[0], specific_person.shape[1], specific_person.shape[2]) 66 | # convert numpy to tensor 67 | specific_person = specific_person.cuda() 68 | #create latent id 69 | specific_person_downsample = F.interpolate(specific_person, size=(112,112)) 70 | specific_person_id_nonorm = model.netArc(specific_person_downsample) 71 | source_specific_id_nonorm_list.append(specific_person_id_nonorm.clone()) 72 | 73 | 74 | # The person who provides id information (list) 75 | target_id_norm_list = [] 76 | target_path = os.path.join(multisepcific_dir,'DST_*') 77 | target_images_path = sorted(glob.glob(target_path)) 78 | 79 | for target_image_path in target_images_path: 80 | img_a_whole = cv2.imread(target_image_path) 81 | img_a_align_crop, _ = app.get(img_a_whole,crop_size) 82 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB)) 83 | img_a = transformer_Arcface(img_a_align_crop_pil) 84 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) 85 | # convert numpy to tensor 86 | img_id = img_id.cuda() 87 | #create latent id 88 | img_id_downsample = F.interpolate(img_id, size=(112,112)) 89 | latend_id = model.netArc(img_id_downsample) 90 | latend_id = F.normalize(latend_id, p=2, dim=1) 91 | target_id_norm_list.append(latend_id.clone()) 92 | 93 | assert len(target_id_norm_list) == len(source_specific_id_nonorm_list), "The number of images in source and target directory must be same !!!" 94 | 95 | 96 | 97 | video_swap(opt.video_path, target_id_norm_list,source_specific_id_nonorm_list, opt.id_thres, \ 98 | model, app, opt.output_path,temp_results_dir=opt.temp_path,no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size) 99 | 100 | -------------------------------------------------------------------------------- /test_video_swapmulti.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 19:00:34 7 | Description: 8 | ''' 9 | 10 | import cv2 11 | import torch 12 | import fractions 13 | import numpy as np 14 | from PIL import Image 15 | import torch.nn.functional as F 16 | from torchvision import transforms 17 | from models.models import create_model 18 | from options.test_options import TestOptions 19 | from insightface_func.face_detect_crop_multi import Face_detect_crop 20 | from util.videoswap import video_swap 21 | import os 22 | 23 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 24 | 25 | transformer = transforms.Compose([ 26 | transforms.ToTensor(), 27 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 28 | ]) 29 | 30 | transformer_Arcface = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 33 | ]) 34 | 35 | # detransformer = transforms.Compose([ 36 | # transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]), 37 | # transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1]) 38 | # ]) 39 | 40 | 41 | if __name__ == '__main__': 42 | opt = TestOptions().parse() 43 | 44 | start_epoch, epoch_iter = 1, 0 45 | crop_size = opt.crop_size 46 | 47 | torch.nn.Module.dump_patches = True 48 | 49 | if crop_size == 512: 50 | opt.which_epoch = 550000 51 | opt.name = '512' 52 | mode = 'ffhq' 53 | else: 54 | mode = 'None' 55 | model = create_model(opt) 56 | model.eval() 57 | 58 | app = Face_detect_crop(name='antelope', root='./insightface_func/models') 59 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode) 60 | 61 | with torch.no_grad(): 62 | pic_a = opt.pic_a_path 63 | # img_a = Image.open(pic_a).convert('RGB') 64 | img_a_whole = cv2.imread(pic_a) 65 | img_a_align_crop, _ = app.get(img_a_whole,crop_size) 66 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB)) 67 | img_a = transformer_Arcface(img_a_align_crop_pil) 68 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) 69 | 70 | # pic_b = opt.pic_b_path 71 | # img_b_whole = cv2.imread(pic_b) 72 | # img_b_align_crop, b_mat = app.get(img_b_whole,crop_size) 73 | # img_b_align_crop_pil = Image.fromarray(cv2.cvtColor(img_b_align_crop,cv2.COLOR_BGR2RGB)) 74 | # img_b = transformer(img_b_align_crop_pil) 75 | # img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2]) 76 | 77 | # convert numpy to tensor 78 | img_id = img_id.cuda() 79 | # img_att = img_att.cuda() 80 | 81 | #create latent id 82 | img_id_downsample = F.interpolate(img_id, size=(112,112)) 83 | latend_id = model.netArc(img_id_downsample) 84 | latend_id = F.normalize(latend_id, p=2, dim=1) 85 | 86 | video_swap(opt.video_path, latend_id, model, app, opt.output_path,temp_results_dir=opt.temp_path,\ 87 | no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size) 88 | 89 | -------------------------------------------------------------------------------- /test_video_swapsingle.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 19:00:38 7 | Description: 8 | ''' 9 | 10 | import cv2 11 | import torch 12 | import fractions 13 | import numpy as np 14 | from PIL import Image 15 | import torch.nn.functional as F 16 | from torchvision import transforms 17 | from models.models import create_model 18 | from options.test_options import TestOptions 19 | from insightface_func.face_detect_crop_single import Face_detect_crop 20 | from util.videoswap import video_swap 21 | import os 22 | 23 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 24 | 25 | transformer = transforms.Compose([ 26 | transforms.ToTensor(), 27 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 28 | ]) 29 | 30 | transformer_Arcface = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 33 | ]) 34 | 35 | # detransformer = transforms.Compose([ 36 | # transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]), 37 | # transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1]) 38 | # ]) 39 | 40 | 41 | if __name__ == '__main__': 42 | opt = TestOptions().parse() 43 | 44 | start_epoch, epoch_iter = 1, 0 45 | crop_size = opt.crop_size 46 | 47 | torch.nn.Module.dump_patches = True 48 | if crop_size == 512: 49 | opt.which_epoch = 550000 50 | opt.name = '512' 51 | mode = 'ffhq' 52 | else: 53 | mode = 'None' 54 | model = create_model(opt) 55 | model.eval() 56 | 57 | 58 | app = Face_detect_crop(name='antelope', root='./insightface_func/models') 59 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode) 60 | with torch.no_grad(): 61 | pic_a = opt.pic_a_path 62 | # img_a = Image.open(pic_a).convert('RGB') 63 | img_a_whole = cv2.imread(pic_a) 64 | img_a_align_crop, _ = app.get(img_a_whole,crop_size) 65 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB)) 66 | img_a = transformer_Arcface(img_a_align_crop_pil) 67 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) 68 | 69 | # pic_b = opt.pic_b_path 70 | # img_b_whole = cv2.imread(pic_b) 71 | # img_b_align_crop, b_mat = app.get(img_b_whole,crop_size) 72 | # img_b_align_crop_pil = Image.fromarray(cv2.cvtColor(img_b_align_crop,cv2.COLOR_BGR2RGB)) 73 | # img_b = transformer(img_b_align_crop_pil) 74 | # img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2]) 75 | 76 | # convert numpy to tensor 77 | img_id = img_id.cuda() 78 | # img_att = img_att.cuda() 79 | 80 | #create latent id 81 | img_id_downsample = F.interpolate(img_id, size=(112,112)) 82 | latend_id = model.netArc(img_id_downsample) 83 | latend_id = F.normalize(latend_id, p=2, dim=1) 84 | 85 | video_swap(opt.video_path, latend_id, model, app, opt.output_path,temp_results_dir=opt.temp_path,\ 86 | no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size) 87 | 88 | -------------------------------------------------------------------------------- /test_video_swapspecific.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 19:00:42 7 | Description: 8 | ''' 9 | 10 | import cv2 11 | import torch 12 | import fractions 13 | import numpy as np 14 | from PIL import Image 15 | import torch.nn.functional as F 16 | from torchvision import transforms 17 | from models.models import create_model 18 | from options.test_options import TestOptions 19 | from insightface_func.face_detect_crop_multi import Face_detect_crop 20 | from util.videoswap_specific import video_swap 21 | import os 22 | 23 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 24 | 25 | transformer = transforms.Compose([ 26 | transforms.ToTensor(), 27 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 28 | ]) 29 | 30 | transformer_Arcface = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 33 | ]) 34 | 35 | # detransformer = transforms.Compose([ 36 | # transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]), 37 | # transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1]) 38 | # ]) 39 | 40 | 41 | if __name__ == '__main__': 42 | opt = TestOptions().parse() 43 | pic_specific = opt.pic_specific_path 44 | start_epoch, epoch_iter = 1, 0 45 | crop_size = opt.crop_size 46 | 47 | torch.nn.Module.dump_patches = True 48 | if crop_size == 512: 49 | opt.which_epoch = 550000 50 | opt.name = '512' 51 | mode = 'ffhq' 52 | else: 53 | mode = 'None' 54 | model = create_model(opt) 55 | model.eval() 56 | 57 | 58 | app = Face_detect_crop(name='antelope', root='./insightface_func/models') 59 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode) 60 | with torch.no_grad(): 61 | pic_a = opt.pic_a_path 62 | # img_a = Image.open(pic_a).convert('RGB') 63 | img_a_whole = cv2.imread(pic_a) 64 | img_a_align_crop, _ = app.get(img_a_whole,crop_size) 65 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB)) 66 | img_a = transformer_Arcface(img_a_align_crop_pil) 67 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) 68 | 69 | # pic_b = opt.pic_b_path 70 | # img_b_whole = cv2.imread(pic_b) 71 | # img_b_align_crop, b_mat = app.get(img_b_whole,crop_size) 72 | # img_b_align_crop_pil = Image.fromarray(cv2.cvtColor(img_b_align_crop,cv2.COLOR_BGR2RGB)) 73 | # img_b = transformer(img_b_align_crop_pil) 74 | # img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2]) 75 | 76 | # convert numpy to tensor 77 | img_id = img_id.cuda() 78 | # img_att = img_att.cuda() 79 | 80 | #create latent id 81 | img_id_downsample = F.interpolate(img_id, size=(112,112)) 82 | latend_id = model.netArc(img_id_downsample) 83 | latend_id = F.normalize(latend_id, p=2, dim=1) 84 | 85 | 86 | # The specific person to be swapped 87 | specific_person_whole = cv2.imread(pic_specific) 88 | specific_person_align_crop, _ = app.get(specific_person_whole,crop_size) 89 | specific_person_align_crop_pil = Image.fromarray(cv2.cvtColor(specific_person_align_crop[0],cv2.COLOR_BGR2RGB)) 90 | specific_person = transformer_Arcface(specific_person_align_crop_pil) 91 | specific_person = specific_person.view(-1, specific_person.shape[0], specific_person.shape[1], specific_person.shape[2]) 92 | specific_person = specific_person.cuda() 93 | specific_person_downsample = F.interpolate(specific_person, size=(112,112)) 94 | specific_person_id_nonorm = model.netArc(specific_person_downsample) 95 | 96 | video_swap(opt.video_path, latend_id,specific_person_id_nonorm, opt.id_thres, \ 97 | model, app, opt.output_path,temp_results_dir=opt.temp_path,no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size) 98 | 99 | -------------------------------------------------------------------------------- /test_wholeimage_swap_multispecific.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 19:19:22 7 | Description: 8 | ''' 9 | 10 | import cv2 11 | import torch 12 | import fractions 13 | import numpy as np 14 | from PIL import Image 15 | import torch.nn.functional as F 16 | from torchvision import transforms 17 | from models.models import create_model 18 | from options.test_options import TestOptions 19 | from insightface_func.face_detect_crop_multi import Face_detect_crop 20 | from util.reverse2original import reverse2wholeimage 21 | import os 22 | from util.add_watermark import watermark_image 23 | import torch.nn as nn 24 | from util.norm import SpecificNorm 25 | import glob 26 | from parsing_model.model import BiSeNet 27 | 28 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 29 | 30 | transformer_Arcface = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 33 | ]) 34 | 35 | def _totensor(array): 36 | tensor = torch.from_numpy(array) 37 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous() 38 | return img.float().div(255) 39 | 40 | def _toarctensor(array): 41 | tensor = torch.from_numpy(array) 42 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous() 43 | return img.float().div(255) 44 | 45 | if __name__ == '__main__': 46 | opt = TestOptions().parse() 47 | 48 | start_epoch, epoch_iter = 1, 0 49 | crop_size = opt.crop_size 50 | 51 | multisepcific_dir = opt.multisepcific_dir 52 | 53 | torch.nn.Module.dump_patches = True 54 | 55 | if crop_size == 512: 56 | opt.which_epoch = 550000 57 | opt.name = '512' 58 | mode = 'ffhq' 59 | else: 60 | mode = 'None' 61 | 62 | logoclass = watermark_image('./simswaplogo/simswaplogo.png') 63 | model = create_model(opt) 64 | model.eval() 65 | mse = torch.nn.MSELoss().cuda() 66 | 67 | spNorm =SpecificNorm() 68 | 69 | 70 | app = Face_detect_crop(name='antelope', root='./insightface_func/models') 71 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode) 72 | 73 | with torch.no_grad(): 74 | # The specific person to be swapped(source) 75 | 76 | source_specific_id_nonorm_list = [] 77 | source_path = os.path.join(multisepcific_dir,'SRC_*') 78 | source_specific_images_path = sorted(glob.glob(source_path)) 79 | 80 | for source_specific_image_path in source_specific_images_path: 81 | specific_person_whole = cv2.imread(source_specific_image_path) 82 | specific_person_align_crop, _ = app.get(specific_person_whole,crop_size) 83 | specific_person_align_crop_pil = Image.fromarray(cv2.cvtColor(specific_person_align_crop[0],cv2.COLOR_BGR2RGB)) 84 | specific_person = transformer_Arcface(specific_person_align_crop_pil) 85 | specific_person = specific_person.view(-1, specific_person.shape[0], specific_person.shape[1], specific_person.shape[2]) 86 | # convert numpy to tensor 87 | specific_person = specific_person.cuda() 88 | #create latent id 89 | specific_person_downsample = F.interpolate(specific_person, size=(112,112)) 90 | specific_person_id_nonorm = model.netArc(specific_person_downsample) 91 | source_specific_id_nonorm_list.append(specific_person_id_nonorm.clone()) 92 | 93 | 94 | # The person who provides id information (list) 95 | target_id_norm_list = [] 96 | target_path = os.path.join(multisepcific_dir,'DST_*') 97 | target_images_path = sorted(glob.glob(target_path)) 98 | 99 | for target_image_path in target_images_path: 100 | img_a_whole = cv2.imread(target_image_path) 101 | img_a_align_crop, _ = app.get(img_a_whole,crop_size) 102 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB)) 103 | img_a = transformer_Arcface(img_a_align_crop_pil) 104 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) 105 | # convert numpy to tensor 106 | img_id = img_id.cuda() 107 | #create latent id 108 | img_id_downsample = F.interpolate(img_id, size=(112,112)) 109 | latend_id = model.netArc(img_id_downsample) 110 | latend_id = F.normalize(latend_id, p=2, dim=1) 111 | target_id_norm_list.append(latend_id.clone()) 112 | 113 | assert len(target_id_norm_list) == len(source_specific_id_nonorm_list), "The number of images in source and target directory must be same !!!" 114 | 115 | ############## Forward Pass ###################### 116 | 117 | pic_b = opt.pic_b_path 118 | img_b_whole = cv2.imread(pic_b) 119 | 120 | img_b_align_crop_list, b_mat_list = app.get(img_b_whole,crop_size) 121 | # detect_results = None 122 | swap_result_list = [] 123 | 124 | id_compare_values = [] 125 | b_align_crop_tenor_list = [] 126 | for b_align_crop in img_b_align_crop_list: 127 | 128 | b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda() 129 | 130 | b_align_crop_tenor_arcnorm = spNorm(b_align_crop_tenor) 131 | b_align_crop_tenor_arcnorm_downsample = F.interpolate(b_align_crop_tenor_arcnorm, size=(112,112)) 132 | b_align_crop_id_nonorm = model.netArc(b_align_crop_tenor_arcnorm_downsample) 133 | 134 | id_compare_values.append([]) 135 | for source_specific_id_nonorm_tmp in source_specific_id_nonorm_list: 136 | id_compare_values[-1].append(mse(b_align_crop_id_nonorm,source_specific_id_nonorm_tmp).detach().cpu().numpy()) 137 | b_align_crop_tenor_list.append(b_align_crop_tenor) 138 | 139 | id_compare_values_array = np.array(id_compare_values).transpose(1,0) 140 | min_indexs = np.argmin(id_compare_values_array,axis=0) 141 | min_value = np.min(id_compare_values_array,axis=0) 142 | 143 | swap_result_list = [] 144 | swap_result_matrix_list = [] 145 | swap_result_ori_pic_list = [] 146 | 147 | for tmp_index, min_index in enumerate(min_indexs): 148 | if min_value[tmp_index] < opt.id_thres: 149 | swap_result = model(None, b_align_crop_tenor_list[tmp_index], target_id_norm_list[min_index], None, True)[0] 150 | swap_result_list.append(swap_result) 151 | swap_result_matrix_list.append(b_mat_list[tmp_index]) 152 | swap_result_ori_pic_list.append(b_align_crop_tenor_list[tmp_index]) 153 | else: 154 | pass 155 | 156 | if len(swap_result_list) !=0: 157 | 158 | if opt.use_mask: 159 | n_classes = 19 160 | net = BiSeNet(n_classes=n_classes) 161 | net.cuda() 162 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth') 163 | net.load_state_dict(torch.load(save_pth)) 164 | net.eval() 165 | else: 166 | net =None 167 | 168 | reverse2wholeimage(swap_result_ori_pic_list, swap_result_list, swap_result_matrix_list, crop_size, img_b_whole, logoclass,\ 169 | os.path.join(opt.output_path, 'result_whole_swap_multispecific.jpg'), opt.no_simswaplogo,pasring_model =net,use_mask=opt.use_mask, norm = spNorm) 170 | 171 | print(' ') 172 | 173 | print('************ Done ! ************') 174 | 175 | else: 176 | print('The people you specified are not found on the picture: {}'.format(pic_b)) 177 | -------------------------------------------------------------------------------- /test_wholeimage_swapmulti.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 19:19:26 7 | Description: 8 | ''' 9 | 10 | import cv2 11 | import torch 12 | import fractions 13 | import numpy as np 14 | from PIL import Image 15 | import torch.nn.functional as F 16 | from torchvision import transforms 17 | from models.models import create_model 18 | from options.test_options import TestOptions 19 | from insightface_func.face_detect_crop_multi import Face_detect_crop 20 | from util.reverse2original import reverse2wholeimage 21 | import os 22 | from util.add_watermark import watermark_image 23 | from util.norm import SpecificNorm 24 | from parsing_model.model import BiSeNet 25 | 26 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 27 | 28 | transformer_Arcface = transforms.Compose([ 29 | transforms.ToTensor(), 30 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 31 | ]) 32 | 33 | def _totensor(array): 34 | tensor = torch.from_numpy(array) 35 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous() 36 | return img.float().div(255) 37 | 38 | if __name__ == '__main__': 39 | opt = TestOptions().parse() 40 | 41 | start_epoch, epoch_iter = 1, 0 42 | crop_size = opt.crop_size 43 | 44 | torch.nn.Module.dump_patches = True 45 | if crop_size == 512: 46 | opt.which_epoch = 550000 47 | opt.name = '512' 48 | mode = 'ffhq' 49 | else: 50 | mode = 'None' 51 | logoclass = watermark_image('./simswaplogo/simswaplogo.png') 52 | model = create_model(opt) 53 | model.eval() 54 | spNorm =SpecificNorm() 55 | 56 | app = Face_detect_crop(name='antelope', root='./insightface_func/models') 57 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode) 58 | 59 | with torch.no_grad(): 60 | pic_a = opt.pic_a_path 61 | 62 | img_a_whole = cv2.imread(pic_a) 63 | img_a_align_crop, _ = app.get(img_a_whole,crop_size) 64 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB)) 65 | img_a = transformer_Arcface(img_a_align_crop_pil) 66 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) 67 | 68 | # convert numpy to tensor 69 | img_id = img_id.cuda() 70 | 71 | #create latent id 72 | img_id_downsample = F.interpolate(img_id, size=(112,112)) 73 | latend_id = model.netArc(img_id_downsample) 74 | latend_id = F.normalize(latend_id, p=2, dim=1) 75 | 76 | 77 | ############## Forward Pass ###################### 78 | 79 | pic_b = opt.pic_b_path 80 | img_b_whole = cv2.imread(pic_b) 81 | 82 | img_b_align_crop_list, b_mat_list = app.get(img_b_whole,crop_size) 83 | # detect_results = None 84 | swap_result_list = [] 85 | b_align_crop_tenor_list = [] 86 | 87 | for b_align_crop in img_b_align_crop_list: 88 | 89 | b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda() 90 | 91 | swap_result = model(None, b_align_crop_tenor, latend_id, None, True)[0] 92 | swap_result_list.append(swap_result) 93 | b_align_crop_tenor_list.append(b_align_crop_tenor) 94 | 95 | 96 | if opt.use_mask: 97 | n_classes = 19 98 | net = BiSeNet(n_classes=n_classes) 99 | net.cuda() 100 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth') 101 | net.load_state_dict(torch.load(save_pth)) 102 | net.eval() 103 | else: 104 | net =None 105 | 106 | reverse2wholeimage(b_align_crop_tenor_list,swap_result_list, b_mat_list, crop_size, img_b_whole, logoclass, \ 107 | os.path.join(opt.output_path, 'result_whole_swapmulti.jpg'),opt.no_simswaplogo,pasring_model =net,use_mask=opt.use_mask, norm = spNorm) 108 | print(' ') 109 | 110 | print('************ Done ! ************') 111 | -------------------------------------------------------------------------------- /test_wholeimage_swapsingle.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 19:19:43 7 | Description: 8 | ''' 9 | 10 | import cv2 11 | import torch 12 | import fractions 13 | import numpy as np 14 | from PIL import Image 15 | import torch.nn.functional as F 16 | from torchvision import transforms 17 | from models.models import create_model 18 | from options.test_options import TestOptions 19 | from insightface_func.face_detect_crop_single import Face_detect_crop 20 | from util.reverse2original import reverse2wholeimage 21 | import os 22 | from util.add_watermark import watermark_image 23 | from util.norm import SpecificNorm 24 | from parsing_model.model import BiSeNet 25 | 26 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 27 | 28 | transformer_Arcface = transforms.Compose([ 29 | transforms.ToTensor(), 30 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 31 | ]) 32 | 33 | def _totensor(array): 34 | tensor = torch.from_numpy(array) 35 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous() 36 | return img.float().div(255) 37 | if __name__ == '__main__': 38 | opt = TestOptions().parse() 39 | 40 | start_epoch, epoch_iter = 1, 0 41 | crop_size = opt.crop_size 42 | 43 | torch.nn.Module.dump_patches = True 44 | if crop_size == 512: 45 | opt.which_epoch = 550000 46 | opt.name = '512' 47 | mode = 'ffhq' 48 | else: 49 | mode = 'None' 50 | logoclass = watermark_image('./simswaplogo/simswaplogo.png') 51 | model = create_model(opt) 52 | model.eval() 53 | 54 | spNorm =SpecificNorm() 55 | app = Face_detect_crop(name='antelope', root='./insightface_func/models') 56 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode) 57 | 58 | with torch.no_grad(): 59 | pic_a = opt.pic_a_path 60 | 61 | img_a_whole = cv2.imread(pic_a) 62 | img_a_align_crop, _ = app.get(img_a_whole,crop_size) 63 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB)) 64 | img_a = transformer_Arcface(img_a_align_crop_pil) 65 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) 66 | 67 | # convert numpy to tensor 68 | img_id = img_id.cuda() 69 | 70 | #create latent id 71 | img_id_downsample = F.interpolate(img_id, size=(112,112)) 72 | latend_id = model.netArc(img_id_downsample) 73 | latend_id = F.normalize(latend_id, p=2, dim=1) 74 | 75 | 76 | ############## Forward Pass ###################### 77 | 78 | pic_b = opt.pic_b_path 79 | img_b_whole = cv2.imread(pic_b) 80 | 81 | img_b_align_crop_list, b_mat_list = app.get(img_b_whole,crop_size) 82 | # detect_results = None 83 | swap_result_list = [] 84 | 85 | b_align_crop_tenor_list = [] 86 | 87 | for b_align_crop in img_b_align_crop_list: 88 | 89 | b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda() 90 | 91 | swap_result = model(None, b_align_crop_tenor, latend_id, None, True)[0] 92 | swap_result_list.append(swap_result) 93 | b_align_crop_tenor_list.append(b_align_crop_tenor) 94 | 95 | if opt.use_mask: 96 | n_classes = 19 97 | net = BiSeNet(n_classes=n_classes) 98 | net.cuda() 99 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth') 100 | net.load_state_dict(torch.load(save_pth)) 101 | net.eval() 102 | else: 103 | net =None 104 | 105 | reverse2wholeimage(b_align_crop_tenor_list, swap_result_list, b_mat_list, crop_size, img_b_whole, logoclass, \ 106 | os.path.join(opt.output_path, 'result_whole_swapsingle.jpg'), opt.no_simswaplogo,pasring_model =net,use_mask=opt.use_mask, norm = spNorm) 107 | 108 | print(' ') 109 | 110 | print('************ Done ! ************') 111 | -------------------------------------------------------------------------------- /test_wholeimage_swapspecific.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 19:19:47 7 | Description: 8 | ''' 9 | 10 | import cv2 11 | import torch 12 | import fractions 13 | import numpy as np 14 | from PIL import Image 15 | import torch.nn.functional as F 16 | from torchvision import transforms 17 | from models.models import create_model 18 | from options.test_options import TestOptions 19 | from insightface_func.face_detect_crop_multi import Face_detect_crop 20 | from util.reverse2original import reverse2wholeimage 21 | import os 22 | from util.add_watermark import watermark_image 23 | import torch.nn as nn 24 | from util.norm import SpecificNorm 25 | from parsing_model.model import BiSeNet 26 | 27 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 28 | 29 | transformer_Arcface = transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 32 | ]) 33 | 34 | def _totensor(array): 35 | tensor = torch.from_numpy(array) 36 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous() 37 | return img.float().div(255) 38 | 39 | def _toarctensor(array): 40 | tensor = torch.from_numpy(array) 41 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous() 42 | return img.float().div(255) 43 | 44 | if __name__ == '__main__': 45 | opt = TestOptions().parse() 46 | 47 | start_epoch, epoch_iter = 1, 0 48 | crop_size = opt.crop_size 49 | 50 | torch.nn.Module.dump_patches = True 51 | if crop_size == 512: 52 | opt.which_epoch = 550000 53 | opt.name = '512' 54 | mode = 'ffhq' 55 | else: 56 | mode = 'None' 57 | logoclass = watermark_image('./simswaplogo/simswaplogo.png') 58 | model = create_model(opt) 59 | model.eval() 60 | mse = torch.nn.MSELoss().cuda() 61 | 62 | spNorm =SpecificNorm() 63 | 64 | 65 | app = Face_detect_crop(name='antelope', root='./insightface_func/models') 66 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode) 67 | 68 | pic_a = opt.pic_a_path 69 | pic_specific = opt.pic_specific_path 70 | 71 | # The person who provides id information 72 | img_a_whole = cv2.imread(pic_a) 73 | img_a_align_crop, _ = app.get(img_a_whole,crop_size) 74 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB)) 75 | img_a = transformer_Arcface(img_a_align_crop_pil) 76 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) 77 | 78 | # convert numpy to tensor 79 | img_id = img_id.cuda() 80 | 81 | #create latent id 82 | img_id_downsample = F.interpolate(img_id, size=(112,112)) 83 | latend_id = model.netArc(img_id_downsample) 84 | latend_id = F.normalize(latend_id, p=2, dim=1) 85 | 86 | 87 | # The specific person to be swapped 88 | specific_person_whole = cv2.imread(pic_specific) 89 | specific_person_align_crop, _ = app.get(specific_person_whole,crop_size) 90 | specific_person_align_crop_pil = Image.fromarray(cv2.cvtColor(specific_person_align_crop[0],cv2.COLOR_BGR2RGB)) 91 | specific_person = transformer_Arcface(specific_person_align_crop_pil) 92 | specific_person = specific_person.view(-1, specific_person.shape[0], specific_person.shape[1], specific_person.shape[2]) 93 | 94 | # convert numpy to tensor 95 | specific_person = specific_person.cuda() 96 | 97 | #create latent id 98 | specific_person_downsample = F.interpolate(specific_person, size=(112,112)) 99 | specific_person_id_nonorm = model.netArc(specific_person_downsample) 100 | # specific_person_id_norm = F.normalize(specific_person_id_nonorm, p=2, dim=1) 101 | 102 | ############## Forward Pass ###################### 103 | 104 | pic_b = opt.pic_b_path 105 | img_b_whole = cv2.imread(pic_b) 106 | 107 | img_b_align_crop_list, b_mat_list = app.get(img_b_whole,crop_size) 108 | # detect_results = None 109 | swap_result_list = [] 110 | 111 | id_compare_values = [] 112 | b_align_crop_tenor_list = [] 113 | for b_align_crop in img_b_align_crop_list: 114 | 115 | b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda() 116 | 117 | b_align_crop_tenor_arcnorm = spNorm(b_align_crop_tenor) 118 | b_align_crop_tenor_arcnorm_downsample = F.interpolate(b_align_crop_tenor_arcnorm, size=(112,112)) 119 | b_align_crop_id_nonorm = model.netArc(b_align_crop_tenor_arcnorm_downsample) 120 | 121 | id_compare_values.append(mse(b_align_crop_id_nonorm,specific_person_id_nonorm).detach().cpu().numpy()) 122 | b_align_crop_tenor_list.append(b_align_crop_tenor) 123 | 124 | id_compare_values_array = np.array(id_compare_values) 125 | min_index = np.argmin(id_compare_values_array) 126 | min_value = id_compare_values_array[min_index] 127 | 128 | if opt.use_mask: 129 | n_classes = 19 130 | net = BiSeNet(n_classes=n_classes) 131 | net.cuda() 132 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth') 133 | net.load_state_dict(torch.load(save_pth)) 134 | net.eval() 135 | else: 136 | net =None 137 | 138 | if min_value < opt.id_thres: 139 | 140 | swap_result = model(None, b_align_crop_tenor_list[min_index], latend_id, None, True)[0] 141 | 142 | reverse2wholeimage([b_align_crop_tenor_list[min_index]], [swap_result], [b_mat_list[min_index]], crop_size, img_b_whole, logoclass, \ 143 | os.path.join(opt.output_path, 'result_whole_swapspecific.jpg'), opt.no_simswaplogo,pasring_model =net,use_mask=opt.use_mask, norm = spNorm) 144 | 145 | print(' ') 146 | 147 | print('************ Done ! ************') 148 | 149 | else: 150 | print('The person you specified is not found on the picture: {}'.format(pic_b)) 151 | -------------------------------------------------------------------------------- /util/add_watermark.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | import math 6 | import numpy as np 7 | # import torch 8 | # from torchvision import transforms 9 | 10 | def rotate_image(image, angle, center = None, scale = 1.0): 11 | (h, w) = image.shape[:2] 12 | 13 | if center is None: 14 | center = (w / 2, h / 2) 15 | 16 | # Perform the rotation 17 | M = cv2.getRotationMatrix2D(center, angle, scale) 18 | rotated = cv2.warpAffine(image, M, (w, h)) 19 | 20 | return rotated 21 | 22 | class watermark_image: 23 | def __init__(self, logo_path, size=0.3, oritation="DR", margin=(5,20,20,20), angle=15, rgb_weight=(0,1,1.5), input_frame_shape=None) -> None: 24 | 25 | logo_image = cv2.imread(logo_path, cv2.IMREAD_UNCHANGED) 26 | h,w,c = logo_image.shape 27 | if angle%360 != 0: 28 | new_h = w*math.sin(angle/180*math.pi) + h*math.cos(angle/180*math.pi) 29 | pad_h = int((new_h-h)//2) 30 | 31 | padding = np.zeros((pad_h, w, c), dtype=np.uint8) 32 | logo_image = cv2.vconcat([logo_image, padding]) 33 | logo_image = cv2.vconcat([padding, logo_image]) 34 | 35 | logo_image = rotate_image(logo_image, angle) 36 | print(logo_image.shape) 37 | self.logo_image = logo_image 38 | 39 | if self.logo_image.shape[2] < 4: 40 | print("No alpha channel found!") 41 | self.logo_image = self.__addAlpha__(self.logo_image) #add alpha channel 42 | self.size = size 43 | self.oritation = oritation 44 | self.margin = margin 45 | self.ori_shape = self.logo_image.shape 46 | self.resized = False 47 | self.rgb_weight = rgb_weight 48 | 49 | self.logo_image[:, :, 2] = self.logo_image[:, :, 2]*self.rgb_weight[0] 50 | self.logo_image[:, :, 1] = self.logo_image[:, :, 1]*self.rgb_weight[1] 51 | self.logo_image[:, :, 0] = self.logo_image[:, :, 0]*self.rgb_weight[2] 52 | 53 | if input_frame_shape is not None: 54 | 55 | logo_w = input_frame_shape[1] * self.size 56 | ratio = logo_w / self.ori_shape[1] 57 | logo_h = int(ratio * self.ori_shape[0]) 58 | logo_w = int(logo_w) 59 | 60 | size = (logo_w, logo_h) 61 | self.logo_image = cv2.resize(self.logo_image, size, interpolation = cv2.INTER_CUBIC) 62 | self.resized = True 63 | if oritation == "UL": 64 | self.coor_h = self.margin[1] 65 | self.coor_w = self.margin[0] 66 | elif oritation == "UR": 67 | self.coor_h = self.margin[1] 68 | self.coor_w = input_frame_shape[1] - (logo_w + self.margin[2]) 69 | elif oritation == "DL": 70 | self.coor_h = input_frame_shape[0] - (logo_h + self.margin[1]) 71 | self.coor_w = self.margin[0] 72 | else: 73 | self.coor_h = input_frame_shape[0] - (logo_h + self.margin[3]) 74 | self.coor_w = input_frame_shape[1] - (logo_w + self.margin[2]) 75 | self.logo_w = logo_w 76 | self.logo_h = logo_h 77 | self.mask = self.logo_image[:,:,3] 78 | self.mask = cv2.bitwise_not(self.mask//255) 79 | 80 | def apply_frames(self, frame): 81 | 82 | if not self.resized: 83 | shape = frame.shape 84 | logo_w = shape[1] * self.size 85 | ratio = logo_w / self.ori_shape[1] 86 | logo_h = int(ratio * self.ori_shape[0]) 87 | logo_w = int(logo_w) 88 | 89 | size = (logo_w, logo_h) 90 | self.logo_image = cv2.resize(self.logo_image, size, interpolation = cv2.INTER_CUBIC) 91 | self.resized = True 92 | if self.oritation == "UL": 93 | self.coor_h = self.margin[1] 94 | self.coor_w = self.margin[0] 95 | elif self.oritation == "UR": 96 | self.coor_h = self.margin[1] 97 | self.coor_w = shape[1] - (logo_w + self.margin[2]) 98 | elif self.oritation == "DL": 99 | self.coor_h = shape[0] - (logo_h + self.margin[1]) 100 | self.coor_w = self.margin[0] 101 | else: 102 | self.coor_h = shape[0] - (logo_h + self.margin[3]) 103 | self.coor_w = shape[1] - (logo_w + self.margin[2]) 104 | self.logo_w = logo_w 105 | self.logo_h = logo_h 106 | self.mask = self.logo_image[:,:,3] 107 | self.mask = cv2.bitwise_not(self.mask//255) 108 | 109 | original_frame = frame[self.coor_h:(self.coor_h+self.logo_h), self.coor_w:(self.coor_w+self.logo_w),:] 110 | blending_logo = cv2.add(self.logo_image[:,:,0:3],original_frame,mask = self.mask) 111 | frame[self.coor_h:(self.coor_h+self.logo_h), self.coor_w:(self.coor_w+self.logo_w),:] = blending_logo 112 | return frame 113 | 114 | def __addAlpha__(self, image): 115 | shape = image.shape 116 | alpha_channel = np.ones((shape[0],shape[1],1),np.uint8)*255 117 | return np.concatenate((image,alpha_channel),2) 118 | 119 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, refresh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | 16 | self.doc = dominate.document(title=title) 17 | if refresh > 0: 18 | with self.doc.head: 19 | meta(http_equiv="refresh", content=str(refresh)) 20 | 21 | def get_image_dir(self): 22 | return self.img_dir 23 | 24 | def add_header(self, str): 25 | with self.doc: 26 | h3(str) 27 | 28 | def add_table(self, border=1): 29 | self.t = table(border=border, style="table-layout: fixed;") 30 | self.doc.add(self.t) 31 | 32 | def add_images(self, ims, txts, links, width=512): 33 | self.add_table() 34 | with self.t: 35 | with tr(): 36 | for im, txt, link in zip(ims, txts, links): 37 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 38 | with p(): 39 | with a(href=os.path.join('images', link)): 40 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 41 | br() 42 | p(txt) 43 | 44 | def save(self): 45 | html_file = '%s/index.html' % self.web_dir 46 | f = open(html_file, 'wt') 47 | f.write(self.doc.render()) 48 | f.close() 49 | 50 | 51 | if __name__ == '__main__': 52 | html = HTML('web/', 'test_html') 53 | html.add_header('hello world') 54 | 55 | ims = [] 56 | txts = [] 57 | links = [] 58 | for n in range(4): 59 | ims.append('image_%d.jpg' % n) 60 | txts.append('text_%d' % n) 61 | links.append('image_%d.jpg' % n) 62 | html.add_images(ims, txts, links) 63 | html.save() 64 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.autograd import Variable 4 | class ImagePool(): 5 | def __init__(self, pool_size): 6 | self.pool_size = pool_size 7 | if self.pool_size > 0: 8 | self.num_imgs = 0 9 | self.images = [] 10 | 11 | def query(self, images): 12 | if self.pool_size == 0: 13 | return images 14 | return_images = [] 15 | for image in images.data: 16 | image = torch.unsqueeze(image, 0) 17 | if self.num_imgs < self.pool_size: 18 | self.num_imgs = self.num_imgs + 1 19 | self.images.append(image) 20 | return_images.append(image) 21 | else: 22 | p = random.uniform(0, 1) 23 | if p > 0.5: 24 | random_id = random.randint(0, self.pool_size-1) 25 | tmp = self.images[random_id].clone() 26 | self.images[random_id] = image 27 | return_images.append(tmp) 28 | else: 29 | return_images.append(image) 30 | return_images = Variable(torch.cat(return_images, 0)) 31 | return return_images 32 | -------------------------------------------------------------------------------- /util/json_config.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def readConfig(path): 5 | with open(path,'r') as cf: 6 | nodelocaltionstr = cf.read() 7 | nodelocaltioninf = json.loads(nodelocaltionstr) 8 | if isinstance(nodelocaltioninf,str): 9 | nodelocaltioninf = json.loads(nodelocaltioninf) 10 | return nodelocaltioninf 11 | 12 | def writeConfig(path, info): 13 | with open(path, 'w') as cf: 14 | configjson = json.dumps(info, indent=4) 15 | cf.writelines(configjson) -------------------------------------------------------------------------------- /util/logo_class.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | ############################################################# 4 | # File: logo_class.py 5 | # Created Date: Tuesday June 29th 2021 6 | # Author: Chen Xuanhong 7 | # Email: chenxuanhongzju@outlook.com 8 | # Last Modified: Monday, 11th October 2021 12:39:55 am 9 | # Modified By: Chen Xuanhong 10 | # Copyright (c) 2021 Shanghai Jiao Tong University 11 | ############################################################# 12 | 13 | class logo_class: 14 | 15 | @staticmethod 16 | def print_group_logo(): 17 | logo_str = """ 18 | 19 | ███╗ ██╗██████╗ ███████╗██╗ ██████╗ ███████╗ ██╗████████╗██╗ ██╗ 20 | ████╗ ██║██╔══██╗██╔════╝██║██╔════╝ ██╔════╝ ██║╚══██╔══╝██║ ██║ 21 | ██╔██╗ ██║██████╔╝███████╗██║██║ ███╗ ███████╗ ██║ ██║ ██║ ██║ 22 | ██║╚██╗██║██╔══██╗╚════██║██║██║ ██║ ╚════██║██ ██║ ██║ ██║ ██║ 23 | ██║ ╚████║██║ ██║███████║██║╚██████╔╝ ███████║╚█████╔╝ ██║ ╚██████╔╝ 24 | ╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═════╝ ╚══════╝ ╚════╝ ╚═╝ ╚═════╝ 25 | Neural Rendering Special Interesting Group of SJTU 26 | 27 | """ 28 | print(logo_str) 29 | 30 | @staticmethod 31 | def print_start_training(): 32 | logo_str = """ 33 | _____ __ __ ______ _ _ 34 | / ___/ / /_ ____ _ _____ / /_ /_ __/_____ ____ _ (_)____ (_)____ ____ _ 35 | \__ \ / __// __ `// ___// __/ / / / ___// __ `// // __ \ / // __ \ / __ `/ 36 | ___/ // /_ / /_/ // / / /_ / / / / / /_/ // // / / // // / / // /_/ / 37 | /____/ \__/ \__,_//_/ \__/ /_/ /_/ \__,_//_//_/ /_//_//_/ /_/ \__, / 38 | /____/ 39 | """ 40 | print(logo_str) 41 | 42 | if __name__=="__main__": 43 | # logo_class.print_group_logo() 44 | logo_class.print_start_training() -------------------------------------------------------------------------------- /util/norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | class SpecificNorm(nn.Module): 5 | def __init__(self, epsilon=1e-8): 6 | """ 7 | @notice: avoid in-place ops. 8 | https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 9 | """ 10 | super(SpecificNorm, self).__init__() 11 | self.mean = np.array([0.485, 0.456, 0.406]) 12 | self.mean = torch.from_numpy(self.mean).float().cuda() 13 | self.mean = self.mean.view([1, 3, 1, 1]) 14 | 15 | self.std = np.array([0.229, 0.224, 0.225]) 16 | self.std = torch.from_numpy(self.std).float().cuda() 17 | self.std = self.std.view([1, 3, 1, 1]) 18 | 19 | def forward(self, x): 20 | mean = self.mean.expand([1, 3, x.shape[2], x.shape[3]]) 21 | std = self.std.expand([1, 3, x.shape[2], x.shape[3]]) 22 | 23 | x = (x - mean) / std 24 | 25 | return x -------------------------------------------------------------------------------- /util/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import PIL 4 | 5 | def postprocess(x): 6 | """[0,1] to uint8.""" 7 | 8 | x = np.clip(255 * x, 0, 255) 9 | x = np.cast[np.uint8](x) 10 | return x 11 | 12 | def tile(X, rows, cols): 13 | """Tile images for display.""" 14 | tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype) 15 | for i in range(rows): 16 | for j in range(cols): 17 | idx = i * cols + j 18 | if idx < X.shape[0]: 19 | img = X[idx,...] 20 | tiling[ 21 | i*X.shape[1]:(i+1)*X.shape[1], 22 | j*X.shape[2]:(j+1)*X.shape[2], 23 | :] = img 24 | return tiling 25 | 26 | 27 | def plot_batch(X, out_path): 28 | """Save batch of images tiled.""" 29 | n_channels = X.shape[3] 30 | if n_channels > 3: 31 | X = X[:,:,:,np.random.choice(n_channels, size = 3)] 32 | X = postprocess(X) 33 | rc = math.sqrt(X.shape[0]) 34 | rows = cols = math.ceil(rc) 35 | canvas = tile(X, rows, cols) 36 | canvas = np.squeeze(canvas) 37 | PIL.Image.fromarray(canvas).save(out_path) -------------------------------------------------------------------------------- /util/reverse2original.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | # import time 4 | import torch 5 | from torch.nn import functional as F 6 | import torch.nn as nn 7 | 8 | 9 | def encode_segmentation_rgb(segmentation, no_neck=True): 10 | parse = segmentation 11 | 12 | face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14] 13 | mouth_id = 11 14 | # hair_id = 17 15 | face_map = np.zeros([parse.shape[0], parse.shape[1]]) 16 | mouth_map = np.zeros([parse.shape[0], parse.shape[1]]) 17 | # hair_map = np.zeros([parse.shape[0], parse.shape[1]]) 18 | 19 | for valid_id in face_part_ids: 20 | valid_index = np.where(parse==valid_id) 21 | face_map[valid_index] = 255 22 | valid_index = np.where(parse==mouth_id) 23 | mouth_map[valid_index] = 255 24 | # valid_index = np.where(parse==hair_id) 25 | # hair_map[valid_index] = 255 26 | #return np.stack([face_map, mouth_map,hair_map], axis=2) 27 | return np.stack([face_map, mouth_map], axis=2) 28 | 29 | 30 | class SoftErosion(nn.Module): 31 | def __init__(self, kernel_size=15, threshold=0.6, iterations=1): 32 | super(SoftErosion, self).__init__() 33 | r = kernel_size // 2 34 | self.padding = r 35 | self.iterations = iterations 36 | self.threshold = threshold 37 | 38 | # Create kernel 39 | y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size)) 40 | dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2) 41 | kernel = dist.max() - dist 42 | kernel /= kernel.sum() 43 | kernel = kernel.view(1, 1, *kernel.shape) 44 | self.register_buffer('weight', kernel) 45 | 46 | def forward(self, x): 47 | x = x.float() 48 | for i in range(self.iterations - 1): 49 | x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)) 50 | x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding) 51 | 52 | mask = x >= self.threshold 53 | x[mask] = 1.0 54 | x[~mask] /= x[~mask].max() 55 | 56 | return x, mask 57 | 58 | 59 | def postprocess(swapped_face, target, target_mask,smooth_mask): 60 | # target_mask = cv2.resize(target_mask, (self.size, self.size)) 61 | 62 | mask_tensor = torch.from_numpy(target_mask.copy().transpose((2, 0, 1))).float().mul_(1/255.0).cuda() 63 | face_mask_tensor = mask_tensor[0] + mask_tensor[1] 64 | 65 | soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0)) 66 | soft_face_mask_tensor.squeeze_() 67 | 68 | soft_face_mask = soft_face_mask_tensor.cpu().numpy() 69 | soft_face_mask = soft_face_mask[:, :, np.newaxis] 70 | 71 | result = swapped_face * soft_face_mask + target * (1 - soft_face_mask) 72 | result = result[:,:,::-1]# .astype(np.uint8) 73 | return result 74 | 75 | def reverse2wholeimage(b_align_crop_tenor_list,swaped_imgs, mats, crop_size, oriimg, logoclass, save_path = '', \ 76 | no_simswaplogo = False,pasring_model =None,norm = None, use_mask = False): 77 | 78 | target_image_list = [] 79 | img_mask_list = [] 80 | if use_mask: 81 | smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).cuda() 82 | else: 83 | pass 84 | 85 | # print(len(swaped_imgs)) 86 | # print(mats) 87 | # print(len(b_align_crop_tenor_list)) 88 | for swaped_img, mat ,source_img in zip(swaped_imgs, mats,b_align_crop_tenor_list): 89 | swaped_img = swaped_img.cpu().detach().numpy().transpose((1, 2, 0)) 90 | img_white = np.full((crop_size,crop_size), 255, dtype=float) 91 | 92 | # inverse the Affine transformation matrix 93 | mat_rev = np.zeros([2,3]) 94 | div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0] 95 | mat_rev[0][0] = mat[1][1]/div1 96 | mat_rev[0][1] = -mat[0][1]/div1 97 | mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1 98 | div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1] 99 | mat_rev[1][0] = mat[1][0]/div2 100 | mat_rev[1][1] = -mat[0][0]/div2 101 | mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 102 | 103 | orisize = (oriimg.shape[1], oriimg.shape[0]) 104 | if use_mask: 105 | source_img_norm = norm(source_img) 106 | source_img_512 = F.interpolate(source_img_norm,size=(512,512)) 107 | out = pasring_model(source_img_512)[0] 108 | parsing = out.squeeze(0).detach().cpu().numpy().argmax(0) 109 | vis_parsing_anno = parsing.copy().astype(np.uint8) 110 | tgt_mask = encode_segmentation_rgb(vis_parsing_anno) 111 | if tgt_mask.sum() >= 5000: 112 | # face_mask_tensor = tgt_mask[...,0] + tgt_mask[...,1] 113 | target_mask = cv2.resize(tgt_mask, (crop_size, crop_size)) 114 | # print(source_img) 115 | target_image_parsing = postprocess(swaped_img, source_img[0].cpu().detach().numpy().transpose((1, 2, 0)), target_mask,smooth_mask) 116 | 117 | 118 | target_image = cv2.warpAffine(target_image_parsing, mat_rev, orisize) 119 | # target_image_parsing = cv2.warpAffine(swaped_img, mat_rev, orisize) 120 | else: 121 | target_image = cv2.warpAffine(swaped_img, mat_rev, orisize)[..., ::-1] 122 | else: 123 | target_image = cv2.warpAffine(swaped_img, mat_rev, orisize) 124 | # source_image = cv2.warpAffine(source_img, mat_rev, orisize) 125 | 126 | img_white = cv2.warpAffine(img_white, mat_rev, orisize) 127 | 128 | 129 | img_white[img_white>20] =255 130 | 131 | img_mask = img_white 132 | 133 | # if use_mask: 134 | # kernel = np.ones((40,40),np.uint8) 135 | # img_mask = cv2.erode(img_mask,kernel,iterations = 1) 136 | # else: 137 | kernel = np.ones((40,40),np.uint8) 138 | img_mask = cv2.erode(img_mask,kernel,iterations = 1) 139 | kernel_size = (20, 20) 140 | blur_size = tuple(2*i+1 for i in kernel_size) 141 | img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) 142 | 143 | # kernel = np.ones((10,10),np.uint8) 144 | # img_mask = cv2.erode(img_mask,kernel,iterations = 1) 145 | 146 | 147 | 148 | img_mask /= 255 149 | 150 | img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) 151 | 152 | # pasing mask 153 | 154 | # target_image_parsing = postprocess(target_image, source_image, tgt_mask) 155 | 156 | if use_mask: 157 | target_image = np.array(target_image, dtype=np.float) * 255 158 | else: 159 | target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 160 | 161 | 162 | img_mask_list.append(img_mask) 163 | target_image_list.append(target_image) 164 | 165 | 166 | # target_image /= 255 167 | # target_image = 0 168 | img = np.array(oriimg, dtype=np.float) 169 | for img_mask, target_image in zip(img_mask_list, target_image_list): 170 | img = img_mask * target_image + (1-img_mask) * img 171 | 172 | final_img = img.astype(np.uint8) 173 | if not no_simswaplogo: 174 | final_img = logoclass.apply_frames(final_img) 175 | cv2.imwrite(save_path, final_img) 176 | -------------------------------------------------------------------------------- /util/save_heatmap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | ############################################################# 4 | # File: save_heatmap.py 5 | # Created Date: Friday January 15th 2021 6 | # Author: Chen Xuanhong 7 | # Email: chenxuanhongzju@outlook.com 8 | # Last Modified: Wednesday, 19th January 2022 1:22:47 am 9 | # Modified By: Chen Xuanhong 10 | # Copyright (c) 2021 Shanghai Jiao Tong University 11 | ############################################################# 12 | 13 | import os 14 | import shutil 15 | import seaborn as sns 16 | import matplotlib.pyplot as plt 17 | import cv2 18 | import numpy as np 19 | 20 | def SaveHeatmap(heatmaps, path, row=-1, dpi=72): 21 | """ 22 | The input tensor must be B X 1 X H X W 23 | """ 24 | batch_size = heatmaps.shape[0] 25 | temp_path = ".temp/" 26 | if not os.path.exists(temp_path): 27 | os.makedirs(temp_path) 28 | final_img = None 29 | if row < 1: 30 | col = batch_size 31 | row = 1 32 | else: 33 | col = batch_size // row 34 | if row * col = col: 51 | col_i = 0 52 | row_i += 1 53 | cv2.imwrite(path,final_img) 54 | 55 | if __name__ == "__main__": 56 | random_map = np.random.randn(16,1,10,10) 57 | SaveHeatmap(random_map,"./wocao.png",1) 58 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | 8 | # Converts a Tensor into a Numpy array 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True): 11 | if isinstance(image_tensor, list): 12 | image_numpy = [] 13 | for i in range(len(image_tensor)): 14 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 15 | return image_numpy 16 | image_numpy = image_tensor.cpu().float().numpy() 17 | if normalize: 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | else: 20 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 21 | image_numpy = np.clip(image_numpy, 0, 255) 22 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: 23 | image_numpy = image_numpy[:,:,0] 24 | return image_numpy.astype(imtype) 25 | 26 | # Converts a one-hot tensor into a colorful label map 27 | def tensor2label(label_tensor, n_label, imtype=np.uint8): 28 | if n_label == 0: 29 | return tensor2im(label_tensor, imtype) 30 | label_tensor = label_tensor.cpu().float() 31 | if label_tensor.size()[0] > 1: 32 | label_tensor = label_tensor.max(0, keepdim=True)[1] 33 | label_tensor = Colorize(n_label)(label_tensor) 34 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 35 | return label_numpy.astype(imtype) 36 | 37 | def save_image(image_numpy, image_path): 38 | image_pil = Image.fromarray(image_numpy) 39 | image_pil.save(image_path) 40 | 41 | def mkdirs(paths): 42 | if isinstance(paths, list) and not isinstance(paths, str): 43 | for path in paths: 44 | mkdir(path) 45 | else: 46 | mkdir(paths) 47 | 48 | def mkdir(path): 49 | if not os.path.exists(path): 50 | os.makedirs(path) 51 | 52 | ############################################################################### 53 | # Code from 54 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 55 | # Modified so it complies with the Citscape label map colors 56 | ############################################################################### 57 | def uint82bin(n, count=8): 58 | """returns the binary of integer n, count refers to amount of bits""" 59 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 60 | 61 | def labelcolormap(N): 62 | if N == 35: # cityscape 63 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81), 64 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), 65 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0), 66 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), 67 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)], 68 | dtype=np.uint8) 69 | else: 70 | cmap = np.zeros((N, 3), dtype=np.uint8) 71 | for i in range(N): 72 | r, g, b = 0, 0, 0 73 | id = i 74 | for j in range(7): 75 | str_id = uint82bin(id) 76 | r = r ^ (np.uint8(str_id[-1]) << (7-j)) 77 | g = g ^ (np.uint8(str_id[-2]) << (7-j)) 78 | b = b ^ (np.uint8(str_id[-3]) << (7-j)) 79 | id = id >> 3 80 | cmap[i, 0] = r 81 | cmap[i, 1] = g 82 | cmap[i, 2] = b 83 | return cmap 84 | 85 | class Colorize(object): 86 | def __init__(self, n=35): 87 | self.cmap = labelcolormap(n) 88 | self.cmap = torch.from_numpy(self.cmap[:n]) 89 | 90 | def __call__(self, gray_image): 91 | size = gray_image.size() 92 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 93 | 94 | for label in range(0, len(self.cmap)): 95 | mask = (label == gray_image[0]).cpu() 96 | color_image[0][mask] = self.cmap[label][0] 97 | color_image[1][mask] = self.cmap[label][1] 98 | color_image[2][mask] = self.cmap[label][2] 99 | 100 | return color_image 101 | -------------------------------------------------------------------------------- /util/videoswap.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 19:19:52 7 | Description: 8 | ''' 9 | import os 10 | import cv2 11 | import glob 12 | import torch 13 | import shutil 14 | import numpy as np 15 | from tqdm import tqdm 16 | from util.reverse2original import reverse2wholeimage 17 | import moviepy.editor as mp 18 | from moviepy.editor import AudioFileClip, VideoFileClip 19 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip 20 | import time 21 | from util.add_watermark import watermark_image 22 | from util.norm import SpecificNorm 23 | from parsing_model.model import BiSeNet 24 | 25 | def _totensor(array): 26 | tensor = torch.from_numpy(array) 27 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous() 28 | return img.float().div(255) 29 | 30 | def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False): 31 | video_forcheck = VideoFileClip(video_path) 32 | if video_forcheck.audio is None: 33 | no_audio = True 34 | else: 35 | no_audio = False 36 | 37 | del video_forcheck 38 | 39 | if not no_audio: 40 | video_audio_clip = AudioFileClip(video_path) 41 | 42 | video = cv2.VideoCapture(video_path) 43 | logoclass = watermark_image('./simswaplogo/simswaplogo.png') 44 | ret = True 45 | frame_index = 0 46 | 47 | frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 48 | 49 | # video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) 50 | 51 | # video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 52 | 53 | fps = video.get(cv2.CAP_PROP_FPS) 54 | if os.path.exists(temp_results_dir): 55 | shutil.rmtree(temp_results_dir) 56 | 57 | spNorm =SpecificNorm() 58 | if use_mask: 59 | n_classes = 19 60 | net = BiSeNet(n_classes=n_classes) 61 | net.cuda() 62 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth') 63 | net.load_state_dict(torch.load(save_pth)) 64 | net.eval() 65 | else: 66 | net =None 67 | 68 | # while ret: 69 | for frame_index in tqdm(range(frame_count)): 70 | ret, frame = video.read() 71 | if ret: 72 | detect_results = detect_model.get(frame,crop_size) 73 | 74 | if detect_results is not None: 75 | # print(frame_index) 76 | if not os.path.exists(temp_results_dir): 77 | os.mkdir(temp_results_dir) 78 | frame_align_crop_list = detect_results[0] 79 | frame_mat_list = detect_results[1] 80 | swap_result_list = [] 81 | frame_align_crop_tenor_list = [] 82 | for frame_align_crop in frame_align_crop_list: 83 | 84 | # BGR TO RGB 85 | # frame_align_crop_RGB = frame_align_crop[...,::-1] 86 | 87 | frame_align_crop_tenor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda() 88 | 89 | swap_result = swap_model(None, frame_align_crop_tenor, id_vetor, None, True)[0] 90 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) 91 | swap_result_list.append(swap_result) 92 | frame_align_crop_tenor_list.append(frame_align_crop_tenor) 93 | 94 | 95 | 96 | reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame, logoclass,\ 97 | os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask=use_mask, norm = spNorm) 98 | 99 | else: 100 | if not os.path.exists(temp_results_dir): 101 | os.mkdir(temp_results_dir) 102 | frame = frame.astype(np.uint8) 103 | if not no_simswaplogo: 104 | frame = logoclass.apply_frames(frame) 105 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) 106 | else: 107 | break 108 | 109 | video.release() 110 | 111 | # image_filename_list = [] 112 | path = os.path.join(temp_results_dir,'*.jpg') 113 | image_filenames = sorted(glob.glob(path)) 114 | 115 | clips = ImageSequenceClip(image_filenames,fps = fps) 116 | 117 | if not no_audio: 118 | clips = clips.set_audio(video_audio_clip) 119 | 120 | 121 | clips.write_videofile(save_path,audio_codec='aac') 122 | 123 | -------------------------------------------------------------------------------- /util/videoswap_multispecific.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import torch 5 | import shutil 6 | import numpy as np 7 | from tqdm import tqdm 8 | from util.reverse2original import reverse2wholeimage 9 | import moviepy.editor as mp 10 | from moviepy.editor import AudioFileClip, VideoFileClip 11 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip 12 | import time 13 | from util.add_watermark import watermark_image 14 | from util.norm import SpecificNorm 15 | import torch.nn.functional as F 16 | from parsing_model.model import BiSeNet 17 | 18 | def _totensor(array): 19 | tensor = torch.from_numpy(array) 20 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous() 21 | return img.float().div(255) 22 | 23 | def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False): 24 | video_forcheck = VideoFileClip(video_path) 25 | if video_forcheck.audio is None: 26 | no_audio = True 27 | else: 28 | no_audio = False 29 | 30 | del video_forcheck 31 | 32 | if not no_audio: 33 | video_audio_clip = AudioFileClip(video_path) 34 | 35 | video = cv2.VideoCapture(video_path) 36 | logoclass = watermark_image('./simswaplogo/simswaplogo.png') 37 | ret = True 38 | frame_index = 0 39 | 40 | frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 41 | 42 | # video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) 43 | 44 | # video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 45 | 46 | fps = video.get(cv2.CAP_PROP_FPS) 47 | if os.path.exists(temp_results_dir): 48 | shutil.rmtree(temp_results_dir) 49 | 50 | spNorm =SpecificNorm() 51 | mse = torch.nn.MSELoss().cuda() 52 | 53 | if use_mask: 54 | n_classes = 19 55 | net = BiSeNet(n_classes=n_classes) 56 | net.cuda() 57 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth') 58 | net.load_state_dict(torch.load(save_pth)) 59 | net.eval() 60 | else: 61 | net =None 62 | 63 | # while ret: 64 | for frame_index in tqdm(range(frame_count)): 65 | ret, frame = video.read() 66 | if ret: 67 | detect_results = detect_model.get(frame,crop_size) 68 | 69 | if detect_results is not None: 70 | # print(frame_index) 71 | if not os.path.exists(temp_results_dir): 72 | os.mkdir(temp_results_dir) 73 | frame_align_crop_list = detect_results[0] 74 | frame_mat_list = detect_results[1] 75 | 76 | id_compare_values = [] 77 | frame_align_crop_tenor_list = [] 78 | for frame_align_crop in frame_align_crop_list: 79 | 80 | # BGR TO RGB 81 | # frame_align_crop_RGB = frame_align_crop[...,::-1] 82 | 83 | frame_align_crop_tenor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda() 84 | 85 | frame_align_crop_tenor_arcnorm = spNorm(frame_align_crop_tenor) 86 | frame_align_crop_tenor_arcnorm_downsample = F.interpolate(frame_align_crop_tenor_arcnorm, size=(112,112)) 87 | frame_align_crop_crop_id_nonorm = swap_model.netArc(frame_align_crop_tenor_arcnorm_downsample) 88 | id_compare_values.append([]) 89 | for source_specific_id_nonorm_tmp in source_specific_id_nonorm_list: 90 | id_compare_values[-1].append(mse(frame_align_crop_crop_id_nonorm,source_specific_id_nonorm_tmp).detach().cpu().numpy()) 91 | frame_align_crop_tenor_list.append(frame_align_crop_tenor) 92 | 93 | id_compare_values_array = np.array(id_compare_values).transpose(1,0) 94 | min_indexs = np.argmin(id_compare_values_array,axis=0) 95 | min_value = np.min(id_compare_values_array,axis=0) 96 | 97 | swap_result_list = [] 98 | swap_result_matrix_list = [] 99 | swap_result_ori_pic_list = [] 100 | for tmp_index, min_index in enumerate(min_indexs): 101 | if min_value[tmp_index] < id_thres: 102 | swap_result = swap_model(None, frame_align_crop_tenor_list[tmp_index], target_id_norm_list[min_index], None, True)[0] 103 | swap_result_list.append(swap_result) 104 | swap_result_matrix_list.append(frame_mat_list[tmp_index]) 105 | swap_result_ori_pic_list.append(frame_align_crop_tenor_list[tmp_index]) 106 | else: 107 | pass 108 | 109 | 110 | 111 | if len(swap_result_list) !=0: 112 | 113 | reverse2wholeimage(swap_result_ori_pic_list,swap_result_list, swap_result_matrix_list, crop_size, frame, logoclass,\ 114 | os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask=use_mask, norm = spNorm) 115 | else: 116 | if not os.path.exists(temp_results_dir): 117 | os.mkdir(temp_results_dir) 118 | frame = frame.astype(np.uint8) 119 | if not no_simswaplogo: 120 | frame = logoclass.apply_frames(frame) 121 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) 122 | 123 | else: 124 | if not os.path.exists(temp_results_dir): 125 | os.mkdir(temp_results_dir) 126 | frame = frame.astype(np.uint8) 127 | if not no_simswaplogo: 128 | frame = logoclass.apply_frames(frame) 129 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) 130 | else: 131 | break 132 | 133 | video.release() 134 | 135 | # image_filename_list = [] 136 | path = os.path.join(temp_results_dir,'*.jpg') 137 | image_filenames = sorted(glob.glob(path)) 138 | 139 | clips = ImageSequenceClip(image_filenames,fps = fps) 140 | 141 | if not no_audio: 142 | clips = clips.set_audio(video_audio_clip) 143 | 144 | 145 | clips.write_videofile(save_path,audio_codec='aac') 146 | 147 | -------------------------------------------------------------------------------- /util/videoswap_specific.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import torch 5 | import shutil 6 | import numpy as np 7 | from tqdm import tqdm 8 | from util.reverse2original import reverse2wholeimage 9 | import moviepy.editor as mp 10 | from moviepy.editor import AudioFileClip, VideoFileClip 11 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip 12 | import time 13 | from util.add_watermark import watermark_image 14 | from util.norm import SpecificNorm 15 | import torch.nn.functional as F 16 | from parsing_model.model import BiSeNet 17 | 18 | def _totensor(array): 19 | tensor = torch.from_numpy(array) 20 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous() 21 | return img.float().div(255) 22 | 23 | def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False): 24 | video_forcheck = VideoFileClip(video_path) 25 | if video_forcheck.audio is None: 26 | no_audio = True 27 | else: 28 | no_audio = False 29 | 30 | del video_forcheck 31 | 32 | if not no_audio: 33 | video_audio_clip = AudioFileClip(video_path) 34 | 35 | video = cv2.VideoCapture(video_path) 36 | logoclass = watermark_image('./simswaplogo/simswaplogo.png') 37 | ret = True 38 | frame_index = 0 39 | 40 | frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 41 | 42 | # video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) 43 | 44 | # video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 45 | 46 | fps = video.get(cv2.CAP_PROP_FPS) 47 | if os.path.exists(temp_results_dir): 48 | shutil.rmtree(temp_results_dir) 49 | 50 | spNorm =SpecificNorm() 51 | mse = torch.nn.MSELoss().cuda() 52 | 53 | if use_mask: 54 | n_classes = 19 55 | net = BiSeNet(n_classes=n_classes) 56 | net.cuda() 57 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth') 58 | net.load_state_dict(torch.load(save_pth)) 59 | net.eval() 60 | else: 61 | net =None 62 | 63 | # while ret: 64 | for frame_index in tqdm(range(frame_count)): 65 | ret, frame = video.read() 66 | if ret: 67 | detect_results = detect_model.get(frame,crop_size) 68 | 69 | if detect_results is not None: 70 | # print(frame_index) 71 | if not os.path.exists(temp_results_dir): 72 | os.mkdir(temp_results_dir) 73 | frame_align_crop_list = detect_results[0] 74 | frame_mat_list = detect_results[1] 75 | 76 | id_compare_values = [] 77 | frame_align_crop_tenor_list = [] 78 | for frame_align_crop in frame_align_crop_list: 79 | 80 | # BGR TO RGB 81 | # frame_align_crop_RGB = frame_align_crop[...,::-1] 82 | 83 | frame_align_crop_tenor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda() 84 | 85 | frame_align_crop_tenor_arcnorm = spNorm(frame_align_crop_tenor) 86 | frame_align_crop_tenor_arcnorm_downsample = F.interpolate(frame_align_crop_tenor_arcnorm, size=(112,112)) 87 | frame_align_crop_crop_id_nonorm = swap_model.netArc(frame_align_crop_tenor_arcnorm_downsample) 88 | 89 | id_compare_values.append(mse(frame_align_crop_crop_id_nonorm,specific_person_id_nonorm).detach().cpu().numpy()) 90 | frame_align_crop_tenor_list.append(frame_align_crop_tenor) 91 | id_compare_values_array = np.array(id_compare_values) 92 | min_index = np.argmin(id_compare_values_array) 93 | min_value = id_compare_values_array[min_index] 94 | if min_value < id_thres: 95 | swap_result = swap_model(None, frame_align_crop_tenor_list[min_index], id_vetor, None, True)[0] 96 | 97 | reverse2wholeimage([frame_align_crop_tenor_list[min_index]], [swap_result], [frame_mat_list[min_index]], crop_size, frame, logoclass,\ 98 | os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask= use_mask, norm = spNorm) 99 | else: 100 | if not os.path.exists(temp_results_dir): 101 | os.mkdir(temp_results_dir) 102 | frame = frame.astype(np.uint8) 103 | if not no_simswaplogo: 104 | frame = logoclass.apply_frames(frame) 105 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) 106 | 107 | else: 108 | if not os.path.exists(temp_results_dir): 109 | os.mkdir(temp_results_dir) 110 | frame = frame.astype(np.uint8) 111 | if not no_simswaplogo: 112 | frame = logoclass.apply_frames(frame) 113 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) 114 | else: 115 | break 116 | 117 | video.release() 118 | 119 | # image_filename_list = [] 120 | path = os.path.join(temp_results_dir,'*.jpg') 121 | image_filenames = sorted(glob.glob(path)) 122 | 123 | clips = ImageSequenceClip(image_filenames,fps = fps) 124 | 125 | if not no_audio: 126 | clips = clips.set_audio(video_audio_clip) 127 | 128 | 129 | clips.write_videofile(save_path,audio_codec='aac') 130 | 131 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | import scipy.misc 8 | try: 9 | from StringIO import StringIO # Python 2.7 10 | except ImportError: 11 | from io import BytesIO # Python 3.x 12 | 13 | class Visualizer(): 14 | def __init__(self, opt): 15 | # self.opt = opt 16 | self.tf_log = opt.tf_log 17 | self.use_html = opt.isTrain and not opt.no_html 18 | self.win_size = opt.display_winsize 19 | self.name = opt.name 20 | if self.tf_log: 21 | import tensorflow as tf 22 | self.tf = tf 23 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 24 | self.writer = tf.summary.FileWriter(self.log_dir) 25 | 26 | if self.use_html: 27 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 28 | self.img_dir = os.path.join(self.web_dir, 'images') 29 | print('create web directory %s...' % self.web_dir) 30 | util.mkdirs([self.web_dir, self.img_dir]) 31 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 32 | with open(self.log_name, "a") as log_file: 33 | now = time.strftime("%c") 34 | log_file.write('================ Training Loss (%s) ================\n' % now) 35 | 36 | # |visuals|: dictionary of images to display or save 37 | def display_current_results(self, visuals, epoch, step): 38 | if self.tf_log: # show images in tensorboard output 39 | img_summaries = [] 40 | for label, image_numpy in visuals.items(): 41 | # Write the image to a string 42 | try: 43 | s = StringIO() 44 | except: 45 | s = BytesIO() 46 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 47 | # Create an Image object 48 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 49 | # Create a Summary value 50 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 51 | 52 | # Create and write Summary 53 | summary = self.tf.Summary(value=img_summaries) 54 | self.writer.add_summary(summary, step) 55 | 56 | if self.use_html: # save images to a html file 57 | for label, image_numpy in visuals.items(): 58 | if isinstance(image_numpy, list): 59 | for i in range(len(image_numpy)): 60 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) 61 | util.save_image(image_numpy[i], img_path) 62 | else: 63 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label)) 64 | util.save_image(image_numpy, img_path) 65 | 66 | # update website 67 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30) 68 | for n in range(epoch, 0, -1): 69 | webpage.add_header('epoch [%d]' % n) 70 | ims = [] 71 | txts = [] 72 | links = [] 73 | 74 | for label, image_numpy in visuals.items(): 75 | if isinstance(image_numpy, list): 76 | for i in range(len(image_numpy)): 77 | img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) 78 | ims.append(img_path) 79 | txts.append(label+str(i)) 80 | links.append(img_path) 81 | else: 82 | img_path = 'epoch%.3d_%s.jpg' % (n, label) 83 | ims.append(img_path) 84 | txts.append(label) 85 | links.append(img_path) 86 | if len(ims) < 10: 87 | webpage.add_images(ims, txts, links, width=self.win_size) 88 | else: 89 | num = int(round(len(ims)/2.0)) 90 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 91 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 92 | webpage.save() 93 | 94 | # errors: dictionary of error labels and values 95 | def plot_current_errors(self, errors, step): 96 | if self.tf_log: 97 | for tag, value in errors.items(): 98 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 99 | self.writer.add_summary(summary, step) 100 | 101 | # errors: same format as |errors| of plotCurrentErrors 102 | def print_current_errors(self, epoch, i, errors, t): 103 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 104 | for k, v in errors.items(): 105 | if v != 0: 106 | message += '%s: %.3f ' % (k, v) 107 | 108 | print(message) 109 | with open(self.log_name, "a") as log_file: 110 | log_file.write('%s\n' % message) 111 | 112 | # save image to the disk 113 | def save_images(self, webpage, visuals, image_path): 114 | image_dir = webpage.get_image_dir() 115 | short_path = ntpath.basename(image_path[0]) 116 | name = os.path.splitext(short_path)[0] 117 | 118 | webpage.add_header(name) 119 | ims = [] 120 | txts = [] 121 | links = [] 122 | 123 | for label, image_numpy in visuals.items(): 124 | image_name = '%s_%s.jpg' % (name, label) 125 | save_path = os.path.join(image_dir, image_name) 126 | util.save_image(image_numpy, save_path) 127 | 128 | ims.append(image_name) 129 | txts.append(label) 130 | links.append(image_name) 131 | webpage.add_images(ims, txts, links, width=self.win_size) 132 | --------------------------------------------------------------------------------