├── .gitattributes
├── .gitignore
├── LICENSE
├── MultiSpecific.ipynb
├── README.md
├── SimSwap colab.ipynb
├── cog.yaml
├── crop_224
├── 1_source.jpg
├── 2.jpg
├── 6.jpg
├── cage.jpg
├── dnl.jpg
├── ds.jpg
├── gdg.jpg
├── gy.jpg
├── hzc.jpg
├── hzxc.jpg
├── james.jpg
├── jl.jpg
├── lcw.jpg
├── ljm.jpg
├── ljm2.jpg
├── ljm3.jpg
├── mars2.jpg
├── mouth_open.jpg
├── mtdm.jpg
├── trump.jpg
├── wlh.jpg
├── zjl.jpg
├── zrf.jpg
└── zxy.jpg
├── data
└── data_loader_Swapping.py
├── demo_file
├── Iron_man.jpg
├── multi_people.jpg
├── multi_people_1080p.mp4
├── multispecific
│ ├── DST_01.jpg
│ ├── DST_02.jpg
│ ├── DST_03.jpg
│ ├── SRC_01.png
│ ├── SRC_02.png
│ └── SRC_03.png
├── specific1.png
├── specific2.png
└── specific3.png
├── docs
├── css
│ ├── bootstrap-theme.min.css
│ ├── bootstrap.min.css
│ ├── ie10-viewport-bug-workaround.css
│ ├── jumbotron.css
│ └── page.css
├── favicon.ico
├── fonts
│ ├── glyphicons-halflings-regular.eot
│ ├── glyphicons-halflings-regular.svg
│ ├── glyphicons-halflings-regular.ttf
│ ├── glyphicons-halflings-regular.woff
│ └── glyphicons-halflings-regular.woff2
├── guidance
│ ├── preparation.md
│ └── usage.md
├── img
│ ├── LRGT_201110059_201110091.webp
│ ├── anni.webp
│ ├── chenglong.webp
│ ├── girl2-RGB.png
│ ├── girl2.gif
│ ├── id
│ │ ├── Iron_man.jpg
│ │ ├── anni.jpg
│ │ ├── chenglong.jpg
│ │ ├── wuyifan.png
│ │ ├── zhoujielun.jpg
│ │ └── zhuyin.jpg
│ ├── logo.png
│ ├── logo1.png
│ ├── logo2.png
│ ├── mama_mask_short.webp
│ ├── mama_mask_wuyifan_short.webp
│ ├── multi_face_comparison.png
│ ├── new.gif
│ ├── nrsig.png
│ ├── result_whole_swap_multispecific_512.jpg
│ ├── results1.PNG
│ ├── title.png
│ ├── total.PNG
│ ├── vggface2_hq_compare.png
│ ├── video.webp
│ ├── zhoujielun.webp
│ └── zhuyin.webp
├── index.html
└── js
│ ├── bootstrap.min.js
│ ├── ie-emulation-modes-warning.js
│ ├── ie10-viewport-bug-workaround.js
│ ├── vendor
│ └── jquery.min.js
│ └── which-image.js
├── download-weights.sh
├── insightface_func
├── __init__.py
├── face_detect_crop_multi.py
├── face_detect_crop_single.py
└── utils
│ └── face_align_ffhqandnewarc.py
├── models
├── __init__.py
├── arcface_models.py
├── base_model.py
├── config.py
├── fs_model.py
├── fs_networks.py
├── fs_networks_512.py
├── fs_networks_fix.py
├── models.py
├── networks.py
├── pix2pixHD_model.py
├── projected_model.py
├── projectionhead.py
└── ui_model.py
├── options
├── base_options.py
├── test_options.py
└── train_options.py
├── output
└── result.jpg
├── parsing_model
├── model.py
└── resnet.py
├── pg_modules
├── blocks.py
├── diffaug.py
├── projected_discriminator.py
└── projector.py
├── predict.py
├── simswaplogo
├── simswaplogo.png
└── socialbook_logo.2020.357eed90add7705e54a8.svg
├── test_one_image.py
├── test_video_swap_multispecific.py
├── test_video_swapmulti.py
├── test_video_swapsingle.py
├── test_video_swapspecific.py
├── test_wholeimage_swap_multispecific.py
├── test_wholeimage_swapmulti.py
├── test_wholeimage_swapsingle.py
├── test_wholeimage_swapspecific.py
├── train.ipynb
├── train.py
└── util
├── add_watermark.py
├── html.py
├── image_pool.py
├── json_config.py
├── logo_class.py
├── norm.py
├── plot.py
├── reverse2original.py
├── save_heatmap.py
├── util.py
├── videoswap.py
├── videoswap_multispecific.py
├── videoswap_specific.py
└── visualizer.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | docs/ppt/
132 | checkpoints/
133 | *.tar
134 | *.patch
135 | *.zip
136 | *.avi
137 | *.pdf
138 | *.pptx
139 |
140 | *.pth
141 | *.onnx
142 | wandb/
143 | temp_results/
144 | output/*.*
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | gpu: true
3 | python_version: "3.8"
4 | system_packages:
5 | - "libgl1-mesa-glx"
6 | - "libglib2.0-0"
7 | python_packages:
8 | - "imageio==2.9.0"
9 | - "torch==1.8.0"
10 | - "torchvision==0.9.0"
11 | - "numpy==1.21.1"
12 | - "insightface==0.2.1"
13 | - "ipython==7.21.0"
14 | - "Pillow==8.3.1"
15 | - "opencv-python==4.5.3.56"
16 | - "Fraction==1.5.1"
17 | - "onnxruntime-gpu==1.8.1"
18 | - "moviepy==1.0.3"
19 |
20 | predict: "predict.py:Predictor"
21 |
--------------------------------------------------------------------------------
/crop_224/1_source.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/1_source.jpg
--------------------------------------------------------------------------------
/crop_224/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/2.jpg
--------------------------------------------------------------------------------
/crop_224/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/6.jpg
--------------------------------------------------------------------------------
/crop_224/cage.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/cage.jpg
--------------------------------------------------------------------------------
/crop_224/dnl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/dnl.jpg
--------------------------------------------------------------------------------
/crop_224/ds.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/ds.jpg
--------------------------------------------------------------------------------
/crop_224/gdg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/gdg.jpg
--------------------------------------------------------------------------------
/crop_224/gy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/gy.jpg
--------------------------------------------------------------------------------
/crop_224/hzc.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/hzc.jpg
--------------------------------------------------------------------------------
/crop_224/hzxc.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/hzxc.jpg
--------------------------------------------------------------------------------
/crop_224/james.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/james.jpg
--------------------------------------------------------------------------------
/crop_224/jl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/jl.jpg
--------------------------------------------------------------------------------
/crop_224/lcw.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/lcw.jpg
--------------------------------------------------------------------------------
/crop_224/ljm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/ljm.jpg
--------------------------------------------------------------------------------
/crop_224/ljm2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/ljm2.jpg
--------------------------------------------------------------------------------
/crop_224/ljm3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/ljm3.jpg
--------------------------------------------------------------------------------
/crop_224/mars2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/mars2.jpg
--------------------------------------------------------------------------------
/crop_224/mouth_open.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/mouth_open.jpg
--------------------------------------------------------------------------------
/crop_224/mtdm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/mtdm.jpg
--------------------------------------------------------------------------------
/crop_224/trump.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/trump.jpg
--------------------------------------------------------------------------------
/crop_224/wlh.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/wlh.jpg
--------------------------------------------------------------------------------
/crop_224/zjl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/zjl.jpg
--------------------------------------------------------------------------------
/crop_224/zrf.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/zrf.jpg
--------------------------------------------------------------------------------
/crop_224/zxy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/crop_224/zxy.jpg
--------------------------------------------------------------------------------
/data/data_loader_Swapping.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import random
5 | from PIL import Image
6 | from torch.utils import data
7 | from torchvision import transforms as T
8 |
9 | class data_prefetcher():
10 | def __init__(self, loader):
11 | self.loader = loader
12 | self.dataiter = iter(loader)
13 | self.stream = torch.cuda.Stream()
14 | self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1)
15 | self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1)
16 | # With Amp, it isn't necessary to manually convert data to half.
17 | # if args.fp16:
18 | # self.mean = self.mean.half()
19 | # self.std = self.std.half()
20 | self.num_images = len(loader)
21 | self.preload()
22 |
23 | def preload(self):
24 | try:
25 | self.src_image1, self.src_image2 = next(self.dataiter)
26 | except StopIteration:
27 | self.dataiter = iter(self.loader)
28 | self.src_image1, self.src_image2 = next(self.dataiter)
29 |
30 | with torch.cuda.stream(self.stream):
31 | self.src_image1 = self.src_image1.cuda(non_blocking=True)
32 | self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
33 | self.src_image2 = self.src_image2.cuda(non_blocking=True)
34 | self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
35 |
36 | def next(self):
37 | torch.cuda.current_stream().wait_stream(self.stream)
38 | src_image1 = self.src_image1
39 | src_image2 = self.src_image2
40 | self.preload()
41 | return src_image1, src_image2
42 |
43 | def __len__(self):
44 | """Return the number of images."""
45 | return self.num_images
46 |
47 | class SwappingDataset(data.Dataset):
48 | """Dataset class for the Artworks dataset and content dataset."""
49 |
50 | def __init__(self,
51 | image_dir,
52 | img_transform,
53 | subffix='jpg',
54 | random_seed=1234):
55 | """Initialize and preprocess the Swapping dataset."""
56 | self.image_dir = image_dir
57 | self.img_transform = img_transform
58 | self.subffix = subffix
59 | self.dataset = []
60 | self.random_seed = random_seed
61 | self.preprocess()
62 | self.num_images = len(self.dataset)
63 |
64 | def preprocess(self):
65 | """Preprocess the Swapping dataset."""
66 | print("processing Swapping dataset images...")
67 |
68 | temp_path = os.path.join(self.image_dir,'*/')
69 | pathes = glob.glob(temp_path)
70 | self.dataset = []
71 | for dir_item in pathes:
72 | join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
73 | print("processing %s"%dir_item,end='\r')
74 | temp_list = []
75 | for item in join_path:
76 | temp_list.append(item)
77 | self.dataset.append(temp_list)
78 | random.seed(self.random_seed)
79 | random.shuffle(self.dataset)
80 | print('Finished preprocessing the Swapping dataset, total dirs number: %d...'%len(self.dataset))
81 |
82 | def __getitem__(self, index):
83 | """Return two src domain images and two dst domain images."""
84 | dir_tmp1 = self.dataset[index]
85 | dir_tmp1_len = len(dir_tmp1)
86 |
87 | filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
88 | filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
89 | image1 = self.img_transform(Image.open(filename1))
90 | image2 = self.img_transform(Image.open(filename2))
91 | return image1, image2
92 |
93 | def __len__(self):
94 | """Return the number of images."""
95 | return self.num_images
96 |
97 | def GetLoader( dataset_roots,
98 | batch_size=16,
99 | dataloader_workers=8,
100 | random_seed = 1234
101 | ):
102 | """Build and return a data loader."""
103 |
104 | num_workers = dataloader_workers
105 | data_root = dataset_roots
106 | random_seed = random_seed
107 |
108 | c_transforms = []
109 |
110 | c_transforms.append(T.ToTensor())
111 | c_transforms = T.Compose(c_transforms)
112 |
113 | content_dataset = SwappingDataset(
114 | data_root,
115 | c_transforms,
116 | "jpg",
117 | random_seed)
118 | content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
119 | drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
120 | prefetcher = data_prefetcher(content_data_loader)
121 | return prefetcher
122 |
123 | def denorm(x):
124 | out = (x + 1) / 2
125 | return out.clamp_(0, 1)
--------------------------------------------------------------------------------
/demo_file/Iron_man.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/Iron_man.jpg
--------------------------------------------------------------------------------
/demo_file/multi_people.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multi_people.jpg
--------------------------------------------------------------------------------
/demo_file/multi_people_1080p.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multi_people_1080p.mp4
--------------------------------------------------------------------------------
/demo_file/multispecific/DST_01.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/DST_01.jpg
--------------------------------------------------------------------------------
/demo_file/multispecific/DST_02.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/DST_02.jpg
--------------------------------------------------------------------------------
/demo_file/multispecific/DST_03.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/DST_03.jpg
--------------------------------------------------------------------------------
/demo_file/multispecific/SRC_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/SRC_01.png
--------------------------------------------------------------------------------
/demo_file/multispecific/SRC_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/SRC_02.png
--------------------------------------------------------------------------------
/demo_file/multispecific/SRC_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/multispecific/SRC_03.png
--------------------------------------------------------------------------------
/demo_file/specific1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/specific1.png
--------------------------------------------------------------------------------
/demo_file/specific2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/specific2.png
--------------------------------------------------------------------------------
/demo_file/specific3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/demo_file/specific3.png
--------------------------------------------------------------------------------
/docs/css/ie10-viewport-bug-workaround.css:
--------------------------------------------------------------------------------
1 | /*!
2 | * IE10 viewport hack for Surface/desktop Windows 8 bug
3 | * Copyright 2014-2015 Twitter, Inc.
4 | * Licensed under MIT (https://github.com/twbs/bootstrap/blob/master/LICENSE)
5 | */
6 |
7 | /*
8 | * See the Getting Started docs for more information:
9 | * http://getbootstrap.com/getting-started/#support-ie10-width
10 | */
11 | @-ms-viewport { width: device-width; }
12 | @-o-viewport { width: device-width; }
13 | @viewport { width: device-width; }
14 |
--------------------------------------------------------------------------------
/docs/css/jumbotron.css:
--------------------------------------------------------------------------------
1 | /* Move down content because we have a fixed navbar that is 50px tall */
2 | body {
3 | padding-top: 50px;
4 | padding-bottom: 20px;
5 | }
6 |
--------------------------------------------------------------------------------
/docs/css/page.css:
--------------------------------------------------------------------------------
1 | .which-image-container {
2 | display: flex;
3 | flex-wrap: wrap;
4 | align-items: center;
5 | flex-direction: column;
6 | justify-content: space-between;
7 | height: 100%;
8 | }
9 |
10 | .which-image-container :nth-child(2) {
11 | display: flex;
12 | flex-wrap: wrap;
13 | align-items: center;
14 | flex-direction: column;
15 | justify-content: flex-end;
16 | }
17 |
18 | .which-image {
19 | display: inline-block;
20 | flex: 1;
21 | flex-basis: 45%;
22 | margin: 1% 1% 1% 1%;
23 | width: 100%;
24 | height: auto;
25 |
26 | }
27 |
28 | .which-image img {
29 | float: right;
30 | width: 100%;
31 | }
32 |
33 | .image-display2 img {
34 | float: right;
35 | width: 100%;
36 | }
37 | /* .image-display{
38 | align-items: center;
39 | }
40 | .image-display2{
41 | align-items: center;
42 | } */
43 | .select-show {
44 | border-style: dashed;
45 | border-width: 2px;
46 | border-color: purple;
47 |
48 | /* padding: 2px; */
49 | }
--------------------------------------------------------------------------------
/docs/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/favicon.ico
--------------------------------------------------------------------------------
/docs/fonts/glyphicons-halflings-regular.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/fonts/glyphicons-halflings-regular.eot
--------------------------------------------------------------------------------
/docs/fonts/glyphicons-halflings-regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/fonts/glyphicons-halflings-regular.ttf
--------------------------------------------------------------------------------
/docs/fonts/glyphicons-halflings-regular.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/fonts/glyphicons-halflings-regular.woff
--------------------------------------------------------------------------------
/docs/fonts/glyphicons-halflings-regular.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/fonts/glyphicons-halflings-regular.woff2
--------------------------------------------------------------------------------
/docs/guidance/preparation.md:
--------------------------------------------------------------------------------
1 |
2 | # Preparation
3 |
4 | ### Installation
5 | **We highly recommand that you use Anaconda for Installation**
6 | ```
7 | conda create -n simswap python=3.6
8 | conda activate simswap
9 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch
10 | (option): pip install --ignore-installed imageio
11 | pip install insightface==0.2.1 onnxruntime moviepy
12 | (option): pip install onnxruntime-gpu (If you want to reduce the inference time)(It will be diffcult to install onnxruntime-gpu , the specify version of onnxruntime-gpu may depends on your machine and cuda version.)
13 | ```
14 | - ***We have now updated the prepare document. The main change gpu version of onnx is supported now. If you have configured the environment before, now use pip install onnxruntime-gpu ,You can increase the computing speed.***
15 | - We use the face detection and alignment methods from **[insightface](https://github.com/deepinsight/insightface)** for image preprocessing. Please download the relative files and unzip them to ./insightface_func/models from [this link](https://onedrive.live.com/?authkey=%21ADJ0aAOSsc90neY&cid=4A83B6B633B029CC&id=4A83B6B633B029CC%215837&parId=4A83B6B633B029CC%215834&action=locate).
16 | - We use the face parsing from **[face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch)** for image postprocessing. Please download the relative file and place it in ./parsing_model/checkpoint from [this link](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view).
17 | - The pytorch and cuda versions above are most recommanded. They may vary.
18 | - Using insightface with different versions is not recommanded. Please use this specific version.
19 | - These settings are tested valid on both Windows and Ubuntu.
20 |
21 | ### Pretrained model
22 | There are two archive files in the drive: **checkpoints.zip** and **arcface_checkpoint.tar**
23 |
24 | - **Copy the arcface_checkpoint.tar into ./arcface_model**
25 | - **Unzip checkpoints.zip, place it in the root dir ./**
26 |
27 | [[Google Drive]](https://drive.google.com/drive/folders/1jV6_0FIMPC53FZ2HzZNJZGMe55bbu17R?usp=sharing)
28 | [[Baidu Drive]](https://pan.baidu.com/s/1wFV11RVZMHqd-ky4YpLdcA) Password: ```jd2v```
29 |
30 | **Simswap 512 (optional)**
31 |
32 | The checkpoint of **Simswap 512 beta version** has been uploaded in [Github release](https://github.com/neuralchen/SimSwap/releases/download/512_beta/512.zip).If you want to experience Simswap 512, feel free to try.
33 | - **Unzip 512.zip, place it in the root dir ./checkpoints**.
34 |
35 |
36 | ### Note
37 | We expect users to have GPU with at least 3G memory. For those who do not, we provide [[Colab Notebook implementation]](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/SimSwap%20colab.ipynb).
38 |
--------------------------------------------------------------------------------
/docs/img/LRGT_201110059_201110091.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/LRGT_201110059_201110091.webp
--------------------------------------------------------------------------------
/docs/img/anni.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/anni.webp
--------------------------------------------------------------------------------
/docs/img/chenglong.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/chenglong.webp
--------------------------------------------------------------------------------
/docs/img/girl2-RGB.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/girl2-RGB.png
--------------------------------------------------------------------------------
/docs/img/girl2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/girl2.gif
--------------------------------------------------------------------------------
/docs/img/id/Iron_man.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/Iron_man.jpg
--------------------------------------------------------------------------------
/docs/img/id/anni.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/anni.jpg
--------------------------------------------------------------------------------
/docs/img/id/chenglong.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/chenglong.jpg
--------------------------------------------------------------------------------
/docs/img/id/wuyifan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/wuyifan.png
--------------------------------------------------------------------------------
/docs/img/id/zhoujielun.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/zhoujielun.jpg
--------------------------------------------------------------------------------
/docs/img/id/zhuyin.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/id/zhuyin.jpg
--------------------------------------------------------------------------------
/docs/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/logo.png
--------------------------------------------------------------------------------
/docs/img/logo1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/logo1.png
--------------------------------------------------------------------------------
/docs/img/logo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/logo2.png
--------------------------------------------------------------------------------
/docs/img/mama_mask_short.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/mama_mask_short.webp
--------------------------------------------------------------------------------
/docs/img/mama_mask_wuyifan_short.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/mama_mask_wuyifan_short.webp
--------------------------------------------------------------------------------
/docs/img/multi_face_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/multi_face_comparison.png
--------------------------------------------------------------------------------
/docs/img/new.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/new.gif
--------------------------------------------------------------------------------
/docs/img/nrsig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/nrsig.png
--------------------------------------------------------------------------------
/docs/img/result_whole_swap_multispecific_512.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/result_whole_swap_multispecific_512.jpg
--------------------------------------------------------------------------------
/docs/img/results1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/results1.PNG
--------------------------------------------------------------------------------
/docs/img/title.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/title.png
--------------------------------------------------------------------------------
/docs/img/total.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/total.PNG
--------------------------------------------------------------------------------
/docs/img/vggface2_hq_compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/vggface2_hq_compare.png
--------------------------------------------------------------------------------
/docs/img/video.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/video.webp
--------------------------------------------------------------------------------
/docs/img/zhoujielun.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/zhoujielun.webp
--------------------------------------------------------------------------------
/docs/img/zhuyin.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/docs/img/zhuyin.webp
--------------------------------------------------------------------------------
/docs/js/ie-emulation-modes-warning.js:
--------------------------------------------------------------------------------
1 | // NOTICE!! DO NOT USE ANY OF THIS JAVASCRIPT
2 | // IT'S JUST JUNK FOR OUR DOCS!
3 | // ++++++++++++++++++++++++++++++++++++++++++
4 | /*!
5 | * Copyright 2014-2015 Twitter, Inc.
6 | *
7 | * Licensed under the Creative Commons Attribution 3.0 Unported License. For
8 | * details, see https://creativecommons.org/licenses/by/3.0/.
9 | */
10 | // Intended to prevent false-positive bug reports about Bootstrap not working properly in old versions of IE due to folks testing using IE's unreliable emulation modes.
11 | (function () {
12 | 'use strict';
13 |
14 | function emulatedIEMajorVersion() {
15 | var groups = /MSIE ([0-9.]+)/.exec(window.navigator.userAgent)
16 | if (groups === null) {
17 | return null
18 | }
19 | var ieVersionNum = parseInt(groups[1], 10)
20 | var ieMajorVersion = Math.floor(ieVersionNum)
21 | return ieMajorVersion
22 | }
23 |
24 | function actualNonEmulatedIEMajorVersion() {
25 | // Detects the actual version of IE in use, even if it's in an older-IE emulation mode.
26 | // IE JavaScript conditional compilation docs: https://msdn.microsoft.com/library/121hztk3%28v=vs.94%29.aspx
27 | // @cc_on docs: https://msdn.microsoft.com/library/8ka90k2e%28v=vs.94%29.aspx
28 | var jscriptVersion = new Function('/*@cc_on return @_jscript_version; @*/')() // jshint ignore:line
29 | if (jscriptVersion === undefined) {
30 | return 11 // IE11+ not in emulation mode
31 | }
32 | if (jscriptVersion < 9) {
33 | return 8 // IE8 (or lower; haven't tested on IE<8)
34 | }
35 | return jscriptVersion // IE9 or IE10 in any mode, or IE11 in non-IE11 mode
36 | }
37 |
38 | var ua = window.navigator.userAgent
39 | if (ua.indexOf('Opera') > -1 || ua.indexOf('Presto') > -1) {
40 | return // Opera, which might pretend to be IE
41 | }
42 | var emulated = emulatedIEMajorVersion()
43 | if (emulated === null) {
44 | return // Not IE
45 | }
46 | var nonEmulated = actualNonEmulatedIEMajorVersion()
47 |
48 | if (emulated !== nonEmulated) {
49 | window.alert('WARNING: You appear to be using IE' + nonEmulated + ' in IE' + emulated + ' emulation mode.\nIE emulation modes can behave significantly differently from ACTUAL older versions of IE.\nPLEASE DON\'T FILE BOOTSTRAP BUGS based on testing in IE emulation modes!')
50 | }
51 | })();
52 |
--------------------------------------------------------------------------------
/docs/js/ie10-viewport-bug-workaround.js:
--------------------------------------------------------------------------------
1 | /*!
2 | * IE10 viewport hack for Surface/desktop Windows 8 bug
3 | * Copyright 2014-2015 Twitter, Inc.
4 | * Licensed under MIT (https://github.com/twbs/bootstrap/blob/master/LICENSE)
5 | */
6 |
7 | // See the Getting Started docs for more information:
8 | // http://getbootstrap.com/getting-started/#support-ie10-width
9 |
10 | (function () {
11 | 'use strict';
12 |
13 | if (navigator.userAgent.match(/IEMobile\/10\.0/)) {
14 | var msViewportStyle = document.createElement('style')
15 | msViewportStyle.appendChild(
16 | document.createTextNode(
17 | '@-ms-viewport{width:auto!important}'
18 | )
19 | )
20 | document.querySelector('head').appendChild(msViewportStyle)
21 | }
22 |
23 | })();
24 |
--------------------------------------------------------------------------------
/docs/js/which-image.js:
--------------------------------------------------------------------------------
1 | /*
2 | * @FilePath: \SimSwap\docs\js\which-image.js
3 | * @Author: Ziang Liu
4 | * @Date: 2021-07-03 16:34:56
5 | * @LastEditors: AceSix
6 | * @LastEditTime: 2021-07-20 00:46:27
7 | * Copyright (C) 2021 SJTU. All rights reserved.
8 | */
9 |
10 |
11 |
12 | function select_source(number) {
13 | var items = ['anni', 'chenglong', 'zhoujielun', 'zhuyin'];
14 | var item_id = items[number];
15 |
16 | for (i = 0; i < 4; i++) {
17 | if (number == i) {
18 | document.getElementById(items[i]).style.borderWidth = '5px';
19 | document.getElementById(items[i]).style.borderColor = 'red';
20 | document.getElementById(items[i]).style.borderStyle = 'outset';
21 | } else {
22 | document.getElementById(items[i]).style.border = 'none';
23 | }
24 | }
25 | document.getElementById('jiroujinlun').src = './img/' + item_id + '.webp';
26 |
27 | }
28 |
29 | function select_source2(number) {
30 | var items = ['Iron_man', 'wuyifan'];
31 | var item_id = items[number];
32 |
33 | for (i = 0; i < 2; i++) {
34 | if (number == i) {
35 | document.getElementById(items[i]).style.borderWidth = '5px';
36 | document.getElementById(items[i]).style.borderColor = 'red';
37 | document.getElementById(items[i]).style.borderStyle = 'outset';
38 | } else {
39 | document.getElementById(items[i]).style.border = 'none';
40 | }
41 | }
42 | if (item_id=='Iron_man'){
43 | document.getElementById('mama').src = './img/mama_mask_short.webp';
44 | }
45 | else{
46 | document.getElementById('mama').src = './img/mama_mask_wuyifan_short.webp';
47 |
48 | }
49 |
50 | }
--------------------------------------------------------------------------------
/download-weights.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | wget -P ./arcface_model https://github.com/neuralchen/SimSwap/releases/download/1.0/arcface_checkpoint.tar
3 | wget https://github.com/neuralchen/SimSwap/releases/download/1.0/checkpoints.zip
4 | unzip ./checkpoints.zip -d ./checkpoints
5 | rm checkpoints.zip
6 | wget --no-check-certificate "https://sh23tw.dm.files.1drv.com/y4mmGiIkNVigkSwOKDcV3nwMJulRGhbtHdkheehR5TArc52UjudUYNXAEvKCii2O5LAmzGCGK6IfleocxuDeoKxDZkNzDRSt4ZUlEt8GlSOpCXAFEkBwaZimtWGDRbpIGpb_pz9Nq5jATBQpezBS6G_UtspWTkgrXHHxhviV2nWy8APPx134zOZrUIbkSF6xnsqzs3uZ_SEX_m9Rey0ykpx9w" -O antelope.zip
7 | mkdir -p insightface_func/models
8 | unzip ./antelope.zip -d ./insightface_func/models/
9 | rm antelope.zip
10 |
--------------------------------------------------------------------------------
/insightface_func/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/insightface_func/__init__.py
--------------------------------------------------------------------------------
/insightface_func/face_detect_crop_multi.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-24 16:45:41
7 | Description:
8 | '''
9 | from __future__ import division
10 | import collections
11 | import numpy as np
12 | import glob
13 | import os
14 | import os.path as osp
15 | import cv2
16 | from insightface.model_zoo import model_zoo
17 | from insightface_func.utils import face_align_ffhqandnewarc as face_align
18 |
19 | __all__ = ['Face_detect_crop', 'Face']
20 |
21 | Face = collections.namedtuple('Face', [
22 | 'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age',
23 | 'embedding_norm', 'normed_embedding',
24 | 'landmark'
25 | ])
26 |
27 | Face.__new__.__defaults__ = (None, ) * len(Face._fields)
28 |
29 |
30 | class Face_detect_crop:
31 | def __init__(self, name, root='~/.insightface_func/models'):
32 | self.models = {}
33 | root = os.path.expanduser(root)
34 | onnx_files = glob.glob(osp.join(root, name, '*.onnx'))
35 | onnx_files = sorted(onnx_files)
36 | for onnx_file in onnx_files:
37 | if onnx_file.find('_selfgen_')>0:
38 | #print('ignore:', onnx_file)
39 | continue
40 | model = model_zoo.get_model(onnx_file)
41 | if model.taskname not in self.models:
42 | print('find model:', onnx_file, model.taskname)
43 | self.models[model.taskname] = model
44 | else:
45 | print('duplicated model task type, ignore:', onnx_file, model.taskname)
46 | del model
47 | assert 'detection' in self.models
48 | self.det_model = self.models['detection']
49 |
50 |
51 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640), mode ='None'):
52 | self.det_thresh = det_thresh
53 | self.mode = mode
54 | assert det_size is not None
55 | print('set det-size:', det_size)
56 | self.det_size = det_size
57 | for taskname, model in self.models.items():
58 | if taskname=='detection':
59 | model.prepare(ctx_id, input_size=det_size)
60 | else:
61 | model.prepare(ctx_id)
62 |
63 | def get(self, img, crop_size, max_num=0):
64 | bboxes, kpss = self.det_model.detect(img,
65 | threshold=self.det_thresh,
66 | max_num=max_num,
67 | metric='default')
68 | if bboxes.shape[0] == 0:
69 | return None
70 | ret = []
71 | # for i in range(bboxes.shape[0]):
72 | # bbox = bboxes[i, 0:4]
73 | # det_score = bboxes[i, 4]
74 | # kps = None
75 | # if kpss is not None:
76 | # kps = kpss[i]
77 | # M, _ = face_align.estimate_norm(kps, crop_size, mode ='None')
78 | # align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
79 | align_img_list = []
80 | M_list = []
81 | for i in range(bboxes.shape[0]):
82 | kps = None
83 | if kpss is not None:
84 | kps = kpss[i]
85 | M, _ = face_align.estimate_norm(kps, crop_size, mode = self.mode)
86 | align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
87 | align_img_list.append(align_img)
88 | M_list.append(M)
89 |
90 | # det_score = bboxes[..., 4]
91 |
92 | # best_index = np.argmax(det_score)
93 |
94 | # kps = None
95 | # if kpss is not None:
96 | # kps = kpss[best_index]
97 | # M, _ = face_align.estimate_norm(kps, crop_size, mode ='None')
98 | # align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
99 |
100 | return align_img_list, M_list
101 |
--------------------------------------------------------------------------------
/insightface_func/face_detect_crop_single.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-24 16:46:04
7 | Description:
8 | '''
9 | from __future__ import division
10 | import collections
11 | import numpy as np
12 | import glob
13 | import os
14 | import os.path as osp
15 | import cv2
16 | from insightface.model_zoo import model_zoo
17 | from insightface_func.utils import face_align_ffhqandnewarc as face_align
18 |
19 | __all__ = ['Face_detect_crop', 'Face']
20 |
21 | Face = collections.namedtuple('Face', [
22 | 'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age',
23 | 'embedding_norm', 'normed_embedding',
24 | 'landmark'
25 | ])
26 |
27 | Face.__new__.__defaults__ = (None, ) * len(Face._fields)
28 |
29 |
30 | class Face_detect_crop:
31 | def __init__(self, name, root='~/.insightface_func/models'):
32 | self.models = {}
33 | root = os.path.expanduser(root)
34 | onnx_files = glob.glob(osp.join(root, name, '*.onnx'))
35 | onnx_files = sorted(onnx_files)
36 | for onnx_file in onnx_files:
37 | if onnx_file.find('_selfgen_')>0:
38 | #print('ignore:', onnx_file)
39 | continue
40 | model = model_zoo.get_model(onnx_file)
41 | if model.taskname not in self.models:
42 | print('find model:', onnx_file, model.taskname)
43 | self.models[model.taskname] = model
44 | else:
45 | print('duplicated model task type, ignore:', onnx_file, model.taskname)
46 | del model
47 | assert 'detection' in self.models
48 | self.det_model = self.models['detection']
49 |
50 |
51 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640), mode ='None'):
52 | self.det_thresh = det_thresh
53 | self.mode = mode
54 | assert det_size is not None
55 | print('set det-size:', det_size)
56 | self.det_size = det_size
57 | for taskname, model in self.models.items():
58 | if taskname=='detection':
59 | model.prepare(ctx_id, input_size=det_size)
60 | else:
61 | model.prepare(ctx_id)
62 |
63 | def get(self, img, crop_size, max_num=0):
64 | bboxes, kpss = self.det_model.detect(img,
65 | threshold=self.det_thresh,
66 | max_num=max_num,
67 | metric='default')
68 | if bboxes.shape[0] == 0:
69 | return None
70 | # ret = []
71 | # for i in range(bboxes.shape[0]):
72 | # bbox = bboxes[i, 0:4]
73 | # det_score = bboxes[i, 4]
74 | # kps = None
75 | # if kpss is not None:
76 | # kps = kpss[i]
77 | # M, _ = face_align.estimate_norm(kps, crop_size, mode ='None')
78 | # align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
79 | # for i in range(bboxes.shape[0]):
80 | # kps = None
81 | # if kpss is not None:
82 | # kps = kpss[i]
83 | # M, _ = face_align.estimate_norm(kps, crop_size, mode ='None')
84 | # align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
85 |
86 | det_score = bboxes[..., 4]
87 |
88 | # select the face with the hightest detection score
89 | best_index = np.argmax(det_score)
90 |
91 | kps = None
92 | if kpss is not None:
93 | kps = kpss[best_index]
94 | M, _ = face_align.estimate_norm(kps, crop_size, mode = self.mode)
95 | align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
96 |
97 | return [align_img], [M]
98 |
--------------------------------------------------------------------------------
/insightface_func/utils/face_align_ffhqandnewarc.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-15 19:42:42
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-15 20:01:47
7 | Description:
8 | '''
9 |
10 | import cv2
11 | import numpy as np
12 | from skimage import transform as trans
13 |
14 | src1 = np.array([[51.642, 50.115], [57.617, 49.990], [35.740, 69.007],
15 | [51.157, 89.050], [57.025, 89.702]],
16 | dtype=np.float32)
17 | #<--left
18 | src2 = np.array([[45.031, 50.118], [65.568, 50.872], [39.677, 68.111],
19 | [45.177, 86.190], [64.246, 86.758]],
20 | dtype=np.float32)
21 |
22 | #---frontal
23 | src3 = np.array([[39.730, 51.138], [72.270, 51.138], [56.000, 68.493],
24 | [42.463, 87.010], [69.537, 87.010]],
25 | dtype=np.float32)
26 |
27 | #-->right
28 | src4 = np.array([[46.845, 50.872], [67.382, 50.118], [72.737, 68.111],
29 | [48.167, 86.758], [67.236, 86.190]],
30 | dtype=np.float32)
31 |
32 | #-->right profile
33 | src5 = np.array([[54.796, 49.990], [60.771, 50.115], [76.673, 69.007],
34 | [55.388, 89.702], [61.257, 89.050]],
35 | dtype=np.float32)
36 |
37 | src = np.array([src1, src2, src3, src4, src5])
38 | src_map = src
39 |
40 | ffhq_src = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
41 | [201.26117, 371.41043], [313.08905, 371.15118]])
42 | ffhq_src = np.expand_dims(ffhq_src, axis=0)
43 |
44 | # arcface_src = np.array(
45 | # [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
46 | # [41.5493, 92.3655], [70.7299, 92.2041]],
47 | # dtype=np.float32)
48 |
49 | # arcface_src = np.expand_dims(arcface_src, axis=0)
50 |
51 | # In[66]:
52 |
53 |
54 | # lmk is prediction; src is template
55 | def estimate_norm(lmk, image_size=112, mode='ffhq'):
56 | assert lmk.shape == (5, 2)
57 | tform = trans.SimilarityTransform()
58 | lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1)
59 | min_M = []
60 | min_index = []
61 | min_error = float('inf')
62 | if mode == 'ffhq':
63 | # assert image_size == 112
64 | src = ffhq_src * image_size / 512
65 | else:
66 | src = src_map * image_size / 112
67 | for i in np.arange(src.shape[0]):
68 | tform.estimate(lmk, src[i])
69 | M = tform.params[0:2, :]
70 | results = np.dot(M, lmk_tran.T)
71 | results = results.T
72 | error = np.sum(np.sqrt(np.sum((results - src[i])**2, axis=1)))
73 | # print(error)
74 | if error < min_error:
75 | min_error = error
76 | min_M = M
77 | min_index = i
78 | return min_M, min_index
79 |
80 |
81 | def norm_crop(img, landmark, image_size=112, mode='ffhq'):
82 | if mode == 'Both':
83 | M_None, _ = estimate_norm(landmark, image_size, mode = 'newarc')
84 | M_ffhq, _ = estimate_norm(landmark, image_size, mode='ffhq')
85 | warped_None = cv2.warpAffine(img, M_None, (image_size, image_size), borderValue=0.0)
86 | warped_ffhq = cv2.warpAffine(img, M_ffhq, (image_size, image_size), borderValue=0.0)
87 | return warped_ffhq, warped_None
88 | else:
89 | M, pose_index = estimate_norm(landmark, image_size, mode)
90 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
91 | return warped
92 |
93 | def square_crop(im, S):
94 | if im.shape[0] > im.shape[1]:
95 | height = S
96 | width = int(float(im.shape[1]) / im.shape[0] * S)
97 | scale = float(S) / im.shape[0]
98 | else:
99 | width = S
100 | height = int(float(im.shape[0]) / im.shape[1] * S)
101 | scale = float(S) / im.shape[1]
102 | resized_im = cv2.resize(im, (width, height))
103 | det_im = np.zeros((S, S, 3), dtype=np.uint8)
104 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
105 | return det_im, scale
106 |
107 |
108 | def transform(data, center, output_size, scale, rotation):
109 | scale_ratio = scale
110 | rot = float(rotation) * np.pi / 180.0
111 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
112 | t1 = trans.SimilarityTransform(scale=scale_ratio)
113 | cx = center[0] * scale_ratio
114 | cy = center[1] * scale_ratio
115 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
116 | t3 = trans.SimilarityTransform(rotation=rot)
117 | t4 = trans.SimilarityTransform(translation=(output_size / 2,
118 | output_size / 2))
119 | t = t1 + t2 + t3 + t4
120 | M = t.params[0:2]
121 | cropped = cv2.warpAffine(data,
122 | M, (output_size, output_size),
123 | borderValue=0.0)
124 | return cropped, M
125 |
126 |
127 | def trans_points2d(pts, M):
128 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
129 | for i in range(pts.shape[0]):
130 | pt = pts[i]
131 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
132 | new_pt = np.dot(M, new_pt)
133 | #print('new_pt', new_pt.shape, new_pt)
134 | new_pts[i] = new_pt[0:2]
135 |
136 | return new_pts
137 |
138 |
139 | def trans_points3d(pts, M):
140 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
141 | #print(scale)
142 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
143 | for i in range(pts.shape[0]):
144 | pt = pts[i]
145 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
146 | new_pt = np.dot(M, new_pt)
147 | #print('new_pt', new_pt.shape, new_pt)
148 | new_pts[i][0:2] = new_pt[0:2]
149 | new_pts[i][2] = pts[i][2] * scale
150 |
151 | return new_pts
152 |
153 |
154 | def trans_points(pts, M):
155 | if pts.shape[1] == 2:
156 | return trans_points2d(pts, M)
157 | else:
158 | return trans_points3d(pts, M)
159 |
160 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .arcface_models import ArcMarginModel
2 | from .arcface_models import ResNet
3 | from .arcface_models import IRBlock
4 | from .arcface_models import SEBlock
--------------------------------------------------------------------------------
/models/arcface_models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 | from torch.nn import Parameter
6 | from .config import device, num_classes
7 |
8 |
9 |
10 | class SEBlock(nn.Module):
11 | def __init__(self, channel, reduction=16):
12 | super(SEBlock, self).__init__()
13 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
14 | self.fc = nn.Sequential(
15 | nn.Linear(channel, channel // reduction),
16 | nn.PReLU(),
17 | nn.Linear(channel // reduction, channel),
18 | nn.Sigmoid()
19 | )
20 |
21 | def forward(self, x):
22 | b, c, _, _ = x.size()
23 | y = self.avg_pool(x).view(b, c)
24 | y = self.fc(y).view(b, c, 1, 1)
25 | return x * y
26 |
27 |
28 | class IRBlock(nn.Module):
29 | expansion = 1
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
32 | super(IRBlock, self).__init__()
33 | self.bn0 = nn.BatchNorm2d(inplanes)
34 | self.conv1 = conv3x3(inplanes, inplanes)
35 | self.bn1 = nn.BatchNorm2d(inplanes)
36 | self.prelu = nn.PReLU()
37 | self.conv2 = conv3x3(inplanes, planes, stride)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 | self.downsample = downsample
40 | self.stride = stride
41 | self.use_se = use_se
42 | if self.use_se:
43 | self.se = SEBlock(planes)
44 |
45 | def forward(self, x):
46 | residual = x
47 | out = self.bn0(x)
48 | out = self.conv1(out)
49 | out = self.bn1(out)
50 | out = self.prelu(out)
51 |
52 | out = self.conv2(out)
53 | out = self.bn2(out)
54 | if self.use_se:
55 | out = self.se(out)
56 |
57 | if self.downsample is not None:
58 | residual = self.downsample(x)
59 |
60 | out += residual
61 | out = self.prelu(out)
62 |
63 | return out
64 |
65 |
66 | class ResNet(nn.Module):
67 |
68 | def __init__(self, block, layers, use_se=True):
69 | self.inplanes = 64
70 | self.use_se = use_se
71 | super(ResNet, self).__init__()
72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False)
73 | self.bn1 = nn.BatchNorm2d(64)
74 | self.prelu = nn.PReLU()
75 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
76 | self.layer1 = self._make_layer(block, 64, layers[0])
77 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
78 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
79 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
80 | self.bn2 = nn.BatchNorm2d(512)
81 | self.dropout = nn.Dropout()
82 | self.fc = nn.Linear(512 * 7 * 7, 512)
83 | self.bn3 = nn.BatchNorm1d(512)
84 |
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | nn.init.xavier_normal_(m.weight)
88 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
89 | nn.init.constant_(m.weight, 1)
90 | nn.init.constant_(m.bias, 0)
91 | elif isinstance(m, nn.Linear):
92 | nn.init.xavier_normal_(m.weight)
93 | nn.init.constant_(m.bias, 0)
94 |
95 | def _make_layer(self, block, planes, blocks, stride=1):
96 | downsample = None
97 | if stride != 1 or self.inplanes != planes * block.expansion:
98 | downsample = nn.Sequential(
99 | nn.Conv2d(self.inplanes, planes * block.expansion,
100 | kernel_size=1, stride=stride, bias=False),
101 | nn.BatchNorm2d(planes * block.expansion),
102 | )
103 |
104 | layers = []
105 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
106 | self.inplanes = planes
107 | for i in range(1, blocks):
108 | layers.append(block(self.inplanes, planes, use_se=self.use_se))
109 |
110 | return nn.Sequential(*layers)
111 |
112 | def forward(self, x):
113 | x = self.conv1(x)
114 | x = self.bn1(x)
115 | x = self.prelu(x)
116 | x = self.maxpool(x)
117 |
118 | x = self.layer1(x)
119 | x = self.layer2(x)
120 | x = self.layer3(x)
121 | x = self.layer4(x)
122 |
123 | x = self.bn2(x)
124 | x = self.dropout(x)
125 | # feature = x
126 | x = x.view(x.size(0), -1)
127 | x = self.fc(x)
128 | x = self.bn3(x)
129 |
130 | return x
131 |
132 |
133 | class ArcMarginModel(nn.Module):
134 | def __init__(self, args):
135 | super(ArcMarginModel, self).__init__()
136 |
137 | self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size))
138 | nn.init.xavier_uniform_(self.weight)
139 |
140 | self.easy_margin = args.easy_margin
141 | self.m = args.margin_m
142 | self.s = args.margin_s
143 |
144 | self.cos_m = math.cos(self.m)
145 | self.sin_m = math.sin(self.m)
146 | self.th = math.cos(math.pi - self.m)
147 | self.mm = math.sin(math.pi - self.m) * self.m
148 |
149 | def forward(self, input, label):
150 | x = F.normalize(input)
151 | W = F.normalize(self.weight)
152 | cosine = F.linear(x, W)
153 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
154 | phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
155 | if self.easy_margin:
156 | phi = torch.where(cosine > 0, phi, cosine)
157 | else:
158 | phi = torch.where(cosine > self.th, phi, cosine - self.mm)
159 | one_hot = torch.zeros(cosine.size(), device=device)
160 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
161 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
162 | output *= self.s
163 | return output
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import sys
4 |
5 | class BaseModel(torch.nn.Module):
6 | def name(self):
7 | return 'BaseModel'
8 |
9 | def initialize(self, opt):
10 | self.opt = opt
11 | self.gpu_ids = opt.gpu_ids
12 | self.isTrain = opt.isTrain
13 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
15 |
16 | def set_input(self, input):
17 | self.input = input
18 |
19 | def forward(self):
20 | pass
21 |
22 | # used in test time, no backprop
23 | def test(self):
24 | pass
25 |
26 | def get_image_paths(self):
27 | pass
28 |
29 | def optimize_parameters(self):
30 | pass
31 |
32 | def get_current_visuals(self):
33 | return self.input
34 |
35 | def get_current_errors(self):
36 | return {}
37 |
38 | def save(self, label):
39 | pass
40 |
41 | # helper saving function that can be used by subclasses
42 | def save_network(self, network, network_label, epoch_label, gpu_ids=None):
43 | save_filename = '{}_net_{}.pth'.format(epoch_label, network_label)
44 | save_path = os.path.join(self.save_dir, save_filename)
45 | torch.save(network.cpu().state_dict(), save_path)
46 | if torch.cuda.is_available():
47 | network.cuda()
48 |
49 | def save_optim(self, network, network_label, epoch_label, gpu_ids=None):
50 | save_filename = '{}_optim_{}.pth'.format(epoch_label, network_label)
51 | save_path = os.path.join(self.save_dir, save_filename)
52 | torch.save(network.state_dict(), save_path)
53 |
54 |
55 | # helper loading function that can be used by subclasses
56 | def load_network(self, network, network_label, epoch_label, save_dir=''):
57 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
58 | if not save_dir:
59 | save_dir = self.save_dir
60 | save_path = os.path.join(save_dir, save_filename)
61 | if not os.path.isfile(save_path):
62 | print('%s not exists yet!' % save_path)
63 | if network_label == 'G':
64 | raise('Generator must exist!')
65 | else:
66 | #network.load_state_dict(torch.load(save_path))
67 | try:
68 | network.load_state_dict(torch.load(save_path))
69 | except:
70 | pretrained_dict = torch.load(save_path)
71 | model_dict = network.state_dict()
72 | try:
73 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
74 | network.load_state_dict(pretrained_dict)
75 | if self.opt.verbose:
76 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
77 | except:
78 | print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
79 | for k, v in pretrained_dict.items():
80 | if v.size() == model_dict[k].size():
81 | model_dict[k] = v
82 |
83 | if sys.version_info >= (3,0):
84 | not_initialized = set()
85 | else:
86 | from sets import Set
87 | not_initialized = Set()
88 |
89 | for k, v in model_dict.items():
90 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
91 | not_initialized.add(k.split('.')[0])
92 |
93 | print(sorted(not_initialized))
94 | network.load_state_dict(model_dict)
95 |
96 | # helper loading function that can be used by subclasses
97 | def load_optim(self, network, network_label, epoch_label, save_dir=''):
98 | save_filename = '%s_optim_%s.pth' % (epoch_label, network_label)
99 | if not save_dir:
100 | save_dir = self.save_dir
101 | save_path = os.path.join(save_dir, save_filename)
102 | if not os.path.isfile(save_path):
103 | print('%s not exists yet!' % save_path)
104 | if network_label == 'G':
105 | raise('Generator must exist!')
106 | else:
107 | #network.load_state_dict(torch.load(save_path))
108 | try:
109 | network.load_state_dict(torch.load(save_path, map_location=torch.device("cpu")))
110 | except:
111 | pretrained_dict = torch.load(save_path, map_location=torch.device("cpu"))
112 | model_dict = network.state_dict()
113 | try:
114 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
115 | network.load_state_dict(pretrained_dict)
116 | if self.opt.verbose:
117 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
118 | except:
119 | print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
120 | for k, v in pretrained_dict.items():
121 | if v.size() == model_dict[k].size():
122 | model_dict[k] = v
123 |
124 | if sys.version_info >= (3,0):
125 | not_initialized = set()
126 | else:
127 | from sets import Set
128 | not_initialized = Set()
129 |
130 | for k, v in model_dict.items():
131 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
132 | not_initialized.add(k.split('.')[0])
133 |
134 | print(sorted(not_initialized))
135 | network.load_state_dict(model_dict)
136 |
137 | def update_learning_rate():
138 | pass
139 |
--------------------------------------------------------------------------------
/models/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors
6 |
7 | # Model parameters
8 | image_w = 112
9 | image_h = 112
10 | channel = 3
11 | emb_size = 512
12 |
13 | # Training parameters
14 | num_workers = 1 # for data-loading; right now, only 1 works with h5py
15 | grad_clip = 5. # clip gradients at an absolute value of
16 | print_freq = 100 # print training/validation stats every __ batches
17 | checkpoint = None # path to checkpoint, None if none
18 |
19 | # Data parameters
20 | num_classes = 93431
21 | num_samples = 5179510
22 | DATA_DIR = 'data'
23 | # faces_ms1m_folder = 'data/faces_ms1m_112x112'
24 | faces_ms1m_folder = 'data/ms1m-retinaface-t1'
25 | path_imgidx = os.path.join(faces_ms1m_folder, 'train.idx')
26 | path_imgrec = os.path.join(faces_ms1m_folder, 'train.rec')
27 | IMG_DIR = 'data/images'
28 | pickle_file = 'data/faces_ms1m_112x112.pickle'
29 |
--------------------------------------------------------------------------------
/models/fs_networks_fix.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class InstanceNorm(nn.Module):
11 | def __init__(self, epsilon=1e-8):
12 | """
13 | @notice: avoid in-place ops.
14 | https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
15 | """
16 | super(InstanceNorm, self).__init__()
17 | self.epsilon = epsilon
18 |
19 | def forward(self, x):
20 | x = x - torch.mean(x, (2, 3), True)
21 | tmp = torch.mul(x, x) # or x ** 2
22 | tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
23 | return x * tmp
24 |
25 | class ApplyStyle(nn.Module):
26 | """
27 | @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
28 | """
29 | def __init__(self, latent_size, channels):
30 | super(ApplyStyle, self).__init__()
31 | self.linear = nn.Linear(latent_size, channels * 2)
32 |
33 | def forward(self, x, latent):
34 | style = self.linear(latent) # style => [batch_size, n_channels*2]
35 | shape = [-1, 2, x.size(1), 1, 1]
36 | style = style.view(shape) # [batch_size, 2, n_channels, ...]
37 | #x = x * (style[:, 0] + 1.) + style[:, 1]
38 | x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
39 | return x
40 |
41 | class ResnetBlock_Adain(nn.Module):
42 | def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
43 | super(ResnetBlock_Adain, self).__init__()
44 |
45 | p = 0
46 | conv1 = []
47 | if padding_type == 'reflect':
48 | conv1 += [nn.ReflectionPad2d(1)]
49 | elif padding_type == 'replicate':
50 | conv1 += [nn.ReplicationPad2d(1)]
51 | elif padding_type == 'zero':
52 | p = 1
53 | else:
54 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
55 | conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
56 | self.conv1 = nn.Sequential(*conv1)
57 | self.style1 = ApplyStyle(latent_size, dim)
58 | self.act1 = activation
59 |
60 | p = 0
61 | conv2 = []
62 | if padding_type == 'reflect':
63 | conv2 += [nn.ReflectionPad2d(1)]
64 | elif padding_type == 'replicate':
65 | conv2 += [nn.ReplicationPad2d(1)]
66 | elif padding_type == 'zero':
67 | p = 1
68 | else:
69 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
70 | conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
71 | self.conv2 = nn.Sequential(*conv2)
72 | self.style2 = ApplyStyle(latent_size, dim)
73 |
74 |
75 | def forward(self, x, dlatents_in_slice):
76 | y = self.conv1(x)
77 | y = self.style1(y, dlatents_in_slice)
78 | y = self.act1(y)
79 | y = self.conv2(y)
80 | y = self.style2(y, dlatents_in_slice)
81 | out = x + y
82 | return out
83 |
84 |
85 |
86 | class Generator_Adain_Upsample(nn.Module):
87 | def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
88 | norm_layer=nn.BatchNorm2d,
89 | padding_type='reflect'):
90 | assert (n_blocks >= 0)
91 | super(Generator_Adain_Upsample, self).__init__()
92 |
93 | activation = nn.ReLU(True)
94 |
95 | self.deep = deep
96 |
97 | self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
98 | norm_layer(64), activation)
99 | ### downsample
100 | self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
101 | norm_layer(128), activation)
102 | self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
103 | norm_layer(256), activation)
104 | self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
105 | norm_layer(512), activation)
106 |
107 | if self.deep:
108 | self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
109 | norm_layer(512), activation)
110 |
111 | ### resnet blocks
112 | BN = []
113 | for i in range(n_blocks):
114 | BN += [
115 | ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)]
116 | self.BottleNeck = nn.Sequential(*BN)
117 |
118 | if self.deep:
119 | self.up4 = nn.Sequential(
120 | nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
121 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
122 | nn.BatchNorm2d(512), activation
123 | )
124 | self.up3 = nn.Sequential(
125 | nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
126 | nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
127 | nn.BatchNorm2d(256), activation
128 | )
129 | self.up2 = nn.Sequential(
130 | nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
131 | nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
132 | nn.BatchNorm2d(128), activation
133 | )
134 | self.up1 = nn.Sequential(
135 | nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
136 | nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
137 | nn.BatchNorm2d(64), activation
138 | )
139 | self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0))
140 |
141 | def forward(self, input, dlatents):
142 | x = input # 3*224*224
143 |
144 | skip1 = self.first_layer(x)
145 | skip2 = self.down1(skip1)
146 | skip3 = self.down2(skip2)
147 | if self.deep:
148 | skip4 = self.down3(skip3)
149 | x = self.down4(skip4)
150 | else:
151 | x = self.down3(skip3)
152 | bot = []
153 | bot.append(x)
154 | features = []
155 | for i in range(len(self.BottleNeck)):
156 | x = self.BottleNeck[i](x, dlatents)
157 | bot.append(x)
158 |
159 | if self.deep:
160 | x = self.up4(x)
161 | features.append(x)
162 | x = self.up3(x)
163 | features.append(x)
164 | x = self.up2(x)
165 | features.append(x)
166 | x = self.up1(x)
167 | features.append(x)
168 | x = self.last_layer(x)
169 | # x = (x + 1) / 2
170 |
171 | # return x, bot, features, dlatents
172 | return x
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 | from torch.nn import Parameter
6 | from .config import device, num_classes
7 |
8 |
9 | def create_model(opt):
10 | #from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
11 | from .fs_model import fsModel
12 | model = fsModel()
13 |
14 | model.initialize(opt)
15 | if opt.verbose:
16 | print("model [%s] was created" % (model.name()))
17 |
18 | if opt.isTrain and len(opt.gpu_ids) and not opt.fp16:
19 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
20 |
21 | return model
22 |
23 |
24 |
25 | class SEBlock(nn.Module):
26 | def __init__(self, channel, reduction=16):
27 | super(SEBlock, self).__init__()
28 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
29 | self.fc = nn.Sequential(
30 | nn.Linear(channel, channel // reduction),
31 | nn.PReLU(),
32 | nn.Linear(channel // reduction, channel),
33 | nn.Sigmoid()
34 | )
35 |
36 | def forward(self, x):
37 | b, c, _, _ = x.size()
38 | y = self.avg_pool(x).view(b, c)
39 | y = self.fc(y).view(b, c, 1, 1)
40 | return x * y
41 |
42 |
43 | class IRBlock(nn.Module):
44 | expansion = 1
45 |
46 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
47 | super(IRBlock, self).__init__()
48 | self.bn0 = nn.BatchNorm2d(inplanes)
49 | self.conv1 = conv3x3(inplanes, inplanes)
50 | self.bn1 = nn.BatchNorm2d(inplanes)
51 | self.prelu = nn.PReLU()
52 | self.conv2 = conv3x3(inplanes, planes, stride)
53 | self.bn2 = nn.BatchNorm2d(planes)
54 | self.downsample = downsample
55 | self.stride = stride
56 | self.use_se = use_se
57 | if self.use_se:
58 | self.se = SEBlock(planes)
59 |
60 | def forward(self, x):
61 | residual = x
62 | out = self.bn0(x)
63 | out = self.conv1(out)
64 | out = self.bn1(out)
65 | out = self.prelu(out)
66 |
67 | out = self.conv2(out)
68 | out = self.bn2(out)
69 | if self.use_se:
70 | out = self.se(out)
71 |
72 | if self.downsample is not None:
73 | residual = self.downsample(x)
74 |
75 | out += residual
76 | out = self.prelu(out)
77 |
78 | return out
79 |
80 |
81 | class ResNet(nn.Module):
82 |
83 | def __init__(self, block, layers, use_se=True):
84 | self.inplanes = 64
85 | self.use_se = use_se
86 | super(ResNet, self).__init__()
87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False)
88 | self.bn1 = nn.BatchNorm2d(64)
89 | self.prelu = nn.PReLU()
90 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
91 | self.layer1 = self._make_layer(block, 64, layers[0])
92 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
93 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
94 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
95 | self.bn2 = nn.BatchNorm2d(512)
96 | self.dropout = nn.Dropout()
97 | self.fc = nn.Linear(512 * 7 * 7, 512)
98 | self.bn3 = nn.BatchNorm1d(512)
99 |
100 | for m in self.modules():
101 | if isinstance(m, nn.Conv2d):
102 | nn.init.xavier_normal_(m.weight)
103 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
104 | nn.init.constant_(m.weight, 1)
105 | nn.init.constant_(m.bias, 0)
106 | elif isinstance(m, nn.Linear):
107 | nn.init.xavier_normal_(m.weight)
108 | nn.init.constant_(m.bias, 0)
109 |
110 | def _make_layer(self, block, planes, blocks, stride=1):
111 | downsample = None
112 | if stride != 1 or self.inplanes != planes * block.expansion:
113 | downsample = nn.Sequential(
114 | nn.Conv2d(self.inplanes, planes * block.expansion,
115 | kernel_size=1, stride=stride, bias=False),
116 | nn.BatchNorm2d(planes * block.expansion),
117 | )
118 |
119 | layers = []
120 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
121 | self.inplanes = planes
122 | for i in range(1, blocks):
123 | layers.append(block(self.inplanes, planes, use_se=self.use_se))
124 |
125 | return nn.Sequential(*layers)
126 |
127 | def forward(self, x):
128 | x = self.conv1(x)
129 | x = self.bn1(x)
130 | x = self.prelu(x)
131 | x = self.maxpool(x)
132 |
133 | x = self.layer1(x)
134 | x = self.layer2(x)
135 | x = self.layer3(x)
136 | x = self.layer4(x)
137 |
138 | x = self.bn2(x)
139 | x = self.dropout(x)
140 | x = x.view(x.size(0), -1)
141 | x = self.fc(x)
142 | x = self.bn3(x)
143 |
144 | return x
145 |
146 |
147 | class ArcMarginModel(nn.Module):
148 | def __init__(self, args):
149 | super(ArcMarginModel, self).__init__()
150 |
151 | self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size))
152 | nn.init.xavier_uniform_(self.weight)
153 |
154 | self.easy_margin = args.easy_margin
155 | self.m = args.margin_m
156 | self.s = args.margin_s
157 |
158 | self.cos_m = math.cos(self.m)
159 | self.sin_m = math.sin(self.m)
160 | self.th = math.cos(math.pi - self.m)
161 | self.mm = math.sin(math.pi - self.m) * self.m
162 |
163 | def forward(self, input, label):
164 | x = F.normalize(input)
165 | W = F.normalize(self.weight)
166 | cosine = F.linear(x, W)
167 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
168 | phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
169 | if self.easy_margin:
170 | phi = torch.where(cosine > 0, phi, cosine)
171 | else:
172 | phi = torch.where(cosine > self.th, phi, cosine - self.mm)
173 | one_hot = torch.zeros(cosine.size(), device=device)
174 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
175 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
176 | output *= self.s
177 | return output
178 |
--------------------------------------------------------------------------------
/models/projected_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | #############################################################
4 | # File: fs_model_fix_idnorm_donggp_saveoptim copy.py
5 | # Created Date: Wednesday January 12th 2022
6 | # Author: Chen Xuanhong
7 | # Email: chenxuanhongzju@outlook.com
8 | # Last Modified: Saturday, 13th May 2023 9:56:35 am
9 | # Modified By: Chen Xuanhong
10 | # Copyright (c) 2022 Shanghai Jiao Tong University
11 | #############################################################
12 |
13 |
14 | import torch
15 | import torch.nn as nn
16 |
17 | from .base_model import BaseModel
18 | from .fs_networks_fix import Generator_Adain_Upsample
19 |
20 | from pg_modules.projected_discriminator import ProjectedDiscriminator
21 |
22 | def compute_grad2(d_out, x_in):
23 | batch_size = x_in.size(0)
24 | grad_dout = torch.autograd.grad(
25 | outputs=d_out.sum(), inputs=x_in,
26 | create_graph=True, retain_graph=True, only_inputs=True
27 | )[0]
28 | grad_dout2 = grad_dout.pow(2)
29 | assert(grad_dout2.size() == x_in.size())
30 | reg = grad_dout2.view(batch_size, -1).sum(1)
31 | return reg
32 |
33 | class fsModel(BaseModel):
34 | def name(self):
35 | return 'fsModel'
36 |
37 | def initialize(self, opt):
38 | BaseModel.initialize(self, opt)
39 | # if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
40 | self.isTrain = opt.isTrain
41 |
42 | # Generator network
43 | self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep)
44 | self.netG.cuda()
45 |
46 | # Id network
47 | netArc_checkpoint = opt.Arc_path
48 | netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
49 | self.netArc = netArc_checkpoint
50 | self.netArc = self.netArc.cuda()
51 | self.netArc.eval()
52 | self.netArc.requires_grad_(False)
53 | if not self.isTrain:
54 | pretrained_path = opt.checkpoints_dir
55 | self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
56 | return
57 | self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
58 | # self.netD.feature_network.requires_grad_(False)
59 | self.netD.cuda()
60 |
61 |
62 | if self.isTrain:
63 | # define loss functions
64 | self.criterionFeat = nn.L1Loss()
65 | self.criterionRec = nn.L1Loss()
66 |
67 |
68 | # initialize optimizers
69 |
70 | # optimizer G
71 | params = list(self.netG.parameters())
72 | self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
73 |
74 | # optimizer D
75 | params = list(self.netD.parameters())
76 | self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
77 |
78 | # load networks
79 | if opt.continue_train:
80 | pretrained_path = '' if not self.isTrain else opt.load_pretrain
81 | # print (pretrained_path)
82 | self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
83 | self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
84 | self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
85 | self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
86 | torch.cuda.empty_cache()
87 |
88 | def cosin_metric(self, x1, x2):
89 | #return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
90 | return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
91 |
92 |
93 |
94 | def save(self, which_epoch):
95 | self.save_network(self.netG, 'G', which_epoch)
96 | self.save_network(self.netD, 'D', which_epoch)
97 | self.save_optim(self.optimizer_G, 'G', which_epoch)
98 | self.save_optim(self.optimizer_D, 'D', which_epoch)
99 | '''if self.gen_features:
100 | self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
101 |
102 | def update_fixed_params(self):
103 | # after fixing the global generator for a number of iterations, also start finetuning it
104 | params = list(self.netG.parameters())
105 | if self.gen_features:
106 | params += list(self.netE.parameters())
107 | self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
108 | if self.opt.verbose:
109 | print('------------ Now also finetuning global generator -----------')
110 |
111 | def update_learning_rate(self):
112 | lrd = self.opt.lr / self.opt.niter_decay
113 | lr = self.old_lr - lrd
114 | for param_group in self.optimizer_D.param_groups:
115 | param_group['lr'] = lr
116 | for param_group in self.optimizer_G.param_groups:
117 | param_group['lr'] = lr
118 | if self.opt.verbose:
119 | print('update learning rate: %f -> %f' % (self.old_lr, lr))
120 | self.old_lr = lr
121 |
122 |
123 |
--------------------------------------------------------------------------------
/models/projectionhead.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class ProjectionHead(nn.Module):
4 | def __init__(self, proj_dim=256):
5 | super(ProjectionHead, self).__init__()
6 |
7 | self.proj = nn.Sequential(
8 | nn.Linear(proj_dim, proj_dim),
9 | nn.ReLU(),
10 | nn.Linear(proj_dim, proj_dim),
11 | )
12 |
13 | def forward(self, x):
14 | return self.proj(x)
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 |
6 | class BaseOptions():
7 | def __init__(self):
8 | self.parser = argparse.ArgumentParser()
9 | self.initialized = False
10 |
11 | def initialize(self):
12 | # experiment specifics
13 | self.parser.add_argument('--name', type=str, default='people', help='name of the experiment. It decides where to store samples and models')
14 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
15 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
16 | # self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use')
17 | self.parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization')
18 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
19 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
20 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
21 | self.parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP')
22 | self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
23 | self.parser.add_argument('--isTrain', type=bool, default=True, help='local rank for distributed training')
24 |
25 | # input/output sizes
26 | self.parser.add_argument('--batchSize', type=int, default=8, help='input batch size')
27 | self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size')
28 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
29 | self.parser.add_argument('--label_nc', type=int, default=0, help='# of input label channels')
30 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
31 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
32 |
33 | # for setting inputs
34 | self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/')
35 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
36 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
37 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
38 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
39 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
40 |
41 | # for displays
42 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
43 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
44 |
45 | # for generator
46 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
47 | self.parser.add_argument('--latent_size', type=int, default=512, help='latent size of Adain layer')
48 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
49 | self.parser.add_argument('--n_downsample_global', type=int, default=3, help='number of downsampling layers in netG')
50 | self.parser.add_argument('--n_blocks_global', type=int, default=6, help='number of residual blocks in the global generator network')
51 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
52 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
53 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')
54 |
55 | # for instance-wise features
56 | self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')
57 | self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input')
58 | self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input')
59 | self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features')
60 | self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps')
61 | self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder')
62 | self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
63 | self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features')
64 | self.parser.add_argument('--image_size', type=int, default=224, help='number of clusters for features')
65 | self.parser.add_argument('--norm_G', type=str, default='spectralspadesyncbatch3x3', help='instance normalization or batch normalization')
66 | self.parser.add_argument('--semantic_nc', type=int, default=3, help='number of clusters for features')
67 | self.initialized = True
68 |
69 | def parse(self, save=True):
70 | if not self.initialized:
71 | self.initialize()
72 | self.opt = self.parser.parse_args()
73 | self.opt.isTrain = self.isTrain # train or test
74 |
75 | str_ids = self.opt.gpu_ids.split(',')
76 | self.opt.gpu_ids = []
77 | for str_id in str_ids:
78 | id = int(str_id)
79 | if id >= 0:
80 | self.opt.gpu_ids.append(id)
81 |
82 | # set gpu ids
83 | if len(self.opt.gpu_ids) > 0:
84 | torch.cuda.set_device(self.opt.gpu_ids[0])
85 |
86 | args = vars(self.opt)
87 |
88 | print('------------ Options -------------')
89 | for k, v in sorted(args.items()):
90 | print('%s: %s' % (str(k), str(v)))
91 | print('-------------- End ----------------')
92 |
93 | # save to the disk
94 | if self.opt.isTrain:
95 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
96 | util.mkdirs(expr_dir)
97 | if save and not self.opt.continue_train:
98 | file_name = os.path.join(expr_dir, 'opt.txt')
99 | with open(file_name, 'wt') as opt_file:
100 | opt_file.write('------------ Options -------------\n')
101 | for k, v in sorted(args.items()):
102 | opt_file.write('%s: %s\n' % (str(k), str(v)))
103 | opt_file.write('-------------- End ----------------\n')
104 | return self.opt
105 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-23 17:08:08
7 | Description:
8 | '''
9 | from .base_options import BaseOptions
10 |
11 | class TestOptions(BaseOptions):
12 | def initialize(self):
13 | BaseOptions.initialize(self)
14 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
15 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
16 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
17 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
18 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
19 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
20 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
21 | self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map')
22 | self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
23 | self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
24 | self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
25 | self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT")
26 | self.parser.add_argument("--pic_a_path", type=str, default='G:/swap_data/ID/elon-musk-hero-image.jpeg', help="Person who provides identity information")
27 | self.parser.add_argument("--pic_b_path", type=str, default='./demo_file/multi_people.jpg', help="Person who provides information other than their identity")
28 | self.parser.add_argument("--pic_specific_path", type=str, default='./crop_224/zrf.jpg', help="The specific person to be swapped")
29 | self.parser.add_argument("--multisepcific_dir", type=str, default='./demo_file/multispecific', help="Dir for multi specific")
30 | self.parser.add_argument("--video_path", type=str, default='G:/swap_data/video/HSB_Demo_Trim.mp4', help="path for the video to swap")
31 | self.parser.add_argument("--temp_path", type=str, default='./temp_results', help="path to save temporarily images")
32 | self.parser.add_argument("--output_path", type=str, default='./output/', help="results path")
33 | self.parser.add_argument('--id_thres', type=float, default=0.03, help='how many test images to run')
34 | self.parser.add_argument('--no_simswaplogo', action='store_true', help='Remove the watermark')
35 | self.parser.add_argument('--use_mask', action='store_true', help='Use mask for better result')
36 | self.parser.add_argument('--crop_size', type=int, default=512, help='Crop of size of input image')
37 |
38 | self.isTrain = False
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TrainOptions(BaseOptions):
4 | def initialize(self):
5 | BaseOptions.initialize(self)
6 | # for displays
7 | self.parser.add_argument('--display_freq', type=int, default=99, help='frequency of showing training results on screen')
8 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
9 | self.parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results')
10 | self.parser.add_argument('--save_epoch_freq', type=int, default=10000, help='frequency of saving checkpoints at the end of epochs')
11 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
12 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
13 |
14 | # for training
15 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
16 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
17 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
18 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
19 | self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate')
20 | self.parser.add_argument('--niter_decay', type=int, default=10000, help='# of iter to linearly decay learning rate to zero')
21 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
22 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
23 |
24 | # for discriminators
25 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
26 | self.parser.add_argument('--n_layers_D', type=int, default=4, help='only used if which_model_netD==n_layers')
27 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
28 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
29 | self.parser.add_argument('--lambda_id', type=float, default=20.0, help='weight for id loss')
30 | self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss')
31 | self.parser.add_argument('--lambda_GP', type=float, default=10.0, help='weight for gradient penalty loss')
32 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
33 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
34 | self.parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)')
35 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
36 | self.parser.add_argument('--times_G', type=int, default=1,
37 | help='time of training generator before traning discriminator')
38 |
39 | self.isTrain = True
40 |
--------------------------------------------------------------------------------
/output/result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/output/result.jpg
--------------------------------------------------------------------------------
/parsing_model/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.model_zoo as modelzoo
8 |
9 | # from modules.bn import InPlaceABNSync as BatchNorm2d
10 |
11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | def __init__(self, in_chan, out_chan, stride=1):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(in_chan, out_chan, stride)
24 | self.bn1 = nn.BatchNorm2d(out_chan)
25 | self.conv2 = conv3x3(out_chan, out_chan)
26 | self.bn2 = nn.BatchNorm2d(out_chan)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = None
29 | if in_chan != out_chan or stride != 1:
30 | self.downsample = nn.Sequential(
31 | nn.Conv2d(in_chan, out_chan,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(out_chan),
34 | )
35 |
36 | def forward(self, x):
37 | residual = self.conv1(x)
38 | residual = F.relu(self.bn1(residual))
39 | residual = self.conv2(residual)
40 | residual = self.bn2(residual)
41 |
42 | shortcut = x
43 | if self.downsample is not None:
44 | shortcut = self.downsample(x)
45 |
46 | out = shortcut + residual
47 | out = self.relu(out)
48 | return out
49 |
50 |
51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53 | for i in range(bnum-1):
54 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
55 | return nn.Sequential(*layers)
56 |
57 |
58 | class Resnet18(nn.Module):
59 | def __init__(self):
60 | super(Resnet18, self).__init__()
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = nn.BatchNorm2d(64)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69 | self.init_weight()
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = F.relu(self.bn1(x))
74 | x = self.maxpool(x)
75 |
76 | x = self.layer1(x)
77 | feat8 = self.layer2(x) # 1/8
78 | feat16 = self.layer3(feat8) # 1/16
79 | feat32 = self.layer4(feat16) # 1/32
80 | return feat8, feat16, feat32
81 |
82 | def init_weight(self):
83 | state_dict = modelzoo.load_url(resnet18_url)
84 | self_state_dict = self.state_dict()
85 | for k, v in state_dict.items():
86 | if 'fc' in k: continue
87 | self_state_dict.update({k: v})
88 | self.load_state_dict(self_state_dict)
89 |
90 | def get_params(self):
91 | wd_params, nowd_params = [], []
92 | for name, module in self.named_modules():
93 | if isinstance(module, (nn.Linear, nn.Conv2d)):
94 | wd_params.append(module.weight)
95 | if not module.bias is None:
96 | nowd_params.append(module.bias)
97 | elif isinstance(module, nn.BatchNorm2d):
98 | nowd_params += list(module.parameters())
99 | return wd_params, nowd_params
100 |
101 |
102 | if __name__ == "__main__":
103 | net = Resnet18()
104 | x = torch.randn(16, 3, 224, 224)
105 | out = net(x)
106 | print(out[0].size())
107 | print(out[1].size())
108 | print(out[2].size())
109 | net.get_params()
110 |
--------------------------------------------------------------------------------
/pg_modules/diffaug.py:
--------------------------------------------------------------------------------
1 | # Differentiable Augmentation for Data-Efficient GAN Training
2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3 | # https://arxiv.org/pdf/2006.10738
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def DiffAugment(x, policy='', channels_first=True):
10 | if policy:
11 | if not channels_first:
12 | x = x.permute(0, 3, 1, 2)
13 | for p in policy.split(','):
14 | for f in AUGMENT_FNS[p]:
15 | x = f(x)
16 | if not channels_first:
17 | x = x.permute(0, 2, 3, 1)
18 | x = x.contiguous()
19 | return x
20 |
21 |
22 | def rand_brightness(x):
23 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
24 | return x
25 |
26 |
27 | def rand_saturation(x):
28 | x_mean = x.mean(dim=1, keepdim=True)
29 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30 | return x
31 |
32 |
33 | def rand_contrast(x):
34 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
35 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
36 | return x
37 |
38 |
39 | def rand_translation(x, ratio=0.125):
40 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
41 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
42 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
43 | grid_batch, grid_x, grid_y = torch.meshgrid(
44 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
45 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
46 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
47 | )
48 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
49 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
50 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
51 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
52 | return x
53 |
54 |
55 | def rand_cutout(x, ratio=0.2):
56 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
57 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
58 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
59 | grid_batch, grid_x, grid_y = torch.meshgrid(
60 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
61 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
62 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
63 | )
64 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
65 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
66 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
67 | mask[grid_batch, grid_x, grid_y] = 0
68 | x = x * mask.unsqueeze(1)
69 | return x
70 |
71 |
72 | AUGMENT_FNS = {
73 | 'color': [rand_brightness, rand_saturation, rand_contrast],
74 | 'translation': [rand_translation],
75 | 'cutout': [rand_cutout],
76 | }
77 |
--------------------------------------------------------------------------------
/pg_modules/projected_discriminator.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 | from pg_modules.blocks import DownBlock, DownBlockPatch, conv2d
7 | from pg_modules.projector import F_RandomProj
8 | from pg_modules.diffaug import DiffAugment
9 |
10 |
11 | class SingleDisc(nn.Module):
12 | def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False):
13 | super().__init__()
14 | channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
15 | 256: 32, 512: 16, 1024: 8}
16 |
17 | # interpolate for start sz that are not powers of two
18 | if start_sz not in channel_dict.keys():
19 | sizes = np.array(list(channel_dict.keys()))
20 | start_sz = sizes[np.argmin(abs(sizes - start_sz))]
21 | self.start_sz = start_sz
22 |
23 | # if given ndf, allocate all layers with the same ndf
24 | if ndf is None:
25 | nfc = channel_dict
26 | else:
27 | nfc = {k: ndf for k, v in channel_dict.items()}
28 |
29 | # for feature map discriminators with nfc not in channel_dict
30 | # this is the case for the pretrained backbone (midas.pretrained)
31 | if nc is not None and head is None:
32 | nfc[start_sz] = nc
33 |
34 | layers = []
35 |
36 | # Head if the initial input is the full modality
37 | if head:
38 | layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
39 | nn.LeakyReLU(0.2, inplace=True)]
40 |
41 | # Down Blocks
42 | DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
43 | while start_sz > end_sz:
44 | layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
45 | start_sz = start_sz // 2
46 |
47 | layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
48 | self.main = nn.Sequential(*layers)
49 |
50 | def forward(self, x, c):
51 | return self.main(x)
52 |
53 |
54 | class SingleDiscCond(nn.Module):
55 | def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128):
56 | super().__init__()
57 | self.cmap_dim = cmap_dim
58 |
59 | # midas channels
60 | channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
61 | 256: 32, 512: 16, 1024: 8}
62 |
63 | # interpolate for start sz that are not powers of two
64 | if start_sz not in channel_dict.keys():
65 | sizes = np.array(list(channel_dict.keys()))
66 | start_sz = sizes[np.argmin(abs(sizes - start_sz))]
67 | self.start_sz = start_sz
68 |
69 | # if given ndf, allocate all layers with the same ndf
70 | if ndf is None:
71 | nfc = channel_dict
72 | else:
73 | nfc = {k: ndf for k, v in channel_dict.items()}
74 |
75 | # for feature map discriminators with nfc not in channel_dict
76 | # this is the case for the pretrained backbone (midas.pretrained)
77 | if nc is not None and head is None:
78 | nfc[start_sz] = nc
79 |
80 | layers = []
81 |
82 | # Head if the initial input is the full modality
83 | if head:
84 | layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
85 | nn.LeakyReLU(0.2, inplace=True)]
86 |
87 | # Down Blocks
88 | DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
89 | while start_sz > end_sz:
90 | layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
91 | start_sz = start_sz // 2
92 | self.main = nn.Sequential(*layers)
93 |
94 | # additions for conditioning on class information
95 | self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False)
96 | self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim)
97 | self.embed_proj = nn.Sequential(
98 | nn.Linear(self.embed.embedding_dim, self.cmap_dim),
99 | nn.LeakyReLU(0.2, inplace=True),
100 | )
101 |
102 | def forward(self, x, c):
103 | h = self.main(x)
104 | out = self.cls(h)
105 |
106 | # conditioning via projection
107 | cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1)
108 | out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
109 |
110 | return out
111 |
112 |
113 | class MultiScaleD(nn.Module):
114 | def __init__(
115 | self,
116 | channels,
117 | resolutions,
118 | num_discs=4,
119 | proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
120 | cond=0,
121 | separable=False,
122 | patch=False,
123 | **kwargs,
124 | ):
125 | super().__init__()
126 |
127 | assert num_discs in [1, 2, 3, 4]
128 |
129 | # the first disc is on the lowest level of the backbone
130 | self.disc_in_channels = channels[:num_discs]
131 | self.disc_in_res = resolutions[:num_discs]
132 | Disc = SingleDiscCond if cond else SingleDisc
133 |
134 | mini_discs = []
135 | for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
136 | start_sz = res if not patch else 16
137 | mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)],
138 | self.mini_discs = nn.ModuleDict(mini_discs)
139 |
140 | def forward(self, features, c):
141 | all_logits = []
142 | for k, disc in self.mini_discs.items():
143 | res = disc(features[k], c).view(features[k].size(0), -1)
144 | all_logits.append(res)
145 |
146 | all_logits = torch.cat(all_logits, dim=1)
147 | return all_logits
148 |
149 |
150 | class ProjectedDiscriminator(torch.nn.Module):
151 | def __init__(
152 | self,
153 | diffaug=True,
154 | interp224=True,
155 | backbone_kwargs={},
156 | **kwargs
157 | ):
158 | super().__init__()
159 | self.diffaug = diffaug
160 | self.interp224 = interp224
161 | self.feature_network = F_RandomProj(**backbone_kwargs)
162 | self.discriminator = MultiScaleD(
163 | channels=self.feature_network.CHANNELS,
164 | resolutions=self.feature_network.RESOLUTIONS,
165 | **backbone_kwargs,
166 | )
167 |
168 | def train(self, mode=True):
169 | self.feature_network = self.feature_network.train(False)
170 | self.discriminator = self.discriminator.train(mode)
171 | return self
172 |
173 | def eval(self):
174 | return self.train(False)
175 |
176 | def get_feature(self, x):
177 | features = self.feature_network(x, get_features=True)
178 | return features
179 |
180 | def forward(self, x, c):
181 | # if self.diffaug:
182 | # x = DiffAugment(x, policy='color,translation,cutout')
183 |
184 | # if self.interp224:
185 | # x = F.interpolate(x, 224, mode='bilinear', align_corners=False)
186 |
187 | features,backbone_features = self.feature_network(x)
188 | logits = self.discriminator(features, c)
189 |
190 | return logits,backbone_features
191 |
192 |
--------------------------------------------------------------------------------
/pg_modules/projector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | from pg_modules.blocks import FeatureFusionBlock
5 |
6 |
7 | def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
8 | # shapes
9 | out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
10 |
11 | scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
12 | scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
13 | scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
14 | scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
15 |
16 | scratch.CHANNELS = out_channels
17 |
18 | return scratch
19 |
20 |
21 | def _make_scratch_csm(scratch, in_channels, cout, expand):
22 | scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
23 | scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
24 | scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
25 | scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
26 |
27 | # last refinenet does not expand to save channels in higher dimensions
28 | scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
29 |
30 | return scratch
31 |
32 |
33 | def _make_efficientnet(model):
34 | pretrained = nn.Module()
35 | pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2])
36 | pretrained.layer1 = nn.Sequential(*model.blocks[2:3])
37 | pretrained.layer2 = nn.Sequential(*model.blocks[3:5])
38 | pretrained.layer3 = nn.Sequential(*model.blocks[5:9])
39 | return pretrained
40 |
41 |
42 | def calc_channels(pretrained, inp_res=224):
43 | channels = []
44 | tmp = torch.zeros(1, 3, inp_res, inp_res)
45 |
46 | # forward pass
47 | tmp = pretrained.layer0(tmp)
48 | channels.append(tmp.shape[1])
49 | tmp = pretrained.layer1(tmp)
50 | channels.append(tmp.shape[1])
51 | tmp = pretrained.layer2(tmp)
52 | channels.append(tmp.shape[1])
53 | tmp = pretrained.layer3(tmp)
54 | channels.append(tmp.shape[1])
55 |
56 | return channels
57 |
58 |
59 | def _make_projector(im_res, cout, proj_type, expand=False):
60 | assert proj_type in [0, 1, 2], "Invalid projection type"
61 |
62 | ### Build pretrained feature network
63 | model = timm.create_model('tf_efficientnet_lite0', pretrained=True)
64 | pretrained = _make_efficientnet(model)
65 |
66 | # determine resolution of feature maps, this is later used to calculate the number
67 | # of down blocks in the discriminators. Interestingly, the best results are achieved
68 | # by fixing this to 256, ie., we use the same number of down blocks per discriminator
69 | # independent of the dataset resolution
70 | im_res = 256
71 | pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
72 | pretrained.CHANNELS = calc_channels(pretrained)
73 |
74 | if proj_type == 0: return pretrained, None
75 |
76 | ### Build CCM
77 | scratch = nn.Module()
78 | scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
79 | pretrained.CHANNELS = scratch.CHANNELS
80 |
81 | if proj_type == 1: return pretrained, scratch
82 |
83 | ### build CSM
84 | scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
85 |
86 | # CSM upsamples x2 so the feature map resolution doubles
87 | pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
88 | pretrained.CHANNELS = scratch.CHANNELS
89 |
90 | return pretrained, scratch
91 |
92 |
93 | class F_RandomProj(nn.Module):
94 | def __init__(
95 | self,
96 | im_res=256,
97 | cout=64,
98 | expand=True,
99 | proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
100 | **kwargs,
101 | ):
102 | super().__init__()
103 | self.proj_type = proj_type
104 | self.cout = cout
105 | self.expand = expand
106 |
107 | # build pretrained feature network and random decoder (scratch)
108 | self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand)
109 | self.CHANNELS = self.pretrained.CHANNELS
110 | self.RESOLUTIONS = self.pretrained.RESOLUTIONS
111 |
112 | def forward(self, x, get_features=False):
113 | # predict feature maps
114 | out0 = self.pretrained.layer0(x)
115 | out1 = self.pretrained.layer1(out0)
116 | out2 = self.pretrained.layer2(out1)
117 | out3 = self.pretrained.layer3(out2)
118 |
119 | # start enumerating at the lowest layer (this is where we put the first discriminator)
120 | backbone_features = {
121 | '0': out0,
122 | '1': out1,
123 | '2': out2,
124 | '3': out3,
125 | }
126 | if get_features:
127 | return backbone_features
128 |
129 | if self.proj_type == 0: return backbone_features
130 |
131 | out0_channel_mixed = self.scratch.layer0_ccm(backbone_features['0'])
132 | out1_channel_mixed = self.scratch.layer1_ccm(backbone_features['1'])
133 | out2_channel_mixed = self.scratch.layer2_ccm(backbone_features['2'])
134 | out3_channel_mixed = self.scratch.layer3_ccm(backbone_features['3'])
135 |
136 | out = {
137 | '0': out0_channel_mixed,
138 | '1': out1_channel_mixed,
139 | '2': out2_channel_mixed,
140 | '3': out3_channel_mixed,
141 | }
142 |
143 | if self.proj_type == 1: return out
144 |
145 | # from bottom to top
146 | out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
147 | out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
148 | out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
149 | out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
150 |
151 | out = {
152 | '0': out0_scale_mixed,
153 | '1': out1_scale_mixed,
154 | '2': out2_scale_mixed,
155 | '3': out3_scale_mixed,
156 | }
157 |
158 | return out, backbone_features
159 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import cog
2 | import tempfile
3 | from pathlib import Path
4 | import argparse
5 | import cv2
6 | import torch
7 | from PIL import Image
8 | import torch.nn.functional as F
9 | from torchvision import transforms
10 | from models.models import create_model
11 | from options.test_options import TestOptions
12 | from util.reverse2original import reverse2wholeimage
13 | from util.norm import SpecificNorm
14 | from test_wholeimage_swapmulti import _totensor
15 | from insightface_func.face_detect_crop_multi import Face_detect_crop as Face_detect_crop_multi
16 | from insightface_func.face_detect_crop_single import Face_detect_crop as Face_detect_crop_single
17 |
18 |
19 | class Predictor(cog.Predictor):
20 | def setup(self):
21 | self.transformer_Arcface = transforms.Compose([
22 | transforms.ToTensor(),
23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24 | ])
25 |
26 | @cog.input("source", type=Path, help="source image")
27 | @cog.input("target", type=Path, help="target image")
28 | @cog.input("mode", type=str, options=['single', 'all'], default='all',
29 | help="swap a single face (the one with highest confidence by face detection) or all faces in the target image")
30 | def predict(self, source, target, mode='all'):
31 |
32 | app = Face_detect_crop_multi(name='antelope', root='./insightface_func/models')
33 |
34 | if mode == 'single':
35 | app = Face_detect_crop_single(name='antelope', root='./insightface_func/models')
36 |
37 | app.prepare(ctx_id=0, det_thresh=0.6, det_size=(640, 640))
38 |
39 | options = TestOptions()
40 | options.initialize()
41 | opt = options.parser.parse_args(["--Arc_path", 'arcface_model/arcface_checkpoint.tar', "--pic_a_path", str(source),
42 | "--pic_b_path", str(target), "--isTrain", False, "--no_simswaplogo"])
43 |
44 | str_ids = opt.gpu_ids.split(',')
45 | opt.gpu_ids = []
46 | for str_id in str_ids:
47 | id = int(str_id)
48 | if id >= 0:
49 | opt.gpu_ids.append(id)
50 |
51 | # set gpu ids
52 | if len(opt.gpu_ids) > 0:
53 | torch.cuda.set_device(opt.gpu_ids[0])
54 |
55 | torch.nn.Module.dump_patches = True
56 | model = create_model(opt)
57 | model.eval()
58 |
59 | crop_size = opt.crop_size
60 | spNorm = SpecificNorm()
61 |
62 | with torch.no_grad():
63 | pic_a = opt.pic_a_path
64 | img_a_whole = cv2.imread(pic_a)
65 | img_a_align_crop, _ = app.get(img_a_whole, crop_size)
66 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0], cv2.COLOR_BGR2RGB))
67 | img_a = self.transformer_Arcface(img_a_align_crop_pil)
68 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
69 |
70 | # convert numpy to tensor
71 | img_id = img_id.cuda()
72 |
73 | # create latent id
74 | img_id_downsample = F.interpolate(img_id, size=(112,112))
75 | latend_id = model.netArc(img_id_downsample)
76 | latend_id = F.normalize(latend_id, p=2, dim=1)
77 |
78 | ############## Forward Pass ######################
79 |
80 | pic_b = opt.pic_b_path
81 | img_b_whole = cv2.imread(pic_b)
82 | img_b_align_crop_list, b_mat_list = app.get(img_b_whole, crop_size)
83 |
84 | swap_result_list = []
85 | b_align_crop_tenor_list = []
86 |
87 | for b_align_crop in img_b_align_crop_list:
88 | b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop, cv2.COLOR_BGR2RGB))[None, ...].cuda()
89 |
90 | swap_result = model(None, b_align_crop_tenor, latend_id, None, True)[0]
91 | swap_result_list.append(swap_result)
92 | b_align_crop_tenor_list.append(b_align_crop_tenor)
93 |
94 | net = None
95 |
96 | out_path = Path(tempfile.mkdtemp()) / "output.png"
97 |
98 | reverse2wholeimage(b_align_crop_tenor_list, swap_result_list, b_mat_list, crop_size, img_b_whole, None,
99 | str(out_path), opt.no_simswaplogo,
100 | pasring_model=net, use_mask=opt.use_mask, norm=spNorm)
101 | return out_path
102 |
--------------------------------------------------------------------------------
/simswaplogo/simswaplogo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuralchen/SimSwap/bd7b7686a17f41dd11cfcd5d82f7e4c5eb94b780/simswaplogo/simswaplogo.png
--------------------------------------------------------------------------------
/simswaplogo/socialbook_logo.2020.357eed90add7705e54a8.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
65 |
--------------------------------------------------------------------------------
/test_one_image.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import torch
4 | import fractions
5 | import numpy as np
6 | from PIL import Image
7 | import torch.nn.functional as F
8 | from torchvision import transforms
9 | from models.models import create_model
10 | from options.test_options import TestOptions
11 |
12 |
13 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0
14 |
15 | transformer = transforms.Compose([
16 | transforms.ToTensor(),
17 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
18 | ])
19 |
20 | transformer_Arcface = transforms.Compose([
21 | transforms.ToTensor(),
22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
23 | ])
24 |
25 | detransformer = transforms.Compose([
26 | transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]),
27 | transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1])
28 | ])
29 | if __name__ == '__main__':
30 | opt = TestOptions().parse()
31 |
32 | start_epoch, epoch_iter = 1, 0
33 |
34 | torch.nn.Module.dump_patches = True
35 | model = create_model(opt)
36 | model.eval()
37 |
38 | with torch.no_grad():
39 |
40 | pic_a = opt.pic_a_path
41 | img_a = Image.open(pic_a).convert('RGB')
42 | img_a = transformer_Arcface(img_a)
43 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
44 |
45 | pic_b = opt.pic_b_path
46 |
47 | img_b = Image.open(pic_b).convert('RGB')
48 | img_b = transformer(img_b)
49 | img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2])
50 |
51 | # convert numpy to tensor
52 | img_id = img_id.cuda()
53 | img_att = img_att.cuda()
54 |
55 | #create latent id
56 | img_id_downsample = F.interpolate(img_id, size=(112,112))
57 | latend_id = model.netArc(img_id_downsample)
58 | latend_id = latend_id.detach().to('cpu')
59 | latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True)
60 | latend_id = latend_id.to('cuda')
61 |
62 |
63 | ############## Forward Pass ######################
64 | img_fake = model(img_id, img_att, latend_id, latend_id, True)
65 |
66 |
67 | for i in range(img_id.shape[0]):
68 | if i == 0:
69 | row1 = img_id[i]
70 | row2 = img_att[i]
71 | row3 = img_fake[i]
72 | else:
73 | row1 = torch.cat([row1, img_id[i]], dim=2)
74 | row2 = torch.cat([row2, img_att[i]], dim=2)
75 | row3 = torch.cat([row3, img_fake[i]], dim=2)
76 |
77 | #full = torch.cat([row1, row2, row3], dim=1).detach()
78 | full = row3.detach()
79 | full = full.permute(1, 2, 0)
80 | output = full.to('cpu')
81 | output = np.array(output)
82 | output = output[..., ::-1]
83 |
84 | output = output*255
85 |
86 | cv2.imwrite(opt.output_path + 'result.jpg', output)
--------------------------------------------------------------------------------
/test_video_swap_multispecific.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import torch
4 | import fractions
5 | from PIL import Image
6 | import torch.nn.functional as F
7 | from torchvision import transforms
8 | from models.models import create_model
9 | from options.test_options import TestOptions
10 | from insightface_func.face_detect_crop_multi import Face_detect_crop
11 | from util.videoswap_multispecific import video_swap
12 | import os
13 | import glob
14 |
15 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0
16 |
17 | transformer = transforms.Compose([
18 | transforms.ToTensor(),
19 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20 | ])
21 |
22 | transformer_Arcface = transforms.Compose([
23 | transforms.ToTensor(),
24 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
25 | ])
26 |
27 | # detransformer = transforms.Compose([
28 | # transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]),
29 | # transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1])
30 | # ])
31 |
32 |
33 | if __name__ == '__main__':
34 | opt = TestOptions().parse()
35 | pic_specific = opt.pic_specific_path
36 | start_epoch, epoch_iter = 1, 0
37 | crop_size = opt.crop_size
38 |
39 | multisepcific_dir = opt.multisepcific_dir
40 | torch.nn.Module.dump_patches = True
41 | if crop_size == 512:
42 | opt.which_epoch = 550000
43 | opt.name = '512'
44 | mode = 'ffhq'
45 | else:
46 | mode = 'None'
47 | model = create_model(opt)
48 | model.eval()
49 |
50 |
51 | app = Face_detect_crop(name='antelope', root='./insightface_func/models')
52 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
53 |
54 | # The specific person to be swapped(source)
55 |
56 | source_specific_id_nonorm_list = []
57 | source_path = os.path.join(multisepcific_dir,'SRC_*')
58 | source_specific_images_path = sorted(glob.glob(source_path))
59 | with torch.no_grad():
60 | for source_specific_image_path in source_specific_images_path:
61 | specific_person_whole = cv2.imread(source_specific_image_path)
62 | specific_person_align_crop, _ = app.get(specific_person_whole,crop_size)
63 | specific_person_align_crop_pil = Image.fromarray(cv2.cvtColor(specific_person_align_crop[0],cv2.COLOR_BGR2RGB))
64 | specific_person = transformer_Arcface(specific_person_align_crop_pil)
65 | specific_person = specific_person.view(-1, specific_person.shape[0], specific_person.shape[1], specific_person.shape[2])
66 | # convert numpy to tensor
67 | specific_person = specific_person.cuda()
68 | #create latent id
69 | specific_person_downsample = F.interpolate(specific_person, size=(112,112))
70 | specific_person_id_nonorm = model.netArc(specific_person_downsample)
71 | source_specific_id_nonorm_list.append(specific_person_id_nonorm.clone())
72 |
73 |
74 | # The person who provides id information (list)
75 | target_id_norm_list = []
76 | target_path = os.path.join(multisepcific_dir,'DST_*')
77 | target_images_path = sorted(glob.glob(target_path))
78 |
79 | for target_image_path in target_images_path:
80 | img_a_whole = cv2.imread(target_image_path)
81 | img_a_align_crop, _ = app.get(img_a_whole,crop_size)
82 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB))
83 | img_a = transformer_Arcface(img_a_align_crop_pil)
84 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
85 | # convert numpy to tensor
86 | img_id = img_id.cuda()
87 | #create latent id
88 | img_id_downsample = F.interpolate(img_id, size=(112,112))
89 | latend_id = model.netArc(img_id_downsample)
90 | latend_id = F.normalize(latend_id, p=2, dim=1)
91 | target_id_norm_list.append(latend_id.clone())
92 |
93 | assert len(target_id_norm_list) == len(source_specific_id_nonorm_list), "The number of images in source and target directory must be same !!!"
94 |
95 |
96 |
97 | video_swap(opt.video_path, target_id_norm_list,source_specific_id_nonorm_list, opt.id_thres, \
98 | model, app, opt.output_path,temp_results_dir=opt.temp_path,no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size)
99 |
100 |
--------------------------------------------------------------------------------
/test_video_swapmulti.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-24 19:00:34
7 | Description:
8 | '''
9 |
10 | import cv2
11 | import torch
12 | import fractions
13 | import numpy as np
14 | from PIL import Image
15 | import torch.nn.functional as F
16 | from torchvision import transforms
17 | from models.models import create_model
18 | from options.test_options import TestOptions
19 | from insightface_func.face_detect_crop_multi import Face_detect_crop
20 | from util.videoswap import video_swap
21 | import os
22 |
23 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0
24 |
25 | transformer = transforms.Compose([
26 | transforms.ToTensor(),
27 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
28 | ])
29 |
30 | transformer_Arcface = transforms.Compose([
31 | transforms.ToTensor(),
32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
33 | ])
34 |
35 | # detransformer = transforms.Compose([
36 | # transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]),
37 | # transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1])
38 | # ])
39 |
40 |
41 | if __name__ == '__main__':
42 | opt = TestOptions().parse()
43 |
44 | start_epoch, epoch_iter = 1, 0
45 | crop_size = opt.crop_size
46 |
47 | torch.nn.Module.dump_patches = True
48 |
49 | if crop_size == 512:
50 | opt.which_epoch = 550000
51 | opt.name = '512'
52 | mode = 'ffhq'
53 | else:
54 | mode = 'None'
55 | model = create_model(opt)
56 | model.eval()
57 |
58 | app = Face_detect_crop(name='antelope', root='./insightface_func/models')
59 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode)
60 |
61 | with torch.no_grad():
62 | pic_a = opt.pic_a_path
63 | # img_a = Image.open(pic_a).convert('RGB')
64 | img_a_whole = cv2.imread(pic_a)
65 | img_a_align_crop, _ = app.get(img_a_whole,crop_size)
66 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB))
67 | img_a = transformer_Arcface(img_a_align_crop_pil)
68 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
69 |
70 | # pic_b = opt.pic_b_path
71 | # img_b_whole = cv2.imread(pic_b)
72 | # img_b_align_crop, b_mat = app.get(img_b_whole,crop_size)
73 | # img_b_align_crop_pil = Image.fromarray(cv2.cvtColor(img_b_align_crop,cv2.COLOR_BGR2RGB))
74 | # img_b = transformer(img_b_align_crop_pil)
75 | # img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2])
76 |
77 | # convert numpy to tensor
78 | img_id = img_id.cuda()
79 | # img_att = img_att.cuda()
80 |
81 | #create latent id
82 | img_id_downsample = F.interpolate(img_id, size=(112,112))
83 | latend_id = model.netArc(img_id_downsample)
84 | latend_id = F.normalize(latend_id, p=2, dim=1)
85 |
86 | video_swap(opt.video_path, latend_id, model, app, opt.output_path,temp_results_dir=opt.temp_path,\
87 | no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size)
88 |
89 |
--------------------------------------------------------------------------------
/test_video_swapsingle.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-24 19:00:38
7 | Description:
8 | '''
9 |
10 | import cv2
11 | import torch
12 | import fractions
13 | import numpy as np
14 | from PIL import Image
15 | import torch.nn.functional as F
16 | from torchvision import transforms
17 | from models.models import create_model
18 | from options.test_options import TestOptions
19 | from insightface_func.face_detect_crop_single import Face_detect_crop
20 | from util.videoswap import video_swap
21 | import os
22 |
23 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0
24 |
25 | transformer = transforms.Compose([
26 | transforms.ToTensor(),
27 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
28 | ])
29 |
30 | transformer_Arcface = transforms.Compose([
31 | transforms.ToTensor(),
32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
33 | ])
34 |
35 | # detransformer = transforms.Compose([
36 | # transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]),
37 | # transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1])
38 | # ])
39 |
40 |
41 | if __name__ == '__main__':
42 | opt = TestOptions().parse()
43 |
44 | start_epoch, epoch_iter = 1, 0
45 | crop_size = opt.crop_size
46 |
47 | torch.nn.Module.dump_patches = True
48 | if crop_size == 512:
49 | opt.which_epoch = 550000
50 | opt.name = '512'
51 | mode = 'ffhq'
52 | else:
53 | mode = 'None'
54 | model = create_model(opt)
55 | model.eval()
56 |
57 |
58 | app = Face_detect_crop(name='antelope', root='./insightface_func/models')
59 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
60 | with torch.no_grad():
61 | pic_a = opt.pic_a_path
62 | # img_a = Image.open(pic_a).convert('RGB')
63 | img_a_whole = cv2.imread(pic_a)
64 | img_a_align_crop, _ = app.get(img_a_whole,crop_size)
65 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB))
66 | img_a = transformer_Arcface(img_a_align_crop_pil)
67 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
68 |
69 | # pic_b = opt.pic_b_path
70 | # img_b_whole = cv2.imread(pic_b)
71 | # img_b_align_crop, b_mat = app.get(img_b_whole,crop_size)
72 | # img_b_align_crop_pil = Image.fromarray(cv2.cvtColor(img_b_align_crop,cv2.COLOR_BGR2RGB))
73 | # img_b = transformer(img_b_align_crop_pil)
74 | # img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2])
75 |
76 | # convert numpy to tensor
77 | img_id = img_id.cuda()
78 | # img_att = img_att.cuda()
79 |
80 | #create latent id
81 | img_id_downsample = F.interpolate(img_id, size=(112,112))
82 | latend_id = model.netArc(img_id_downsample)
83 | latend_id = F.normalize(latend_id, p=2, dim=1)
84 |
85 | video_swap(opt.video_path, latend_id, model, app, opt.output_path,temp_results_dir=opt.temp_path,\
86 | no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size)
87 |
88 |
--------------------------------------------------------------------------------
/test_video_swapspecific.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-24 19:00:42
7 | Description:
8 | '''
9 |
10 | import cv2
11 | import torch
12 | import fractions
13 | import numpy as np
14 | from PIL import Image
15 | import torch.nn.functional as F
16 | from torchvision import transforms
17 | from models.models import create_model
18 | from options.test_options import TestOptions
19 | from insightface_func.face_detect_crop_multi import Face_detect_crop
20 | from util.videoswap_specific import video_swap
21 | import os
22 |
23 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0
24 |
25 | transformer = transforms.Compose([
26 | transforms.ToTensor(),
27 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
28 | ])
29 |
30 | transformer_Arcface = transforms.Compose([
31 | transforms.ToTensor(),
32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
33 | ])
34 |
35 | # detransformer = transforms.Compose([
36 | # transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]),
37 | # transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1])
38 | # ])
39 |
40 |
41 | if __name__ == '__main__':
42 | opt = TestOptions().parse()
43 | pic_specific = opt.pic_specific_path
44 | start_epoch, epoch_iter = 1, 0
45 | crop_size = opt.crop_size
46 |
47 | torch.nn.Module.dump_patches = True
48 | if crop_size == 512:
49 | opt.which_epoch = 550000
50 | opt.name = '512'
51 | mode = 'ffhq'
52 | else:
53 | mode = 'None'
54 | model = create_model(opt)
55 | model.eval()
56 |
57 |
58 | app = Face_detect_crop(name='antelope', root='./insightface_func/models')
59 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
60 | with torch.no_grad():
61 | pic_a = opt.pic_a_path
62 | # img_a = Image.open(pic_a).convert('RGB')
63 | img_a_whole = cv2.imread(pic_a)
64 | img_a_align_crop, _ = app.get(img_a_whole,crop_size)
65 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB))
66 | img_a = transformer_Arcface(img_a_align_crop_pil)
67 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
68 |
69 | # pic_b = opt.pic_b_path
70 | # img_b_whole = cv2.imread(pic_b)
71 | # img_b_align_crop, b_mat = app.get(img_b_whole,crop_size)
72 | # img_b_align_crop_pil = Image.fromarray(cv2.cvtColor(img_b_align_crop,cv2.COLOR_BGR2RGB))
73 | # img_b = transformer(img_b_align_crop_pil)
74 | # img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2])
75 |
76 | # convert numpy to tensor
77 | img_id = img_id.cuda()
78 | # img_att = img_att.cuda()
79 |
80 | #create latent id
81 | img_id_downsample = F.interpolate(img_id, size=(112,112))
82 | latend_id = model.netArc(img_id_downsample)
83 | latend_id = F.normalize(latend_id, p=2, dim=1)
84 |
85 |
86 | # The specific person to be swapped
87 | specific_person_whole = cv2.imread(pic_specific)
88 | specific_person_align_crop, _ = app.get(specific_person_whole,crop_size)
89 | specific_person_align_crop_pil = Image.fromarray(cv2.cvtColor(specific_person_align_crop[0],cv2.COLOR_BGR2RGB))
90 | specific_person = transformer_Arcface(specific_person_align_crop_pil)
91 | specific_person = specific_person.view(-1, specific_person.shape[0], specific_person.shape[1], specific_person.shape[2])
92 | specific_person = specific_person.cuda()
93 | specific_person_downsample = F.interpolate(specific_person, size=(112,112))
94 | specific_person_id_nonorm = model.netArc(specific_person_downsample)
95 |
96 | video_swap(opt.video_path, latend_id,specific_person_id_nonorm, opt.id_thres, \
97 | model, app, opt.output_path,temp_results_dir=opt.temp_path,no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size)
98 |
99 |
--------------------------------------------------------------------------------
/test_wholeimage_swap_multispecific.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-24 19:19:22
7 | Description:
8 | '''
9 |
10 | import cv2
11 | import torch
12 | import fractions
13 | import numpy as np
14 | from PIL import Image
15 | import torch.nn.functional as F
16 | from torchvision import transforms
17 | from models.models import create_model
18 | from options.test_options import TestOptions
19 | from insightface_func.face_detect_crop_multi import Face_detect_crop
20 | from util.reverse2original import reverse2wholeimage
21 | import os
22 | from util.add_watermark import watermark_image
23 | import torch.nn as nn
24 | from util.norm import SpecificNorm
25 | import glob
26 | from parsing_model.model import BiSeNet
27 |
28 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0
29 |
30 | transformer_Arcface = transforms.Compose([
31 | transforms.ToTensor(),
32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
33 | ])
34 |
35 | def _totensor(array):
36 | tensor = torch.from_numpy(array)
37 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
38 | return img.float().div(255)
39 |
40 | def _toarctensor(array):
41 | tensor = torch.from_numpy(array)
42 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
43 | return img.float().div(255)
44 |
45 | if __name__ == '__main__':
46 | opt = TestOptions().parse()
47 |
48 | start_epoch, epoch_iter = 1, 0
49 | crop_size = opt.crop_size
50 |
51 | multisepcific_dir = opt.multisepcific_dir
52 |
53 | torch.nn.Module.dump_patches = True
54 |
55 | if crop_size == 512:
56 | opt.which_epoch = 550000
57 | opt.name = '512'
58 | mode = 'ffhq'
59 | else:
60 | mode = 'None'
61 |
62 | logoclass = watermark_image('./simswaplogo/simswaplogo.png')
63 | model = create_model(opt)
64 | model.eval()
65 | mse = torch.nn.MSELoss().cuda()
66 |
67 | spNorm =SpecificNorm()
68 |
69 |
70 | app = Face_detect_crop(name='antelope', root='./insightface_func/models')
71 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode)
72 |
73 | with torch.no_grad():
74 | # The specific person to be swapped(source)
75 |
76 | source_specific_id_nonorm_list = []
77 | source_path = os.path.join(multisepcific_dir,'SRC_*')
78 | source_specific_images_path = sorted(glob.glob(source_path))
79 |
80 | for source_specific_image_path in source_specific_images_path:
81 | specific_person_whole = cv2.imread(source_specific_image_path)
82 | specific_person_align_crop, _ = app.get(specific_person_whole,crop_size)
83 | specific_person_align_crop_pil = Image.fromarray(cv2.cvtColor(specific_person_align_crop[0],cv2.COLOR_BGR2RGB))
84 | specific_person = transformer_Arcface(specific_person_align_crop_pil)
85 | specific_person = specific_person.view(-1, specific_person.shape[0], specific_person.shape[1], specific_person.shape[2])
86 | # convert numpy to tensor
87 | specific_person = specific_person.cuda()
88 | #create latent id
89 | specific_person_downsample = F.interpolate(specific_person, size=(112,112))
90 | specific_person_id_nonorm = model.netArc(specific_person_downsample)
91 | source_specific_id_nonorm_list.append(specific_person_id_nonorm.clone())
92 |
93 |
94 | # The person who provides id information (list)
95 | target_id_norm_list = []
96 | target_path = os.path.join(multisepcific_dir,'DST_*')
97 | target_images_path = sorted(glob.glob(target_path))
98 |
99 | for target_image_path in target_images_path:
100 | img_a_whole = cv2.imread(target_image_path)
101 | img_a_align_crop, _ = app.get(img_a_whole,crop_size)
102 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB))
103 | img_a = transformer_Arcface(img_a_align_crop_pil)
104 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
105 | # convert numpy to tensor
106 | img_id = img_id.cuda()
107 | #create latent id
108 | img_id_downsample = F.interpolate(img_id, size=(112,112))
109 | latend_id = model.netArc(img_id_downsample)
110 | latend_id = F.normalize(latend_id, p=2, dim=1)
111 | target_id_norm_list.append(latend_id.clone())
112 |
113 | assert len(target_id_norm_list) == len(source_specific_id_nonorm_list), "The number of images in source and target directory must be same !!!"
114 |
115 | ############## Forward Pass ######################
116 |
117 | pic_b = opt.pic_b_path
118 | img_b_whole = cv2.imread(pic_b)
119 |
120 | img_b_align_crop_list, b_mat_list = app.get(img_b_whole,crop_size)
121 | # detect_results = None
122 | swap_result_list = []
123 |
124 | id_compare_values = []
125 | b_align_crop_tenor_list = []
126 | for b_align_crop in img_b_align_crop_list:
127 |
128 | b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
129 |
130 | b_align_crop_tenor_arcnorm = spNorm(b_align_crop_tenor)
131 | b_align_crop_tenor_arcnorm_downsample = F.interpolate(b_align_crop_tenor_arcnorm, size=(112,112))
132 | b_align_crop_id_nonorm = model.netArc(b_align_crop_tenor_arcnorm_downsample)
133 |
134 | id_compare_values.append([])
135 | for source_specific_id_nonorm_tmp in source_specific_id_nonorm_list:
136 | id_compare_values[-1].append(mse(b_align_crop_id_nonorm,source_specific_id_nonorm_tmp).detach().cpu().numpy())
137 | b_align_crop_tenor_list.append(b_align_crop_tenor)
138 |
139 | id_compare_values_array = np.array(id_compare_values).transpose(1,0)
140 | min_indexs = np.argmin(id_compare_values_array,axis=0)
141 | min_value = np.min(id_compare_values_array,axis=0)
142 |
143 | swap_result_list = []
144 | swap_result_matrix_list = []
145 | swap_result_ori_pic_list = []
146 |
147 | for tmp_index, min_index in enumerate(min_indexs):
148 | if min_value[tmp_index] < opt.id_thres:
149 | swap_result = model(None, b_align_crop_tenor_list[tmp_index], target_id_norm_list[min_index], None, True)[0]
150 | swap_result_list.append(swap_result)
151 | swap_result_matrix_list.append(b_mat_list[tmp_index])
152 | swap_result_ori_pic_list.append(b_align_crop_tenor_list[tmp_index])
153 | else:
154 | pass
155 |
156 | if len(swap_result_list) !=0:
157 |
158 | if opt.use_mask:
159 | n_classes = 19
160 | net = BiSeNet(n_classes=n_classes)
161 | net.cuda()
162 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth')
163 | net.load_state_dict(torch.load(save_pth))
164 | net.eval()
165 | else:
166 | net =None
167 |
168 | reverse2wholeimage(swap_result_ori_pic_list, swap_result_list, swap_result_matrix_list, crop_size, img_b_whole, logoclass,\
169 | os.path.join(opt.output_path, 'result_whole_swap_multispecific.jpg'), opt.no_simswaplogo,pasring_model =net,use_mask=opt.use_mask, norm = spNorm)
170 |
171 | print(' ')
172 |
173 | print('************ Done ! ************')
174 |
175 | else:
176 | print('The people you specified are not found on the picture: {}'.format(pic_b))
177 |
--------------------------------------------------------------------------------
/test_wholeimage_swapmulti.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-24 19:19:26
7 | Description:
8 | '''
9 |
10 | import cv2
11 | import torch
12 | import fractions
13 | import numpy as np
14 | from PIL import Image
15 | import torch.nn.functional as F
16 | from torchvision import transforms
17 | from models.models import create_model
18 | from options.test_options import TestOptions
19 | from insightface_func.face_detect_crop_multi import Face_detect_crop
20 | from util.reverse2original import reverse2wholeimage
21 | import os
22 | from util.add_watermark import watermark_image
23 | from util.norm import SpecificNorm
24 | from parsing_model.model import BiSeNet
25 |
26 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0
27 |
28 | transformer_Arcface = transforms.Compose([
29 | transforms.ToTensor(),
30 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
31 | ])
32 |
33 | def _totensor(array):
34 | tensor = torch.from_numpy(array)
35 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
36 | return img.float().div(255)
37 |
38 | if __name__ == '__main__':
39 | opt = TestOptions().parse()
40 |
41 | start_epoch, epoch_iter = 1, 0
42 | crop_size = opt.crop_size
43 |
44 | torch.nn.Module.dump_patches = True
45 | if crop_size == 512:
46 | opt.which_epoch = 550000
47 | opt.name = '512'
48 | mode = 'ffhq'
49 | else:
50 | mode = 'None'
51 | logoclass = watermark_image('./simswaplogo/simswaplogo.png')
52 | model = create_model(opt)
53 | model.eval()
54 | spNorm =SpecificNorm()
55 |
56 | app = Face_detect_crop(name='antelope', root='./insightface_func/models')
57 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
58 |
59 | with torch.no_grad():
60 | pic_a = opt.pic_a_path
61 |
62 | img_a_whole = cv2.imread(pic_a)
63 | img_a_align_crop, _ = app.get(img_a_whole,crop_size)
64 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB))
65 | img_a = transformer_Arcface(img_a_align_crop_pil)
66 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
67 |
68 | # convert numpy to tensor
69 | img_id = img_id.cuda()
70 |
71 | #create latent id
72 | img_id_downsample = F.interpolate(img_id, size=(112,112))
73 | latend_id = model.netArc(img_id_downsample)
74 | latend_id = F.normalize(latend_id, p=2, dim=1)
75 |
76 |
77 | ############## Forward Pass ######################
78 |
79 | pic_b = opt.pic_b_path
80 | img_b_whole = cv2.imread(pic_b)
81 |
82 | img_b_align_crop_list, b_mat_list = app.get(img_b_whole,crop_size)
83 | # detect_results = None
84 | swap_result_list = []
85 | b_align_crop_tenor_list = []
86 |
87 | for b_align_crop in img_b_align_crop_list:
88 |
89 | b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
90 |
91 | swap_result = model(None, b_align_crop_tenor, latend_id, None, True)[0]
92 | swap_result_list.append(swap_result)
93 | b_align_crop_tenor_list.append(b_align_crop_tenor)
94 |
95 |
96 | if opt.use_mask:
97 | n_classes = 19
98 | net = BiSeNet(n_classes=n_classes)
99 | net.cuda()
100 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth')
101 | net.load_state_dict(torch.load(save_pth))
102 | net.eval()
103 | else:
104 | net =None
105 |
106 | reverse2wholeimage(b_align_crop_tenor_list,swap_result_list, b_mat_list, crop_size, img_b_whole, logoclass, \
107 | os.path.join(opt.output_path, 'result_whole_swapmulti.jpg'),opt.no_simswaplogo,pasring_model =net,use_mask=opt.use_mask, norm = spNorm)
108 | print(' ')
109 |
110 | print('************ Done ! ************')
111 |
--------------------------------------------------------------------------------
/test_wholeimage_swapsingle.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-24 19:19:43
7 | Description:
8 | '''
9 |
10 | import cv2
11 | import torch
12 | import fractions
13 | import numpy as np
14 | from PIL import Image
15 | import torch.nn.functional as F
16 | from torchvision import transforms
17 | from models.models import create_model
18 | from options.test_options import TestOptions
19 | from insightface_func.face_detect_crop_single import Face_detect_crop
20 | from util.reverse2original import reverse2wholeimage
21 | import os
22 | from util.add_watermark import watermark_image
23 | from util.norm import SpecificNorm
24 | from parsing_model.model import BiSeNet
25 |
26 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0
27 |
28 | transformer_Arcface = transforms.Compose([
29 | transforms.ToTensor(),
30 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
31 | ])
32 |
33 | def _totensor(array):
34 | tensor = torch.from_numpy(array)
35 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
36 | return img.float().div(255)
37 | if __name__ == '__main__':
38 | opt = TestOptions().parse()
39 |
40 | start_epoch, epoch_iter = 1, 0
41 | crop_size = opt.crop_size
42 |
43 | torch.nn.Module.dump_patches = True
44 | if crop_size == 512:
45 | opt.which_epoch = 550000
46 | opt.name = '512'
47 | mode = 'ffhq'
48 | else:
49 | mode = 'None'
50 | logoclass = watermark_image('./simswaplogo/simswaplogo.png')
51 | model = create_model(opt)
52 | model.eval()
53 |
54 | spNorm =SpecificNorm()
55 | app = Face_detect_crop(name='antelope', root='./insightface_func/models')
56 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
57 |
58 | with torch.no_grad():
59 | pic_a = opt.pic_a_path
60 |
61 | img_a_whole = cv2.imread(pic_a)
62 | img_a_align_crop, _ = app.get(img_a_whole,crop_size)
63 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB))
64 | img_a = transformer_Arcface(img_a_align_crop_pil)
65 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
66 |
67 | # convert numpy to tensor
68 | img_id = img_id.cuda()
69 |
70 | #create latent id
71 | img_id_downsample = F.interpolate(img_id, size=(112,112))
72 | latend_id = model.netArc(img_id_downsample)
73 | latend_id = F.normalize(latend_id, p=2, dim=1)
74 |
75 |
76 | ############## Forward Pass ######################
77 |
78 | pic_b = opt.pic_b_path
79 | img_b_whole = cv2.imread(pic_b)
80 |
81 | img_b_align_crop_list, b_mat_list = app.get(img_b_whole,crop_size)
82 | # detect_results = None
83 | swap_result_list = []
84 |
85 | b_align_crop_tenor_list = []
86 |
87 | for b_align_crop in img_b_align_crop_list:
88 |
89 | b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
90 |
91 | swap_result = model(None, b_align_crop_tenor, latend_id, None, True)[0]
92 | swap_result_list.append(swap_result)
93 | b_align_crop_tenor_list.append(b_align_crop_tenor)
94 |
95 | if opt.use_mask:
96 | n_classes = 19
97 | net = BiSeNet(n_classes=n_classes)
98 | net.cuda()
99 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth')
100 | net.load_state_dict(torch.load(save_pth))
101 | net.eval()
102 | else:
103 | net =None
104 |
105 | reverse2wholeimage(b_align_crop_tenor_list, swap_result_list, b_mat_list, crop_size, img_b_whole, logoclass, \
106 | os.path.join(opt.output_path, 'result_whole_swapsingle.jpg'), opt.no_simswaplogo,pasring_model =net,use_mask=opt.use_mask, norm = spNorm)
107 |
108 | print(' ')
109 |
110 | print('************ Done ! ************')
111 |
--------------------------------------------------------------------------------
/test_wholeimage_swapspecific.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-24 19:19:47
7 | Description:
8 | '''
9 |
10 | import cv2
11 | import torch
12 | import fractions
13 | import numpy as np
14 | from PIL import Image
15 | import torch.nn.functional as F
16 | from torchvision import transforms
17 | from models.models import create_model
18 | from options.test_options import TestOptions
19 | from insightface_func.face_detect_crop_multi import Face_detect_crop
20 | from util.reverse2original import reverse2wholeimage
21 | import os
22 | from util.add_watermark import watermark_image
23 | import torch.nn as nn
24 | from util.norm import SpecificNorm
25 | from parsing_model.model import BiSeNet
26 |
27 | def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0
28 |
29 | transformer_Arcface = transforms.Compose([
30 | transforms.ToTensor(),
31 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
32 | ])
33 |
34 | def _totensor(array):
35 | tensor = torch.from_numpy(array)
36 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
37 | return img.float().div(255)
38 |
39 | def _toarctensor(array):
40 | tensor = torch.from_numpy(array)
41 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
42 | return img.float().div(255)
43 |
44 | if __name__ == '__main__':
45 | opt = TestOptions().parse()
46 |
47 | start_epoch, epoch_iter = 1, 0
48 | crop_size = opt.crop_size
49 |
50 | torch.nn.Module.dump_patches = True
51 | if crop_size == 512:
52 | opt.which_epoch = 550000
53 | opt.name = '512'
54 | mode = 'ffhq'
55 | else:
56 | mode = 'None'
57 | logoclass = watermark_image('./simswaplogo/simswaplogo.png')
58 | model = create_model(opt)
59 | model.eval()
60 | mse = torch.nn.MSELoss().cuda()
61 |
62 | spNorm =SpecificNorm()
63 |
64 |
65 | app = Face_detect_crop(name='antelope', root='./insightface_func/models')
66 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
67 |
68 | pic_a = opt.pic_a_path
69 | pic_specific = opt.pic_specific_path
70 |
71 | # The person who provides id information
72 | img_a_whole = cv2.imread(pic_a)
73 | img_a_align_crop, _ = app.get(img_a_whole,crop_size)
74 | img_a_align_crop_pil = Image.fromarray(cv2.cvtColor(img_a_align_crop[0],cv2.COLOR_BGR2RGB))
75 | img_a = transformer_Arcface(img_a_align_crop_pil)
76 | img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2])
77 |
78 | # convert numpy to tensor
79 | img_id = img_id.cuda()
80 |
81 | #create latent id
82 | img_id_downsample = F.interpolate(img_id, size=(112,112))
83 | latend_id = model.netArc(img_id_downsample)
84 | latend_id = F.normalize(latend_id, p=2, dim=1)
85 |
86 |
87 | # The specific person to be swapped
88 | specific_person_whole = cv2.imread(pic_specific)
89 | specific_person_align_crop, _ = app.get(specific_person_whole,crop_size)
90 | specific_person_align_crop_pil = Image.fromarray(cv2.cvtColor(specific_person_align_crop[0],cv2.COLOR_BGR2RGB))
91 | specific_person = transformer_Arcface(specific_person_align_crop_pil)
92 | specific_person = specific_person.view(-1, specific_person.shape[0], specific_person.shape[1], specific_person.shape[2])
93 |
94 | # convert numpy to tensor
95 | specific_person = specific_person.cuda()
96 |
97 | #create latent id
98 | specific_person_downsample = F.interpolate(specific_person, size=(112,112))
99 | specific_person_id_nonorm = model.netArc(specific_person_downsample)
100 | # specific_person_id_norm = F.normalize(specific_person_id_nonorm, p=2, dim=1)
101 |
102 | ############## Forward Pass ######################
103 |
104 | pic_b = opt.pic_b_path
105 | img_b_whole = cv2.imread(pic_b)
106 |
107 | img_b_align_crop_list, b_mat_list = app.get(img_b_whole,crop_size)
108 | # detect_results = None
109 | swap_result_list = []
110 |
111 | id_compare_values = []
112 | b_align_crop_tenor_list = []
113 | for b_align_crop in img_b_align_crop_list:
114 |
115 | b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
116 |
117 | b_align_crop_tenor_arcnorm = spNorm(b_align_crop_tenor)
118 | b_align_crop_tenor_arcnorm_downsample = F.interpolate(b_align_crop_tenor_arcnorm, size=(112,112))
119 | b_align_crop_id_nonorm = model.netArc(b_align_crop_tenor_arcnorm_downsample)
120 |
121 | id_compare_values.append(mse(b_align_crop_id_nonorm,specific_person_id_nonorm).detach().cpu().numpy())
122 | b_align_crop_tenor_list.append(b_align_crop_tenor)
123 |
124 | id_compare_values_array = np.array(id_compare_values)
125 | min_index = np.argmin(id_compare_values_array)
126 | min_value = id_compare_values_array[min_index]
127 |
128 | if opt.use_mask:
129 | n_classes = 19
130 | net = BiSeNet(n_classes=n_classes)
131 | net.cuda()
132 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth')
133 | net.load_state_dict(torch.load(save_pth))
134 | net.eval()
135 | else:
136 | net =None
137 |
138 | if min_value < opt.id_thres:
139 |
140 | swap_result = model(None, b_align_crop_tenor_list[min_index], latend_id, None, True)[0]
141 |
142 | reverse2wholeimage([b_align_crop_tenor_list[min_index]], [swap_result], [b_mat_list[min_index]], crop_size, img_b_whole, logoclass, \
143 | os.path.join(opt.output_path, 'result_whole_swapspecific.jpg'), opt.no_simswaplogo,pasring_model =net,use_mask=opt.use_mask, norm = spNorm)
144 |
145 | print(' ')
146 |
147 | print('************ Done ! ************')
148 |
149 | else:
150 | print('The person you specified is not found on the picture: {}'.format(pic_b))
151 |
--------------------------------------------------------------------------------
/util/add_watermark.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import numpy as np
4 | from PIL import Image
5 | import math
6 | import numpy as np
7 | # import torch
8 | # from torchvision import transforms
9 |
10 | def rotate_image(image, angle, center = None, scale = 1.0):
11 | (h, w) = image.shape[:2]
12 |
13 | if center is None:
14 | center = (w / 2, h / 2)
15 |
16 | # Perform the rotation
17 | M = cv2.getRotationMatrix2D(center, angle, scale)
18 | rotated = cv2.warpAffine(image, M, (w, h))
19 |
20 | return rotated
21 |
22 | class watermark_image:
23 | def __init__(self, logo_path, size=0.3, oritation="DR", margin=(5,20,20,20), angle=15, rgb_weight=(0,1,1.5), input_frame_shape=None) -> None:
24 |
25 | logo_image = cv2.imread(logo_path, cv2.IMREAD_UNCHANGED)
26 | h,w,c = logo_image.shape
27 | if angle%360 != 0:
28 | new_h = w*math.sin(angle/180*math.pi) + h*math.cos(angle/180*math.pi)
29 | pad_h = int((new_h-h)//2)
30 |
31 | padding = np.zeros((pad_h, w, c), dtype=np.uint8)
32 | logo_image = cv2.vconcat([logo_image, padding])
33 | logo_image = cv2.vconcat([padding, logo_image])
34 |
35 | logo_image = rotate_image(logo_image, angle)
36 | print(logo_image.shape)
37 | self.logo_image = logo_image
38 |
39 | if self.logo_image.shape[2] < 4:
40 | print("No alpha channel found!")
41 | self.logo_image = self.__addAlpha__(self.logo_image) #add alpha channel
42 | self.size = size
43 | self.oritation = oritation
44 | self.margin = margin
45 | self.ori_shape = self.logo_image.shape
46 | self.resized = False
47 | self.rgb_weight = rgb_weight
48 |
49 | self.logo_image[:, :, 2] = self.logo_image[:, :, 2]*self.rgb_weight[0]
50 | self.logo_image[:, :, 1] = self.logo_image[:, :, 1]*self.rgb_weight[1]
51 | self.logo_image[:, :, 0] = self.logo_image[:, :, 0]*self.rgb_weight[2]
52 |
53 | if input_frame_shape is not None:
54 |
55 | logo_w = input_frame_shape[1] * self.size
56 | ratio = logo_w / self.ori_shape[1]
57 | logo_h = int(ratio * self.ori_shape[0])
58 | logo_w = int(logo_w)
59 |
60 | size = (logo_w, logo_h)
61 | self.logo_image = cv2.resize(self.logo_image, size, interpolation = cv2.INTER_CUBIC)
62 | self.resized = True
63 | if oritation == "UL":
64 | self.coor_h = self.margin[1]
65 | self.coor_w = self.margin[0]
66 | elif oritation == "UR":
67 | self.coor_h = self.margin[1]
68 | self.coor_w = input_frame_shape[1] - (logo_w + self.margin[2])
69 | elif oritation == "DL":
70 | self.coor_h = input_frame_shape[0] - (logo_h + self.margin[1])
71 | self.coor_w = self.margin[0]
72 | else:
73 | self.coor_h = input_frame_shape[0] - (logo_h + self.margin[3])
74 | self.coor_w = input_frame_shape[1] - (logo_w + self.margin[2])
75 | self.logo_w = logo_w
76 | self.logo_h = logo_h
77 | self.mask = self.logo_image[:,:,3]
78 | self.mask = cv2.bitwise_not(self.mask//255)
79 |
80 | def apply_frames(self, frame):
81 |
82 | if not self.resized:
83 | shape = frame.shape
84 | logo_w = shape[1] * self.size
85 | ratio = logo_w / self.ori_shape[1]
86 | logo_h = int(ratio * self.ori_shape[0])
87 | logo_w = int(logo_w)
88 |
89 | size = (logo_w, logo_h)
90 | self.logo_image = cv2.resize(self.logo_image, size, interpolation = cv2.INTER_CUBIC)
91 | self.resized = True
92 | if self.oritation == "UL":
93 | self.coor_h = self.margin[1]
94 | self.coor_w = self.margin[0]
95 | elif self.oritation == "UR":
96 | self.coor_h = self.margin[1]
97 | self.coor_w = shape[1] - (logo_w + self.margin[2])
98 | elif self.oritation == "DL":
99 | self.coor_h = shape[0] - (logo_h + self.margin[1])
100 | self.coor_w = self.margin[0]
101 | else:
102 | self.coor_h = shape[0] - (logo_h + self.margin[3])
103 | self.coor_w = shape[1] - (logo_w + self.margin[2])
104 | self.logo_w = logo_w
105 | self.logo_h = logo_h
106 | self.mask = self.logo_image[:,:,3]
107 | self.mask = cv2.bitwise_not(self.mask//255)
108 |
109 | original_frame = frame[self.coor_h:(self.coor_h+self.logo_h), self.coor_w:(self.coor_w+self.logo_w),:]
110 | blending_logo = cv2.add(self.logo_image[:,:,0:3],original_frame,mask = self.mask)
111 | frame[self.coor_h:(self.coor_h+self.logo_h), self.coor_w:(self.coor_w+self.logo_w),:] = blending_logo
112 | return frame
113 |
114 | def __addAlpha__(self, image):
115 | shape = image.shape
116 | alpha_channel = np.ones((shape[0],shape[1],1),np.uint8)*255
117 | return np.concatenate((image,alpha_channel),2)
118 |
119 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, refresh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 |
16 | self.doc = dominate.document(title=title)
17 | if refresh > 0:
18 | with self.doc.head:
19 | meta(http_equiv="refresh", content=str(refresh))
20 |
21 | def get_image_dir(self):
22 | return self.img_dir
23 |
24 | def add_header(self, str):
25 | with self.doc:
26 | h3(str)
27 |
28 | def add_table(self, border=1):
29 | self.t = table(border=border, style="table-layout: fixed;")
30 | self.doc.add(self.t)
31 |
32 | def add_images(self, ims, txts, links, width=512):
33 | self.add_table()
34 | with self.t:
35 | with tr():
36 | for im, txt, link in zip(ims, txts, links):
37 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
38 | with p():
39 | with a(href=os.path.join('images', link)):
40 | img(style="width:%dpx" % (width), src=os.path.join('images', im))
41 | br()
42 | p(txt)
43 |
44 | def save(self):
45 | html_file = '%s/index.html' % self.web_dir
46 | f = open(html_file, 'wt')
47 | f.write(self.doc.render())
48 | f.close()
49 |
50 |
51 | if __name__ == '__main__':
52 | html = HTML('web/', 'test_html')
53 | html.add_header('hello world')
54 |
55 | ims = []
56 | txts = []
57 | links = []
58 | for n in range(4):
59 | ims.append('image_%d.jpg' % n)
60 | txts.append('text_%d' % n)
61 | links.append('image_%d.jpg' % n)
62 | html.add_images(ims, txts, links)
63 | html.save()
64 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.autograd import Variable
4 | class ImagePool():
5 | def __init__(self, pool_size):
6 | self.pool_size = pool_size
7 | if self.pool_size > 0:
8 | self.num_imgs = 0
9 | self.images = []
10 |
11 | def query(self, images):
12 | if self.pool_size == 0:
13 | return images
14 | return_images = []
15 | for image in images.data:
16 | image = torch.unsqueeze(image, 0)
17 | if self.num_imgs < self.pool_size:
18 | self.num_imgs = self.num_imgs + 1
19 | self.images.append(image)
20 | return_images.append(image)
21 | else:
22 | p = random.uniform(0, 1)
23 | if p > 0.5:
24 | random_id = random.randint(0, self.pool_size-1)
25 | tmp = self.images[random_id].clone()
26 | self.images[random_id] = image
27 | return_images.append(tmp)
28 | else:
29 | return_images.append(image)
30 | return_images = Variable(torch.cat(return_images, 0))
31 | return return_images
32 |
--------------------------------------------------------------------------------
/util/json_config.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | def readConfig(path):
5 | with open(path,'r') as cf:
6 | nodelocaltionstr = cf.read()
7 | nodelocaltioninf = json.loads(nodelocaltionstr)
8 | if isinstance(nodelocaltioninf,str):
9 | nodelocaltioninf = json.loads(nodelocaltioninf)
10 | return nodelocaltioninf
11 |
12 | def writeConfig(path, info):
13 | with open(path, 'w') as cf:
14 | configjson = json.dumps(info, indent=4)
15 | cf.writelines(configjson)
--------------------------------------------------------------------------------
/util/logo_class.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | #############################################################
4 | # File: logo_class.py
5 | # Created Date: Tuesday June 29th 2021
6 | # Author: Chen Xuanhong
7 | # Email: chenxuanhongzju@outlook.com
8 | # Last Modified: Monday, 11th October 2021 12:39:55 am
9 | # Modified By: Chen Xuanhong
10 | # Copyright (c) 2021 Shanghai Jiao Tong University
11 | #############################################################
12 |
13 | class logo_class:
14 |
15 | @staticmethod
16 | def print_group_logo():
17 | logo_str = """
18 |
19 | ███╗ ██╗██████╗ ███████╗██╗ ██████╗ ███████╗ ██╗████████╗██╗ ██╗
20 | ████╗ ██║██╔══██╗██╔════╝██║██╔════╝ ██╔════╝ ██║╚══██╔══╝██║ ██║
21 | ██╔██╗ ██║██████╔╝███████╗██║██║ ███╗ ███████╗ ██║ ██║ ██║ ██║
22 | ██║╚██╗██║██╔══██╗╚════██║██║██║ ██║ ╚════██║██ ██║ ██║ ██║ ██║
23 | ██║ ╚████║██║ ██║███████║██║╚██████╔╝ ███████║╚█████╔╝ ██║ ╚██████╔╝
24 | ╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═════╝ ╚══════╝ ╚════╝ ╚═╝ ╚═════╝
25 | Neural Rendering Special Interesting Group of SJTU
26 |
27 | """
28 | print(logo_str)
29 |
30 | @staticmethod
31 | def print_start_training():
32 | logo_str = """
33 | _____ __ __ ______ _ _
34 | / ___/ / /_ ____ _ _____ / /_ /_ __/_____ ____ _ (_)____ (_)____ ____ _
35 | \__ \ / __// __ `// ___// __/ / / / ___// __ `// // __ \ / // __ \ / __ `/
36 | ___/ // /_ / /_/ // / / /_ / / / / / /_/ // // / / // // / / // /_/ /
37 | /____/ \__/ \__,_//_/ \__/ /_/ /_/ \__,_//_//_/ /_//_//_/ /_/ \__, /
38 | /____/
39 | """
40 | print(logo_str)
41 |
42 | if __name__=="__main__":
43 | # logo_class.print_group_logo()
44 | logo_class.print_start_training()
--------------------------------------------------------------------------------
/util/norm.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch
4 | class SpecificNorm(nn.Module):
5 | def __init__(self, epsilon=1e-8):
6 | """
7 | @notice: avoid in-place ops.
8 | https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
9 | """
10 | super(SpecificNorm, self).__init__()
11 | self.mean = np.array([0.485, 0.456, 0.406])
12 | self.mean = torch.from_numpy(self.mean).float().cuda()
13 | self.mean = self.mean.view([1, 3, 1, 1])
14 |
15 | self.std = np.array([0.229, 0.224, 0.225])
16 | self.std = torch.from_numpy(self.std).float().cuda()
17 | self.std = self.std.view([1, 3, 1, 1])
18 |
19 | def forward(self, x):
20 | mean = self.mean.expand([1, 3, x.shape[2], x.shape[3]])
21 | std = self.std.expand([1, 3, x.shape[2], x.shape[3]])
22 |
23 | x = (x - mean) / std
24 |
25 | return x
--------------------------------------------------------------------------------
/util/plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import PIL
4 |
5 | def postprocess(x):
6 | """[0,1] to uint8."""
7 |
8 | x = np.clip(255 * x, 0, 255)
9 | x = np.cast[np.uint8](x)
10 | return x
11 |
12 | def tile(X, rows, cols):
13 | """Tile images for display."""
14 | tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype)
15 | for i in range(rows):
16 | for j in range(cols):
17 | idx = i * cols + j
18 | if idx < X.shape[0]:
19 | img = X[idx,...]
20 | tiling[
21 | i*X.shape[1]:(i+1)*X.shape[1],
22 | j*X.shape[2]:(j+1)*X.shape[2],
23 | :] = img
24 | return tiling
25 |
26 |
27 | def plot_batch(X, out_path):
28 | """Save batch of images tiled."""
29 | n_channels = X.shape[3]
30 | if n_channels > 3:
31 | X = X[:,:,:,np.random.choice(n_channels, size = 3)]
32 | X = postprocess(X)
33 | rc = math.sqrt(X.shape[0])
34 | rows = cols = math.ceil(rc)
35 | canvas = tile(X, rows, cols)
36 | canvas = np.squeeze(canvas)
37 | PIL.Image.fromarray(canvas).save(out_path)
--------------------------------------------------------------------------------
/util/reverse2original.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | # import time
4 | import torch
5 | from torch.nn import functional as F
6 | import torch.nn as nn
7 |
8 |
9 | def encode_segmentation_rgb(segmentation, no_neck=True):
10 | parse = segmentation
11 |
12 | face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14]
13 | mouth_id = 11
14 | # hair_id = 17
15 | face_map = np.zeros([parse.shape[0], parse.shape[1]])
16 | mouth_map = np.zeros([parse.shape[0], parse.shape[1]])
17 | # hair_map = np.zeros([parse.shape[0], parse.shape[1]])
18 |
19 | for valid_id in face_part_ids:
20 | valid_index = np.where(parse==valid_id)
21 | face_map[valid_index] = 255
22 | valid_index = np.where(parse==mouth_id)
23 | mouth_map[valid_index] = 255
24 | # valid_index = np.where(parse==hair_id)
25 | # hair_map[valid_index] = 255
26 | #return np.stack([face_map, mouth_map,hair_map], axis=2)
27 | return np.stack([face_map, mouth_map], axis=2)
28 |
29 |
30 | class SoftErosion(nn.Module):
31 | def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
32 | super(SoftErosion, self).__init__()
33 | r = kernel_size // 2
34 | self.padding = r
35 | self.iterations = iterations
36 | self.threshold = threshold
37 |
38 | # Create kernel
39 | y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
40 | dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
41 | kernel = dist.max() - dist
42 | kernel /= kernel.sum()
43 | kernel = kernel.view(1, 1, *kernel.shape)
44 | self.register_buffer('weight', kernel)
45 |
46 | def forward(self, x):
47 | x = x.float()
48 | for i in range(self.iterations - 1):
49 | x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
50 | x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
51 |
52 | mask = x >= self.threshold
53 | x[mask] = 1.0
54 | x[~mask] /= x[~mask].max()
55 |
56 | return x, mask
57 |
58 |
59 | def postprocess(swapped_face, target, target_mask,smooth_mask):
60 | # target_mask = cv2.resize(target_mask, (self.size, self.size))
61 |
62 | mask_tensor = torch.from_numpy(target_mask.copy().transpose((2, 0, 1))).float().mul_(1/255.0).cuda()
63 | face_mask_tensor = mask_tensor[0] + mask_tensor[1]
64 |
65 | soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
66 | soft_face_mask_tensor.squeeze_()
67 |
68 | soft_face_mask = soft_face_mask_tensor.cpu().numpy()
69 | soft_face_mask = soft_face_mask[:, :, np.newaxis]
70 |
71 | result = swapped_face * soft_face_mask + target * (1 - soft_face_mask)
72 | result = result[:,:,::-1]# .astype(np.uint8)
73 | return result
74 |
75 | def reverse2wholeimage(b_align_crop_tenor_list,swaped_imgs, mats, crop_size, oriimg, logoclass, save_path = '', \
76 | no_simswaplogo = False,pasring_model =None,norm = None, use_mask = False):
77 |
78 | target_image_list = []
79 | img_mask_list = []
80 | if use_mask:
81 | smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).cuda()
82 | else:
83 | pass
84 |
85 | # print(len(swaped_imgs))
86 | # print(mats)
87 | # print(len(b_align_crop_tenor_list))
88 | for swaped_img, mat ,source_img in zip(swaped_imgs, mats,b_align_crop_tenor_list):
89 | swaped_img = swaped_img.cpu().detach().numpy().transpose((1, 2, 0))
90 | img_white = np.full((crop_size,crop_size), 255, dtype=float)
91 |
92 | # inverse the Affine transformation matrix
93 | mat_rev = np.zeros([2,3])
94 | div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
95 | mat_rev[0][0] = mat[1][1]/div1
96 | mat_rev[0][1] = -mat[0][1]/div1
97 | mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
98 | div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
99 | mat_rev[1][0] = mat[1][0]/div2
100 | mat_rev[1][1] = -mat[0][0]/div2
101 | mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
102 |
103 | orisize = (oriimg.shape[1], oriimg.shape[0])
104 | if use_mask:
105 | source_img_norm = norm(source_img)
106 | source_img_512 = F.interpolate(source_img_norm,size=(512,512))
107 | out = pasring_model(source_img_512)[0]
108 | parsing = out.squeeze(0).detach().cpu().numpy().argmax(0)
109 | vis_parsing_anno = parsing.copy().astype(np.uint8)
110 | tgt_mask = encode_segmentation_rgb(vis_parsing_anno)
111 | if tgt_mask.sum() >= 5000:
112 | # face_mask_tensor = tgt_mask[...,0] + tgt_mask[...,1]
113 | target_mask = cv2.resize(tgt_mask, (crop_size, crop_size))
114 | # print(source_img)
115 | target_image_parsing = postprocess(swaped_img, source_img[0].cpu().detach().numpy().transpose((1, 2, 0)), target_mask,smooth_mask)
116 |
117 |
118 | target_image = cv2.warpAffine(target_image_parsing, mat_rev, orisize)
119 | # target_image_parsing = cv2.warpAffine(swaped_img, mat_rev, orisize)
120 | else:
121 | target_image = cv2.warpAffine(swaped_img, mat_rev, orisize)[..., ::-1]
122 | else:
123 | target_image = cv2.warpAffine(swaped_img, mat_rev, orisize)
124 | # source_image = cv2.warpAffine(source_img, mat_rev, orisize)
125 |
126 | img_white = cv2.warpAffine(img_white, mat_rev, orisize)
127 |
128 |
129 | img_white[img_white>20] =255
130 |
131 | img_mask = img_white
132 |
133 | # if use_mask:
134 | # kernel = np.ones((40,40),np.uint8)
135 | # img_mask = cv2.erode(img_mask,kernel,iterations = 1)
136 | # else:
137 | kernel = np.ones((40,40),np.uint8)
138 | img_mask = cv2.erode(img_mask,kernel,iterations = 1)
139 | kernel_size = (20, 20)
140 | blur_size = tuple(2*i+1 for i in kernel_size)
141 | img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
142 |
143 | # kernel = np.ones((10,10),np.uint8)
144 | # img_mask = cv2.erode(img_mask,kernel,iterations = 1)
145 |
146 |
147 |
148 | img_mask /= 255
149 |
150 | img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
151 |
152 | # pasing mask
153 |
154 | # target_image_parsing = postprocess(target_image, source_image, tgt_mask)
155 |
156 | if use_mask:
157 | target_image = np.array(target_image, dtype=np.float) * 255
158 | else:
159 | target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
160 |
161 |
162 | img_mask_list.append(img_mask)
163 | target_image_list.append(target_image)
164 |
165 |
166 | # target_image /= 255
167 | # target_image = 0
168 | img = np.array(oriimg, dtype=np.float)
169 | for img_mask, target_image in zip(img_mask_list, target_image_list):
170 | img = img_mask * target_image + (1-img_mask) * img
171 |
172 | final_img = img.astype(np.uint8)
173 | if not no_simswaplogo:
174 | final_img = logoclass.apply_frames(final_img)
175 | cv2.imwrite(save_path, final_img)
176 |
--------------------------------------------------------------------------------
/util/save_heatmap.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | #############################################################
4 | # File: save_heatmap.py
5 | # Created Date: Friday January 15th 2021
6 | # Author: Chen Xuanhong
7 | # Email: chenxuanhongzju@outlook.com
8 | # Last Modified: Wednesday, 19th January 2022 1:22:47 am
9 | # Modified By: Chen Xuanhong
10 | # Copyright (c) 2021 Shanghai Jiao Tong University
11 | #############################################################
12 |
13 | import os
14 | import shutil
15 | import seaborn as sns
16 | import matplotlib.pyplot as plt
17 | import cv2
18 | import numpy as np
19 |
20 | def SaveHeatmap(heatmaps, path, row=-1, dpi=72):
21 | """
22 | The input tensor must be B X 1 X H X W
23 | """
24 | batch_size = heatmaps.shape[0]
25 | temp_path = ".temp/"
26 | if not os.path.exists(temp_path):
27 | os.makedirs(temp_path)
28 | final_img = None
29 | if row < 1:
30 | col = batch_size
31 | row = 1
32 | else:
33 | col = batch_size // row
34 | if row * col = col:
51 | col_i = 0
52 | row_i += 1
53 | cv2.imwrite(path,final_img)
54 |
55 | if __name__ == "__main__":
56 | random_map = np.random.randn(16,1,10,10)
57 | SaveHeatmap(random_map,"./wocao.png",1)
58 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import numpy as np
6 | import os
7 |
8 | # Converts a Tensor into a Numpy array
9 | # |imtype|: the desired type of the converted numpy array
10 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
11 | if isinstance(image_tensor, list):
12 | image_numpy = []
13 | for i in range(len(image_tensor)):
14 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
15 | return image_numpy
16 | image_numpy = image_tensor.cpu().float().numpy()
17 | if normalize:
18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
19 | else:
20 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
21 | image_numpy = np.clip(image_numpy, 0, 255)
22 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
23 | image_numpy = image_numpy[:,:,0]
24 | return image_numpy.astype(imtype)
25 |
26 | # Converts a one-hot tensor into a colorful label map
27 | def tensor2label(label_tensor, n_label, imtype=np.uint8):
28 | if n_label == 0:
29 | return tensor2im(label_tensor, imtype)
30 | label_tensor = label_tensor.cpu().float()
31 | if label_tensor.size()[0] > 1:
32 | label_tensor = label_tensor.max(0, keepdim=True)[1]
33 | label_tensor = Colorize(n_label)(label_tensor)
34 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
35 | return label_numpy.astype(imtype)
36 |
37 | def save_image(image_numpy, image_path):
38 | image_pil = Image.fromarray(image_numpy)
39 | image_pil.save(image_path)
40 |
41 | def mkdirs(paths):
42 | if isinstance(paths, list) and not isinstance(paths, str):
43 | for path in paths:
44 | mkdir(path)
45 | else:
46 | mkdir(paths)
47 |
48 | def mkdir(path):
49 | if not os.path.exists(path):
50 | os.makedirs(path)
51 |
52 | ###############################################################################
53 | # Code from
54 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py
55 | # Modified so it complies with the Citscape label map colors
56 | ###############################################################################
57 | def uint82bin(n, count=8):
58 | """returns the binary of integer n, count refers to amount of bits"""
59 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
60 |
61 | def labelcolormap(N):
62 | if N == 35: # cityscape
63 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
64 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
65 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
66 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
67 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
68 | dtype=np.uint8)
69 | else:
70 | cmap = np.zeros((N, 3), dtype=np.uint8)
71 | for i in range(N):
72 | r, g, b = 0, 0, 0
73 | id = i
74 | for j in range(7):
75 | str_id = uint82bin(id)
76 | r = r ^ (np.uint8(str_id[-1]) << (7-j))
77 | g = g ^ (np.uint8(str_id[-2]) << (7-j))
78 | b = b ^ (np.uint8(str_id[-3]) << (7-j))
79 | id = id >> 3
80 | cmap[i, 0] = r
81 | cmap[i, 1] = g
82 | cmap[i, 2] = b
83 | return cmap
84 |
85 | class Colorize(object):
86 | def __init__(self, n=35):
87 | self.cmap = labelcolormap(n)
88 | self.cmap = torch.from_numpy(self.cmap[:n])
89 |
90 | def __call__(self, gray_image):
91 | size = gray_image.size()
92 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
93 |
94 | for label in range(0, len(self.cmap)):
95 | mask = (label == gray_image[0]).cpu()
96 | color_image[0][mask] = self.cmap[label][0]
97 | color_image[1][mask] = self.cmap[label][1]
98 | color_image[2][mask] = self.cmap[label][2]
99 |
100 | return color_image
101 |
--------------------------------------------------------------------------------
/util/videoswap.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-23 17:03:58
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-24 19:19:52
7 | Description:
8 | '''
9 | import os
10 | import cv2
11 | import glob
12 | import torch
13 | import shutil
14 | import numpy as np
15 | from tqdm import tqdm
16 | from util.reverse2original import reverse2wholeimage
17 | import moviepy.editor as mp
18 | from moviepy.editor import AudioFileClip, VideoFileClip
19 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
20 | import time
21 | from util.add_watermark import watermark_image
22 | from util.norm import SpecificNorm
23 | from parsing_model.model import BiSeNet
24 |
25 | def _totensor(array):
26 | tensor = torch.from_numpy(array)
27 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
28 | return img.float().div(255)
29 |
30 | def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False):
31 | video_forcheck = VideoFileClip(video_path)
32 | if video_forcheck.audio is None:
33 | no_audio = True
34 | else:
35 | no_audio = False
36 |
37 | del video_forcheck
38 |
39 | if not no_audio:
40 | video_audio_clip = AudioFileClip(video_path)
41 |
42 | video = cv2.VideoCapture(video_path)
43 | logoclass = watermark_image('./simswaplogo/simswaplogo.png')
44 | ret = True
45 | frame_index = 0
46 |
47 | frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
48 |
49 | # video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
50 |
51 | # video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
52 |
53 | fps = video.get(cv2.CAP_PROP_FPS)
54 | if os.path.exists(temp_results_dir):
55 | shutil.rmtree(temp_results_dir)
56 |
57 | spNorm =SpecificNorm()
58 | if use_mask:
59 | n_classes = 19
60 | net = BiSeNet(n_classes=n_classes)
61 | net.cuda()
62 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth')
63 | net.load_state_dict(torch.load(save_pth))
64 | net.eval()
65 | else:
66 | net =None
67 |
68 | # while ret:
69 | for frame_index in tqdm(range(frame_count)):
70 | ret, frame = video.read()
71 | if ret:
72 | detect_results = detect_model.get(frame,crop_size)
73 |
74 | if detect_results is not None:
75 | # print(frame_index)
76 | if not os.path.exists(temp_results_dir):
77 | os.mkdir(temp_results_dir)
78 | frame_align_crop_list = detect_results[0]
79 | frame_mat_list = detect_results[1]
80 | swap_result_list = []
81 | frame_align_crop_tenor_list = []
82 | for frame_align_crop in frame_align_crop_list:
83 |
84 | # BGR TO RGB
85 | # frame_align_crop_RGB = frame_align_crop[...,::-1]
86 |
87 | frame_align_crop_tenor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
88 |
89 | swap_result = swap_model(None, frame_align_crop_tenor, id_vetor, None, True)[0]
90 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
91 | swap_result_list.append(swap_result)
92 | frame_align_crop_tenor_list.append(frame_align_crop_tenor)
93 |
94 |
95 |
96 | reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame, logoclass,\
97 | os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask=use_mask, norm = spNorm)
98 |
99 | else:
100 | if not os.path.exists(temp_results_dir):
101 | os.mkdir(temp_results_dir)
102 | frame = frame.astype(np.uint8)
103 | if not no_simswaplogo:
104 | frame = logoclass.apply_frames(frame)
105 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
106 | else:
107 | break
108 |
109 | video.release()
110 |
111 | # image_filename_list = []
112 | path = os.path.join(temp_results_dir,'*.jpg')
113 | image_filenames = sorted(glob.glob(path))
114 |
115 | clips = ImageSequenceClip(image_filenames,fps = fps)
116 |
117 | if not no_audio:
118 | clips = clips.set_audio(video_audio_clip)
119 |
120 |
121 | clips.write_videofile(save_path,audio_codec='aac')
122 |
123 |
--------------------------------------------------------------------------------
/util/videoswap_multispecific.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import torch
5 | import shutil
6 | import numpy as np
7 | from tqdm import tqdm
8 | from util.reverse2original import reverse2wholeimage
9 | import moviepy.editor as mp
10 | from moviepy.editor import AudioFileClip, VideoFileClip
11 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
12 | import time
13 | from util.add_watermark import watermark_image
14 | from util.norm import SpecificNorm
15 | import torch.nn.functional as F
16 | from parsing_model.model import BiSeNet
17 |
18 | def _totensor(array):
19 | tensor = torch.from_numpy(array)
20 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
21 | return img.float().div(255)
22 |
23 | def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False):
24 | video_forcheck = VideoFileClip(video_path)
25 | if video_forcheck.audio is None:
26 | no_audio = True
27 | else:
28 | no_audio = False
29 |
30 | del video_forcheck
31 |
32 | if not no_audio:
33 | video_audio_clip = AudioFileClip(video_path)
34 |
35 | video = cv2.VideoCapture(video_path)
36 | logoclass = watermark_image('./simswaplogo/simswaplogo.png')
37 | ret = True
38 | frame_index = 0
39 |
40 | frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
41 |
42 | # video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
43 |
44 | # video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
45 |
46 | fps = video.get(cv2.CAP_PROP_FPS)
47 | if os.path.exists(temp_results_dir):
48 | shutil.rmtree(temp_results_dir)
49 |
50 | spNorm =SpecificNorm()
51 | mse = torch.nn.MSELoss().cuda()
52 |
53 | if use_mask:
54 | n_classes = 19
55 | net = BiSeNet(n_classes=n_classes)
56 | net.cuda()
57 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth')
58 | net.load_state_dict(torch.load(save_pth))
59 | net.eval()
60 | else:
61 | net =None
62 |
63 | # while ret:
64 | for frame_index in tqdm(range(frame_count)):
65 | ret, frame = video.read()
66 | if ret:
67 | detect_results = detect_model.get(frame,crop_size)
68 |
69 | if detect_results is not None:
70 | # print(frame_index)
71 | if not os.path.exists(temp_results_dir):
72 | os.mkdir(temp_results_dir)
73 | frame_align_crop_list = detect_results[0]
74 | frame_mat_list = detect_results[1]
75 |
76 | id_compare_values = []
77 | frame_align_crop_tenor_list = []
78 | for frame_align_crop in frame_align_crop_list:
79 |
80 | # BGR TO RGB
81 | # frame_align_crop_RGB = frame_align_crop[...,::-1]
82 |
83 | frame_align_crop_tenor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
84 |
85 | frame_align_crop_tenor_arcnorm = spNorm(frame_align_crop_tenor)
86 | frame_align_crop_tenor_arcnorm_downsample = F.interpolate(frame_align_crop_tenor_arcnorm, size=(112,112))
87 | frame_align_crop_crop_id_nonorm = swap_model.netArc(frame_align_crop_tenor_arcnorm_downsample)
88 | id_compare_values.append([])
89 | for source_specific_id_nonorm_tmp in source_specific_id_nonorm_list:
90 | id_compare_values[-1].append(mse(frame_align_crop_crop_id_nonorm,source_specific_id_nonorm_tmp).detach().cpu().numpy())
91 | frame_align_crop_tenor_list.append(frame_align_crop_tenor)
92 |
93 | id_compare_values_array = np.array(id_compare_values).transpose(1,0)
94 | min_indexs = np.argmin(id_compare_values_array,axis=0)
95 | min_value = np.min(id_compare_values_array,axis=0)
96 |
97 | swap_result_list = []
98 | swap_result_matrix_list = []
99 | swap_result_ori_pic_list = []
100 | for tmp_index, min_index in enumerate(min_indexs):
101 | if min_value[tmp_index] < id_thres:
102 | swap_result = swap_model(None, frame_align_crop_tenor_list[tmp_index], target_id_norm_list[min_index], None, True)[0]
103 | swap_result_list.append(swap_result)
104 | swap_result_matrix_list.append(frame_mat_list[tmp_index])
105 | swap_result_ori_pic_list.append(frame_align_crop_tenor_list[tmp_index])
106 | else:
107 | pass
108 |
109 |
110 |
111 | if len(swap_result_list) !=0:
112 |
113 | reverse2wholeimage(swap_result_ori_pic_list,swap_result_list, swap_result_matrix_list, crop_size, frame, logoclass,\
114 | os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask=use_mask, norm = spNorm)
115 | else:
116 | if not os.path.exists(temp_results_dir):
117 | os.mkdir(temp_results_dir)
118 | frame = frame.astype(np.uint8)
119 | if not no_simswaplogo:
120 | frame = logoclass.apply_frames(frame)
121 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
122 |
123 | else:
124 | if not os.path.exists(temp_results_dir):
125 | os.mkdir(temp_results_dir)
126 | frame = frame.astype(np.uint8)
127 | if not no_simswaplogo:
128 | frame = logoclass.apply_frames(frame)
129 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
130 | else:
131 | break
132 |
133 | video.release()
134 |
135 | # image_filename_list = []
136 | path = os.path.join(temp_results_dir,'*.jpg')
137 | image_filenames = sorted(glob.glob(path))
138 |
139 | clips = ImageSequenceClip(image_filenames,fps = fps)
140 |
141 | if not no_audio:
142 | clips = clips.set_audio(video_audio_clip)
143 |
144 |
145 | clips.write_videofile(save_path,audio_codec='aac')
146 |
147 |
--------------------------------------------------------------------------------
/util/videoswap_specific.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import torch
5 | import shutil
6 | import numpy as np
7 | from tqdm import tqdm
8 | from util.reverse2original import reverse2wholeimage
9 | import moviepy.editor as mp
10 | from moviepy.editor import AudioFileClip, VideoFileClip
11 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
12 | import time
13 | from util.add_watermark import watermark_image
14 | from util.norm import SpecificNorm
15 | import torch.nn.functional as F
16 | from parsing_model.model import BiSeNet
17 |
18 | def _totensor(array):
19 | tensor = torch.from_numpy(array)
20 | img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
21 | return img.float().div(255)
22 |
23 | def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False):
24 | video_forcheck = VideoFileClip(video_path)
25 | if video_forcheck.audio is None:
26 | no_audio = True
27 | else:
28 | no_audio = False
29 |
30 | del video_forcheck
31 |
32 | if not no_audio:
33 | video_audio_clip = AudioFileClip(video_path)
34 |
35 | video = cv2.VideoCapture(video_path)
36 | logoclass = watermark_image('./simswaplogo/simswaplogo.png')
37 | ret = True
38 | frame_index = 0
39 |
40 | frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
41 |
42 | # video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
43 |
44 | # video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
45 |
46 | fps = video.get(cv2.CAP_PROP_FPS)
47 | if os.path.exists(temp_results_dir):
48 | shutil.rmtree(temp_results_dir)
49 |
50 | spNorm =SpecificNorm()
51 | mse = torch.nn.MSELoss().cuda()
52 |
53 | if use_mask:
54 | n_classes = 19
55 | net = BiSeNet(n_classes=n_classes)
56 | net.cuda()
57 | save_pth = os.path.join('./parsing_model/checkpoint', '79999_iter.pth')
58 | net.load_state_dict(torch.load(save_pth))
59 | net.eval()
60 | else:
61 | net =None
62 |
63 | # while ret:
64 | for frame_index in tqdm(range(frame_count)):
65 | ret, frame = video.read()
66 | if ret:
67 | detect_results = detect_model.get(frame,crop_size)
68 |
69 | if detect_results is not None:
70 | # print(frame_index)
71 | if not os.path.exists(temp_results_dir):
72 | os.mkdir(temp_results_dir)
73 | frame_align_crop_list = detect_results[0]
74 | frame_mat_list = detect_results[1]
75 |
76 | id_compare_values = []
77 | frame_align_crop_tenor_list = []
78 | for frame_align_crop in frame_align_crop_list:
79 |
80 | # BGR TO RGB
81 | # frame_align_crop_RGB = frame_align_crop[...,::-1]
82 |
83 | frame_align_crop_tenor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
84 |
85 | frame_align_crop_tenor_arcnorm = spNorm(frame_align_crop_tenor)
86 | frame_align_crop_tenor_arcnorm_downsample = F.interpolate(frame_align_crop_tenor_arcnorm, size=(112,112))
87 | frame_align_crop_crop_id_nonorm = swap_model.netArc(frame_align_crop_tenor_arcnorm_downsample)
88 |
89 | id_compare_values.append(mse(frame_align_crop_crop_id_nonorm,specific_person_id_nonorm).detach().cpu().numpy())
90 | frame_align_crop_tenor_list.append(frame_align_crop_tenor)
91 | id_compare_values_array = np.array(id_compare_values)
92 | min_index = np.argmin(id_compare_values_array)
93 | min_value = id_compare_values_array[min_index]
94 | if min_value < id_thres:
95 | swap_result = swap_model(None, frame_align_crop_tenor_list[min_index], id_vetor, None, True)[0]
96 |
97 | reverse2wholeimage([frame_align_crop_tenor_list[min_index]], [swap_result], [frame_mat_list[min_index]], crop_size, frame, logoclass,\
98 | os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask= use_mask, norm = spNorm)
99 | else:
100 | if not os.path.exists(temp_results_dir):
101 | os.mkdir(temp_results_dir)
102 | frame = frame.astype(np.uint8)
103 | if not no_simswaplogo:
104 | frame = logoclass.apply_frames(frame)
105 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
106 |
107 | else:
108 | if not os.path.exists(temp_results_dir):
109 | os.mkdir(temp_results_dir)
110 | frame = frame.astype(np.uint8)
111 | if not no_simswaplogo:
112 | frame = logoclass.apply_frames(frame)
113 | cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
114 | else:
115 | break
116 |
117 | video.release()
118 |
119 | # image_filename_list = []
120 | path = os.path.join(temp_results_dir,'*.jpg')
121 | image_filenames = sorted(glob.glob(path))
122 |
123 | clips = ImageSequenceClip(image_filenames,fps = fps)
124 |
125 | if not no_audio:
126 | clips = clips.set_audio(video_audio_clip)
127 |
128 |
129 | clips.write_videofile(save_path,audio_codec='aac')
130 |
131 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import ntpath
4 | import time
5 | from . import util
6 | from . import html
7 | import scipy.misc
8 | try:
9 | from StringIO import StringIO # Python 2.7
10 | except ImportError:
11 | from io import BytesIO # Python 3.x
12 |
13 | class Visualizer():
14 | def __init__(self, opt):
15 | # self.opt = opt
16 | self.tf_log = opt.tf_log
17 | self.use_html = opt.isTrain and not opt.no_html
18 | self.win_size = opt.display_winsize
19 | self.name = opt.name
20 | if self.tf_log:
21 | import tensorflow as tf
22 | self.tf = tf
23 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
24 | self.writer = tf.summary.FileWriter(self.log_dir)
25 |
26 | if self.use_html:
27 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
28 | self.img_dir = os.path.join(self.web_dir, 'images')
29 | print('create web directory %s...' % self.web_dir)
30 | util.mkdirs([self.web_dir, self.img_dir])
31 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
32 | with open(self.log_name, "a") as log_file:
33 | now = time.strftime("%c")
34 | log_file.write('================ Training Loss (%s) ================\n' % now)
35 |
36 | # |visuals|: dictionary of images to display or save
37 | def display_current_results(self, visuals, epoch, step):
38 | if self.tf_log: # show images in tensorboard output
39 | img_summaries = []
40 | for label, image_numpy in visuals.items():
41 | # Write the image to a string
42 | try:
43 | s = StringIO()
44 | except:
45 | s = BytesIO()
46 | scipy.misc.toimage(image_numpy).save(s, format="jpeg")
47 | # Create an Image object
48 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
49 | # Create a Summary value
50 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
51 |
52 | # Create and write Summary
53 | summary = self.tf.Summary(value=img_summaries)
54 | self.writer.add_summary(summary, step)
55 |
56 | if self.use_html: # save images to a html file
57 | for label, image_numpy in visuals.items():
58 | if isinstance(image_numpy, list):
59 | for i in range(len(image_numpy)):
60 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i))
61 | util.save_image(image_numpy[i], img_path)
62 | else:
63 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label))
64 | util.save_image(image_numpy, img_path)
65 |
66 | # update website
67 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30)
68 | for n in range(epoch, 0, -1):
69 | webpage.add_header('epoch [%d]' % n)
70 | ims = []
71 | txts = []
72 | links = []
73 |
74 | for label, image_numpy in visuals.items():
75 | if isinstance(image_numpy, list):
76 | for i in range(len(image_numpy)):
77 | img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i)
78 | ims.append(img_path)
79 | txts.append(label+str(i))
80 | links.append(img_path)
81 | else:
82 | img_path = 'epoch%.3d_%s.jpg' % (n, label)
83 | ims.append(img_path)
84 | txts.append(label)
85 | links.append(img_path)
86 | if len(ims) < 10:
87 | webpage.add_images(ims, txts, links, width=self.win_size)
88 | else:
89 | num = int(round(len(ims)/2.0))
90 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
91 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
92 | webpage.save()
93 |
94 | # errors: dictionary of error labels and values
95 | def plot_current_errors(self, errors, step):
96 | if self.tf_log:
97 | for tag, value in errors.items():
98 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
99 | self.writer.add_summary(summary, step)
100 |
101 | # errors: same format as |errors| of plotCurrentErrors
102 | def print_current_errors(self, epoch, i, errors, t):
103 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
104 | for k, v in errors.items():
105 | if v != 0:
106 | message += '%s: %.3f ' % (k, v)
107 |
108 | print(message)
109 | with open(self.log_name, "a") as log_file:
110 | log_file.write('%s\n' % message)
111 |
112 | # save image to the disk
113 | def save_images(self, webpage, visuals, image_path):
114 | image_dir = webpage.get_image_dir()
115 | short_path = ntpath.basename(image_path[0])
116 | name = os.path.splitext(short_path)[0]
117 |
118 | webpage.add_header(name)
119 | ims = []
120 | txts = []
121 | links = []
122 |
123 | for label, image_numpy in visuals.items():
124 | image_name = '%s_%s.jpg' % (name, label)
125 | save_path = os.path.join(image_dir, image_name)
126 | util.save_image(image_numpy, save_path)
127 |
128 | ims.append(image_name)
129 | txts.append(label)
130 | links.append(image_name)
131 | webpage.add_images(ims, txts, links, width=self.win_size)
132 |
--------------------------------------------------------------------------------