├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── face-alignment ├── LICENSE ├── __init__.py ├── conda │ ├── conda_upload.sh │ └── meta.yaml ├── examples │ ├── demo.ipynb │ └── detect_landmarks_in_image.py ├── face_alignment │ ├── __init__.py │ ├── api.py │ ├── detection │ │ ├── __init__.py │ │ ├── blazeface │ │ │ ├── __init__.py │ │ │ ├── blazeface_detector.py │ │ │ ├── detect.py │ │ │ ├── net_blazeface.py │ │ │ └── utils.py │ │ ├── core.py │ │ ├── dlib │ │ │ ├── __init__.py │ │ │ └── dlib_detector.py │ │ ├── folder │ │ │ ├── __init__.py │ │ │ └── folder_detector.py │ │ └── sfd │ │ │ ├── __init__.py │ │ │ ├── bbox.py │ │ │ ├── detect.py │ │ │ ├── net_s3fd.py │ │ │ └── sfd_detector.py │ ├── models.py │ └── utils.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── test │ ├── assets │ │ └── aflw-test.jpg │ ├── facealignment_test.py │ ├── smoke_test.py │ └── test_utils.py └── tox.ini ├── images ├── controllable_gan.gif ├── dog_control.png ├── face_control.png ├── painting_control.png └── pinting_style.png ├── notebooks ├── __init__.py └── gan_control_inference_example.ipynb ├── resources ├── README.md └── ffhq_1K_attributes_samples_df.pkl └── src ├── __init__.py └── gan_control ├── __init__.py ├── configs ├── afhq.json ├── controller_configs │ ├── afhq │ │ └── default_w_latent_controller.json │ ├── debug.json │ ├── ffhq │ │ ├── age_w_latent_controller.json │ │ ├── default_w_latent_controller.json │ │ ├── expression3d_w_latent_controller.json │ │ ├── gamma_w_latent_controller.json │ │ ├── hair_w_latent_controller.json │ │ ├── merged_attr_w_latent_controller.json │ │ ├── orientation_w_latent_controller.json │ │ └── w_latent_controller_no_negative.json │ └── metfaces │ │ └── default_w_latent_controller.json ├── ffhq.json └── metfaces.json ├── datasets ├── __init__.py ├── afhq_dataset.py ├── dataframe_dataset.py ├── ffhq_dataset.py ├── image_net_classes.py ├── merged_dataframe_dataset.py └── metfaces_dataset.py ├── evaluation ├── __init__.py ├── age.py ├── expression.py ├── extract_recon_3d │ ├── __init__.py │ ├── disentanglement_dataloader.py │ ├── disentanglement_score.py │ └── extract_recon_3d.py ├── face_alignment_utils │ ├── __init__.py │ └── face_alignment_utils.py ├── gan_evaluation │ ├── __init__.py │ └── error_bar_plot.py ├── generation.py ├── hair.py ├── inference_class.py ├── orientation.py ├── recon_3d.py ├── separability.py └── tracker.py ├── fid_utils ├── __init__.py ├── calc_inception.py ├── evaluate_fid.py ├── fid.py ├── inception.py └── overwrite_inception.py ├── inception_stats └── README.md ├── inference ├── __init__.py ├── controller.py └── inference.py ├── losses ├── __init__.py ├── arc_face │ ├── __init__.py │ ├── arc_face_criterion.py │ ├── arc_face_model.py │ └── arc_face_skeleton.py ├── deep_expectation_age │ ├── __init__.py │ ├── deep_age_criterion.py │ ├── deep_age_model.py │ └── deep_age_skeleton.py ├── deep_head_pose │ ├── __init__.py │ ├── hopenet_criterion.py │ ├── hopenet_model.py │ └── hopenet_skeleton.py ├── dogfacenet │ ├── __init__.py │ ├── dogfacenet_criterion.py │ ├── dogfacenet_skeleton.py │ └── models │ │ ├── __init__.py │ │ ├── h5_model.py │ │ ├── pb.py │ │ └── pytorch_dogfacenet_model.py ├── face3dmm_recon │ ├── __init__.py │ ├── face3dmm_criterion.py │ ├── face3dmm_skeleton.py │ └── models │ │ ├── __init__.py │ │ ├── pb.py │ │ ├── pytorch_3d_recon_model.py │ │ ├── resnet.py │ │ └── tf_model.py ├── facial_features_esr │ ├── __init__.py │ ├── esr9_criterion.py │ ├── esr9_model.py │ └── esr9_skeleton.py ├── hair_loss │ ├── __init__.py │ ├── hair_criterion.py │ ├── hair_model.py │ └── hair_skeleton.py ├── imagenet │ ├── __init__.py │ ├── imagenet_criterion.py │ └── imagenet_skeleton.py ├── loss_model.py └── stayle │ ├── __init__.py │ ├── style_criterion.py │ └── style_skeleton.py ├── make_attributes_df.py ├── models ├── __init__.py ├── controller_model.py ├── gan_model.py └── pytorch_upfirdn2d.py ├── pretrained_models └── README.md ├── projection ├── __init__.py ├── lpips │ ├── __init__.py │ ├── base_model.py │ ├── dist_model.py │ ├── lpips.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth └── projection.py ├── train_controller.py ├── train_generator.py ├── trainers ├── __init__.py ├── controller_trainer.py ├── generator_trainer.py ├── non_leaking.py └── utils.py └── utils ├── __init__.py ├── file_utils.py ├── hopenet_utils.py ├── logging_utils.py ├── mini_batch_multi_split_utils.py ├── mini_batch_random_multi_split_utils.py ├── mini_batch_utils.py ├── pandas_utils.py ├── pil_images_utils.py ├── ploting_utils.py ├── spherical_harmonics_utils.py └── tensor_transforms.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /face-alignment/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Adrian Bulat 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /face-alignment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/face-alignment/__init__.py -------------------------------------------------------------------------------- /face-alignment/conda/conda_upload.sh: -------------------------------------------------------------------------------- 1 | PCG_NAME=face_alignment 2 | USER=1adrianb 3 | 4 | mkdir ~/conda-build 5 | conda config --set anaconda_upload no 6 | conda build conda/ 7 | anaconda -t $CONDA_UPLOAD_TOKEN upload -u $USER /home/travis/miniconda/envs/test-environment/conda-bld/noarch/face_alignment-1.1.1-py_1.tar.bz2 --force -------------------------------------------------------------------------------- /face-alignment/conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set version = "1.1.1" %} 2 | 3 | package: 4 | name: face_alignment 5 | version: {{ version }} 6 | 7 | source: 8 | path: .. 9 | 10 | build: 11 | number: 1 12 | noarch: python 13 | script: python setup.py install --single-version-externally-managed --record=record.txt 14 | 15 | requirements: 16 | build: 17 | - setuptools 18 | - python 19 | run: 20 | - python 21 | - pytorch 22 | - numpy 23 | - scikit-image 24 | - scipy 25 | - opencv 26 | - tqdm 27 | 28 | about: 29 | home: https://github.com/1adrianb/face-alignment 30 | license: BSD 31 | license_file: LICENSE 32 | summary: A 2D and 3D face alignment libray in python 33 | 34 | extra: 35 | recipe-maintainers: 36 | - 1adrianb 37 | -------------------------------------------------------------------------------- /face-alignment/examples/detect_landmarks_in_image.py: -------------------------------------------------------------------------------- 1 | import face_alignment 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | from skimage import io 5 | import collections 6 | 7 | 8 | # Run the 3D face alignment on a test image, without CUDA. 9 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu', flip_input=True) 10 | 11 | try: 12 | input_img = io.imread('../test/assets/aflw-test.jpg') 13 | except FileNotFoundError: 14 | input_img = io.imread('test/assets/aflw-test.jpg') 15 | 16 | preds = fa.get_landmarks(input_img)[-1] 17 | 18 | # 2D-Plot 19 | plot_style = dict(marker='o', 20 | markersize=4, 21 | linestyle='-', 22 | lw=2) 23 | 24 | pred_type = collections.namedtuple('prediction_type', ['slice', 'color']) 25 | pred_types = {'face': pred_type(slice(0, 17), (0.682, 0.780, 0.909, 0.5)), 26 | 'eyebrow1': pred_type(slice(17, 22), (1.0, 0.498, 0.055, 0.4)), 27 | 'eyebrow2': pred_type(slice(22, 27), (1.0, 0.498, 0.055, 0.4)), 28 | 'nose': pred_type(slice(27, 31), (0.345, 0.239, 0.443, 0.4)), 29 | 'nostril': pred_type(slice(31, 36), (0.345, 0.239, 0.443, 0.4)), 30 | 'eye1': pred_type(slice(36, 42), (0.596, 0.875, 0.541, 0.3)), 31 | 'eye2': pred_type(slice(42, 48), (0.596, 0.875, 0.541, 0.3)), 32 | 'lips': pred_type(slice(48, 60), (0.596, 0.875, 0.541, 0.3)), 33 | 'teeth': pred_type(slice(60, 68), (0.596, 0.875, 0.541, 0.4)) 34 | } 35 | 36 | fig = plt.figure(figsize=plt.figaspect(.5)) 37 | ax = fig.add_subplot(1, 2, 1) 38 | ax.imshow(input_img) 39 | 40 | for pred_type in pred_types.values(): 41 | ax.plot(preds[pred_type.slice, 0], 42 | preds[pred_type.slice, 1], 43 | color=pred_type.color, **plot_style) 44 | 45 | ax.axis('off') 46 | 47 | # 3D-Plot 48 | ax = fig.add_subplot(1, 2, 2, projection='3d') 49 | surf = ax.scatter(preds[:, 0] * 1.2, 50 | preds[:, 1], 51 | preds[:, 2], 52 | c='cyan', 53 | alpha=1.0, 54 | edgecolor='b') 55 | 56 | for pred_type in pred_types.values(): 57 | ax.plot3D(preds[pred_type.slice, 0] * 1.2, 58 | preds[pred_type.slice, 1], 59 | preds[pred_type.slice, 2], color='blue') 60 | 61 | ax.view_init(elev=90., azim=90.) 62 | ax.set_xlim(ax.get_xlim()[::-1]) 63 | plt.show() 64 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __author__ = """Adrian Bulat""" 4 | __email__ = 'adrian.bulat@nottingham.ac.uk' 5 | __version__ = '1.1.1' 6 | 7 | from .api import FaceAlignment, LandmarksType, NetworkSize 8 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import FaceDetector -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/blazeface/__init__.py: -------------------------------------------------------------------------------- 1 | from .blazeface_detector import BlazeFaceDetector as FaceDetector 2 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/blazeface/blazeface_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from torch.utils.model_zoo import load_url 4 | 5 | from ..core import FaceDetector 6 | 7 | from .net_blazeface import BlazeFace 8 | from .detect import * 9 | 10 | import requests 11 | import io 12 | 13 | 14 | def load_numpy_from_url(url): 15 | response = requests.get(url) 16 | response.raise_for_status() 17 | data = np.load(io.BytesIO(response.content)) # Works! 18 | return data 19 | 20 | models_urls = { 21 | 'blazeface_weights': 'https://github.com/hollance/BlazeFace-PyTorch/blob/master/blazeface.pth?raw=true', 22 | 'blazeface_anchors': 'https://github.com/hollance/BlazeFace-PyTorch/blob/master/anchors.npy?raw=true' 23 | } 24 | 25 | 26 | class BlazeFaceDetector(FaceDetector): 27 | def __init__(self, device, path_to_detector=None, path_to_anchor=None, verbose=False): 28 | super(BlazeFaceDetector, self).__init__(device, verbose) 29 | 30 | # Initialise the face detector 31 | if path_to_detector is None: 32 | model_weights = load_url(models_urls['blazeface_weights']) 33 | model_anchors = load_numpy_from_url(models_urls['blazeface_anchors']) 34 | else: 35 | model_weights = torch.load(path_to_detector) 36 | model_anchors = np.load(path_to_anchor) 37 | 38 | self.face_detector = BlazeFace() 39 | self.face_detector.load_state_dict(model_weights) 40 | self.face_detector.load_anchors_from_npy(model_anchors, device) 41 | 42 | # Optionally change the thresholds: 43 | self.face_detector.min_score_thresh = 0.5 44 | self.face_detector.min_suppression_threshold = 0.3 45 | 46 | self.face_detector.to(device) 47 | self.face_detector.eval() 48 | 49 | def detect_from_image(self, tensor_or_path): 50 | image = self.tensor_or_path_to_ndarray(tensor_or_path) 51 | 52 | bboxlist = detect(self.face_detector, image, device=self.device)[0] 53 | 54 | return bboxlist 55 | 56 | def detect_from_batch(self, tensor): 57 | bboxlists = batch_detect(self.face_detector, tensor, device=self.device) 58 | return bboxlists 59 | 60 | @property 61 | def reference_scale(self): 62 | return 195 63 | 64 | @property 65 | def reference_x_shift(self): 66 | return 0 67 | 68 | @property 69 | def reference_y_shift(self): 70 | return 0 71 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/blazeface/detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import sys 6 | import cv2 7 | import random 8 | import datetime 9 | import math 10 | import argparse 11 | import numpy as np 12 | 13 | import scipy.io as sio 14 | import zipfile 15 | 16 | from .utils import * 17 | # from .net_blazeface import s3fd 18 | 19 | 20 | def detect(net, img, device): 21 | H, W, C = img.shape 22 | orig_size = min(H, W) 23 | img, (xshift, yshift) = resize_and_crop_image(img, 128) 24 | preds = net.predict_on_image(img) 25 | 26 | if 0 == len(preds): 27 | return np.zeros((1, 1, 5)) 28 | 29 | shift = np.array([xshift, yshift] * 2) 30 | scores = preds[:, -1:] 31 | 32 | # TODO: ugly 33 | # reverses, x and y to adapt with face-alignment code 34 | locs = np.concatenate((preds[:, 1:2], preds[:, 0:1], preds[:, 3:4], preds[:, 2:3]), axis=1) 35 | return [np.concatenate((locs * orig_size + shift, scores), axis=1)] 36 | 37 | 38 | def batch_detect(net, img_batch, device): 39 | """ 40 | Inputs: 41 | - img_batch: a numpy array of shape (Batch size, Channels, Height, Width) 42 | """ 43 | B, C, H, W = img_batch.shape 44 | orig_size = min(H, W) 45 | 46 | # BB, HH, WW = img_batch.shape 47 | # if img_batch 48 | if isinstance(img_batch, torch.Tensor): 49 | img_batch = img_batch.cpu().numpy() 50 | img_batch = img_batch.transpose((0, 2, 3, 1)) 51 | 52 | imgs, (xshift, yshift) = resize_and_crop_batch(img_batch, 128) 53 | preds = net.predict_on_batch(imgs) 54 | bboxlists = [] 55 | for pred in preds: 56 | shift = np.array([xshift, yshift] * 2) 57 | scores = pred[:, -1:] 58 | locs = np.concatenate((pred[:, 1:2], pred[:, 0:1], pred[:, 3:4], pred[:, 2:3]), axis=1) 59 | bboxlists.append(np.concatenate((locs * orig_size + shift, scores), axis=1)) 60 | 61 | if 0 == len(bboxlists): 62 | bboxlists = np.zeros((1, 1, 5)) 63 | 64 | return bboxlists 65 | 66 | 67 | def flip_detect(net, img, device): 68 | img = cv2.flip(img, 1) 69 | b = detect(net, img, device) 70 | 71 | bboxlist = np.zeros(b.shape) 72 | bboxlist[:, 0] = img.shape[1] - b[:, 2] 73 | bboxlist[:, 1] = b[:, 1] 74 | bboxlist[:, 2] = img.shape[1] - b[:, 0] 75 | bboxlist[:, 3] = b[:, 3] 76 | bboxlist[:, 4] = b[:, 4] 77 | return bboxlist 78 | 79 | 80 | def pts_to_bb(pts): 81 | min_x, min_y = np.min(pts, axis=0) 82 | max_x, max_y = np.max(pts, axis=0) 83 | return np.array([min_x, min_y, max_x, max_y]) 84 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/blazeface/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def image_resize(image, width=None, height=None, inter=cv2.INTER_AREA): 6 | # initialize the dimensions of the image to be resized and 7 | # grab the image size 8 | dim = None 9 | (h, w) = image.shape[:2] 10 | 11 | # if both the width and height are None, then return the 12 | # original image 13 | if width is None and height is None: 14 | return image 15 | 16 | # check to see if the width is None 17 | if width is None: 18 | # calculate the ratio of the height and construct the 19 | # dimensions 20 | r = height / float(h) 21 | dim = (int(w * r), height) 22 | 23 | # otherwise, the height is None 24 | else: 25 | # calculate the ratio of the width and construct the 26 | # dimensions 27 | r = width / float(w) 28 | dim = (width, int(h * r)) 29 | 30 | # resize the image 31 | resized = cv2.resize(image, dim, interpolation=inter) 32 | 33 | # return the resized image 34 | return resized 35 | 36 | 37 | def resize_and_crop_image(image, dim): 38 | if image.shape[0] > image.shape[1]: 39 | img = image_resize(image, width=dim) 40 | yshift, xshift = (image.shape[0] - image.shape[1]) // 2, 0 41 | y_start = (img.shape[0] - img.shape[1]) // 2 42 | y_end = y_start + dim 43 | return img[y_start:y_end, :, :], (xshift, yshift) 44 | else: 45 | img = image_resize(image, height=dim) 46 | yshift, xshift = 0, (image.shape[1] - image.shape[0]) // 2 47 | x_start = (img.shape[1] - img.shape[0]) // 2 48 | x_end = x_start + dim 49 | return img[:, x_start:x_end, :], (xshift, yshift) 50 | 51 | 52 | def resize_and_crop_batch(frames, dim): 53 | """ 54 | Center crop + resize to (dim x dim) 55 | inputs: 56 | - frames: list of images (numpy arrays) 57 | - dim: output dimension size 58 | """ 59 | smframes = [] 60 | xshift, yshift = 0, 0 61 | for i in range(len(frames)): 62 | smframe, (xshift, yshift) = resize_and_crop_image(frames[i], dim) 63 | smframes.append(smframe) 64 | smframes = np.stack(smframes) 65 | return smframes, (xshift, yshift) 66 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/core.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import cv2 7 | from skimage import io 8 | 9 | 10 | class FaceDetector(object): 11 | """An abstract class representing a face detector. 12 | 13 | Any other face detection implementation must subclass it. All subclasses 14 | must implement ``detect_from_image``, that return a list of detected 15 | bounding boxes. Optionally, for speed considerations detect from path is 16 | recommended. 17 | """ 18 | 19 | def __init__(self, device, verbose): 20 | self.device = device 21 | self.verbose = verbose 22 | 23 | if verbose: 24 | if 'cpu' in device: 25 | logger = logging.getLogger(__name__) 26 | logger.warning("Detection running on CPU, this may be potentially slow.") 27 | 28 | if 'cpu' not in device and 'cuda' not in device: 29 | if verbose: 30 | logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) 31 | raise ValueError 32 | 33 | def detect_from_image(self, tensor_or_path): 34 | """Detects faces in a given image. 35 | 36 | This function detects the faces present in a provided BGR(usually) 37 | image. The input can be either the image itself or the path to it. 38 | 39 | Arguments: 40 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path 41 | to an image or the image itself. 42 | 43 | Example:: 44 | 45 | >>> path_to_image = 'data/image_01.jpg' 46 | ... detected_faces = detect_from_image(path_to_image) 47 | [A list of bounding boxes (x1, y1, x2, y2)] 48 | >>> image = cv2.imread(path_to_image) 49 | ... detected_faces = detect_from_image(image) 50 | [A list of bounding boxes (x1, y1, x2, y2)] 51 | 52 | """ 53 | raise NotImplementedError 54 | 55 | def detect_from_batch(self, tensor): 56 | """Detects faces in a given image. 57 | 58 | This function detects the faces present in a provided BGR(usually) 59 | image. The input can be either the image itself or the path to it. 60 | 61 | Arguments: 62 | tensor {torch.tensor} -- image batch tensor. 63 | 64 | Example:: 65 | 66 | >>> path_to_image = 'data/image_01.jpg' 67 | ... detected_faces = detect_from_image(path_to_image) 68 | [A list of bounding boxes (x1, y1, x2, y2)] 69 | >>> image = cv2.imread(path_to_image) 70 | ... detected_faces = detect_from_image(image) 71 | [A list of bounding boxes (x1, y1, x2, y2)] 72 | 73 | """ 74 | raise NotImplementedError 75 | 76 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): 77 | """Detects faces from all the images present in a given directory. 78 | 79 | Arguments: 80 | path {string} -- a string containing a path that points to the folder containing the images 81 | 82 | Keyword Arguments: 83 | extensions {list} -- list of string containing the extensions to be 84 | consider in the following format: ``.extension_name`` (default: 85 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the 86 | folder recursively (default: {False}) show_progress_bar {bool} -- 87 | display a progressbar (default: {True}) 88 | 89 | Example: 90 | >>> directory = 'data' 91 | ... detected_faces = detect_from_directory(directory) 92 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} 93 | 94 | """ 95 | if self.verbose: 96 | logger = logging.getLogger(__name__) 97 | 98 | if len(extensions) == 0: 99 | if self.verbose: 100 | logger.error("Expected at list one extension, but none was received.") 101 | raise ValueError 102 | 103 | if self.verbose: 104 | logger.info("Constructing the list of images.") 105 | additional_pattern = '/**/*' if recursive else '/*' 106 | files = [] 107 | for extension in extensions: 108 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) 109 | 110 | if self.verbose: 111 | logger.info("Finished searching for images. %s images found", len(files)) 112 | logger.info("Preparing to run the detection.") 113 | 114 | predictions = {} 115 | for image_path in tqdm(files, disable=not show_progress_bar): 116 | if self.verbose: 117 | logger.info("Running the face detector on image: %s", image_path) 118 | predictions[image_path] = self.detect_from_image(image_path) 119 | 120 | if self.verbose: 121 | logger.info("The detector was successfully run on all %s images", len(files)) 122 | 123 | return predictions 124 | 125 | @property 126 | def reference_scale(self): 127 | raise NotImplementedError 128 | 129 | @property 130 | def reference_x_shift(self): 131 | raise NotImplementedError 132 | 133 | @property 134 | def reference_y_shift(self): 135 | raise NotImplementedError 136 | 137 | @staticmethod 138 | def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): 139 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray 140 | 141 | Arguments: 142 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself 143 | """ 144 | if isinstance(tensor_or_path, str): 145 | return cv2.imread(tensor_or_path) if not rgb else io.imread(tensor_or_path) 146 | elif torch.is_tensor(tensor_or_path): 147 | # Call cpu in case its coming from cuda 148 | return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() 149 | elif isinstance(tensor_or_path, np.ndarray): 150 | return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path 151 | else: 152 | raise TypeError 153 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/dlib/__init__.py: -------------------------------------------------------------------------------- 1 | from .dlib_detector import DlibDetector as FaceDetector -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/dlib/dlib_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import dlib 4 | 5 | try: 6 | import urllib.request as request_file 7 | except BaseException: 8 | import urllib as request_file 9 | 10 | from ..core import FaceDetector 11 | from ...utils import appdata_dir 12 | 13 | 14 | class DlibDetector(FaceDetector): 15 | def __init__(self, device, path_to_detector=None, verbose=False): 16 | super().__init__(device, verbose) 17 | 18 | print('Warning: this detector is deprecated. Please use a different one, i.e.: S3FD.') 19 | base_path = os.path.join(appdata_dir('face_alignment'), "data") 20 | 21 | # Initialise the face detector 22 | if 'cuda' in device: 23 | if path_to_detector is None: 24 | path_to_detector = os.path.join( 25 | base_path, "mmod_human_face_detector.dat") 26 | 27 | if not os.path.isfile(path_to_detector): 28 | print("Downloading the face detection CNN. Please wait...") 29 | 30 | path_to_temp_detector = os.path.join( 31 | base_path, "mmod_human_face_detector.dat.download") 32 | 33 | if os.path.isfile(path_to_temp_detector): 34 | os.remove(os.path.join(path_to_temp_detector)) 35 | 36 | request_file.urlretrieve( 37 | "https://www.adrianbulat.com/downloads/dlib/mmod_human_face_detector.dat", 38 | os.path.join(path_to_temp_detector)) 39 | 40 | os.rename(os.path.join(path_to_temp_detector), os.path.join(path_to_detector)) 41 | 42 | self.face_detector = dlib.cnn_face_detection_model_v1(path_to_detector) 43 | else: 44 | self.face_detector = dlib.get_frontal_face_detector() 45 | 46 | def detect_from_image(self, tensor_or_path): 47 | image = self.tensor_or_path_to_ndarray(tensor_or_path, rgb=False) 48 | 49 | detected_faces = self.face_detector(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)) 50 | 51 | if 'cuda' not in self.device: 52 | detected_faces = [[d.left(), d.top(), d.right(), d.bottom()] for d in detected_faces] 53 | else: 54 | detected_faces = [[d.rect.left(), d.rect.top(), d.rect.right(), d.rect.bottom()] for d in detected_faces] 55 | 56 | return detected_faces 57 | 58 | @property 59 | def reference_scale(self): 60 | return 195 61 | 62 | @property 63 | def reference_x_shift(self): 64 | return 0 65 | 66 | @property 67 | def reference_y_shift(self): 68 | return 0 69 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/folder/__init__.py: -------------------------------------------------------------------------------- 1 | from .folder_detector import FolderDetector as FaceDetector -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/folder/folder_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | from ..core import FaceDetector 6 | 7 | 8 | class FolderDetector(FaceDetector): 9 | '''This is a simple helper module that assumes the faces were detected already 10 | (either previously or are provided as ground truth). 11 | 12 | The class expects to find the bounding boxes in the same format used by 13 | the rest of face detectors, mainly ``list[(x1,y1,x2,y2),...]``. 14 | For each image the detector will search for a file with the same name and with one of the 15 | following extensions: .npy, .t7 or .pth 16 | 17 | ''' 18 | 19 | def __init__(self, device, path_to_detector=None, verbose=False): 20 | super(FolderDetector, self).__init__(device, verbose) 21 | 22 | def detect_from_image(self, tensor_or_path): 23 | # Only strings supported 24 | if not isinstance(tensor_or_path, str): 25 | raise ValueError 26 | 27 | base_name = os.path.splitext(tensor_or_path)[0] 28 | 29 | if os.path.isfile(base_name + '.npy'): 30 | detected_faces = np.load(base_name + '.npy') 31 | elif os.path.isfile(base_name + '.t7'): 32 | detected_faces = torch.load(base_name + '.t7') 33 | elif os.path.isfile(base_name + '.pth'): 34 | detected_faces = torch.load(base_name + '.pth') 35 | else: 36 | raise FileNotFoundError 37 | 38 | if not isinstance(detected_faces, list): 39 | raise TypeError 40 | 41 | return detected_faces 42 | 43 | @property 44 | def reference_scale(self): 45 | return 195 46 | 47 | @property 48 | def reference_x_shift(self): 49 | return 0 50 | 51 | @property 52 | def reference_y_shift(self): 53 | return 0 54 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/sfd/__init__.py: -------------------------------------------------------------------------------- 1 | from .sfd_detector import SFDDetector as FaceDetector -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/sfd/bbox.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import cv2 5 | import random 6 | import datetime 7 | import time 8 | import math 9 | import argparse 10 | import numpy as np 11 | import torch 12 | 13 | try: 14 | from iou import IOU 15 | except BaseException: 16 | # IOU cython speedup 10x 17 | def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): 18 | sa = abs((ax2 - ax1) * (ay2 - ay1)) 19 | sb = abs((bx2 - bx1) * (by2 - by1)) 20 | x1, y1 = max(ax1, bx1), max(ay1, by1) 21 | x2, y2 = min(ax2, bx2), min(ay2, by2) 22 | w = x2 - x1 23 | h = y2 - y1 24 | if w < 0 or h < 0: 25 | return 0.0 26 | else: 27 | return 1.0 * w * h / (sa + sb - w * h) 28 | 29 | 30 | def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): 31 | xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 32 | dx, dy = (xc - axc) / aww, (yc - ayc) / ahh 33 | dw, dh = math.log(ww / aww), math.log(hh / ahh) 34 | return dx, dy, dw, dh 35 | 36 | 37 | def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): 38 | xc, yc = dx * aww + axc, dy * ahh + ayc 39 | ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh 40 | x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 41 | return x1, y1, x2, y2 42 | 43 | 44 | def nms(dets, thresh): 45 | if 0 == len(dets): 46 | return [] 47 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] 48 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 49 | order = scores.argsort()[::-1] 50 | 51 | keep = [] 52 | while order.size > 0: 53 | i = order[0] 54 | keep.append(i) 55 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) 56 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) 57 | 58 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) 59 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h) 60 | 61 | inds = np.where(ovr <= thresh)[0] 62 | order = order[inds + 1] 63 | 64 | return keep 65 | 66 | 67 | def encode(matched, priors, variances): 68 | """Encode the variances from the priorbox layers into the ground truth boxes 69 | we have matched (based on jaccard overlap) with the prior boxes. 70 | Args: 71 | matched: (tensor) Coords of ground truth for each prior in point-form 72 | Shape: [num_priors, 4]. 73 | priors: (tensor) Prior boxes in center-offset form 74 | Shape: [num_priors,4]. 75 | variances: (list[float]) Variances of priorboxes 76 | Return: 77 | encoded boxes (tensor), Shape: [num_priors, 4] 78 | """ 79 | 80 | # dist b/t match center and prior's center 81 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] 82 | # encode variance 83 | g_cxcy /= (variances[0] * priors[:, 2:]) 84 | # match wh / prior wh 85 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 86 | g_wh = torch.log(g_wh) / variances[1] 87 | # return target for smooth_l1_loss 88 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 89 | 90 | 91 | def decode(loc, priors, variances): 92 | """Decode locations from predictions using priors to undo 93 | the encoding we did for offset regression at train time. 94 | Args: 95 | loc (tensor): location predictions for loc layers, 96 | Shape: [num_priors,4] 97 | priors (tensor): Prior boxes in center-offset form. 98 | Shape: [num_priors,4]. 99 | variances: (list[float]) Variances of priorboxes 100 | Return: 101 | decoded bounding box predictions 102 | """ 103 | 104 | boxes = torch.cat(( 105 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 106 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 107 | boxes[:, :2] -= boxes[:, 2:] / 2 108 | boxes[:, 2:] += boxes[:, :2] 109 | return boxes 110 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/sfd/detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import sys 6 | import cv2 7 | import random 8 | import datetime 9 | import math 10 | import argparse 11 | import numpy as np 12 | 13 | import scipy.io as sio 14 | import zipfile 15 | from .net_s3fd import s3fd 16 | from .bbox import * 17 | 18 | 19 | def detect(net, img, device): 20 | img = img - np.array([104, 117, 123]) 21 | img = img.transpose(2, 0, 1) 22 | # Creates a batch of 1 23 | img = img.reshape((1,) + img.shape) 24 | 25 | if 'cuda' in device: 26 | torch.backends.cudnn.benchmark = True 27 | 28 | img = torch.from_numpy(img).float().to(device) 29 | 30 | return batch_detect(net, img, device) 31 | 32 | 33 | def batch_detect(net, img_batch, device): 34 | """ 35 | Inputs: 36 | - img_batch: a torch.Tensor of shape (Batch size, Channels, Height, Width) 37 | """ 38 | 39 | if 'cuda' in device: 40 | torch.backends.cudnn.benchmark = True 41 | 42 | BB, CC, HH, WW = img_batch.size() 43 | 44 | with torch.no_grad(): 45 | olist = net(img_batch.float()) # patched uint8_t overflow error 46 | 47 | for i in range(len(olist) // 2): 48 | olist[i * 2] = F.softmax(olist[i * 2], dim=1) 49 | 50 | bboxlists = [] 51 | 52 | olist = [oelem.data.cpu() for oelem in olist] 53 | 54 | for j in range(BB): 55 | bboxlist = [] 56 | for i in range(len(olist) // 2): 57 | ocls, oreg = olist[i * 2], olist[i * 2 + 1] 58 | FB, FC, FH, FW = ocls.size() # feature map size 59 | stride = 2**(i + 2) # 4,8,16,32,64,128 60 | anchor = stride * 4 61 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) 62 | for Iindex, hindex, windex in poss: 63 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride 64 | score = ocls[j, 1, hindex, windex] 65 | loc = oreg[j, :, hindex, windex].contiguous().view(1, 4) 66 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) 67 | variances = [0.1, 0.2] 68 | box = decode(loc, priors, variances) 69 | x1, y1, x2, y2 = box[0] * 1.0 70 | bboxlist.append([x1, y1, x2, y2, score]) 71 | 72 | bboxlists.append(bboxlist) 73 | 74 | bboxlists = np.array(bboxlists) 75 | 76 | if 0 == len(bboxlists): 77 | bboxlists = np.zeros((1, 1, 5)) 78 | 79 | return bboxlists 80 | 81 | 82 | def flip_detect(net, img, device): 83 | img = cv2.flip(img, 1) 84 | b = detect(net, img, device) 85 | 86 | bboxlist = np.zeros(b.shape) 87 | bboxlist[:, 0] = img.shape[1] - b[:, 2] 88 | bboxlist[:, 1] = b[:, 1] 89 | bboxlist[:, 2] = img.shape[1] - b[:, 0] 90 | bboxlist[:, 3] = b[:, 3] 91 | bboxlist[:, 4] = b[:, 4] 92 | return bboxlist 93 | 94 | 95 | def pts_to_bb(pts): 96 | min_x, min_y = np.min(pts, axis=0) 97 | max_x, max_y = np.max(pts, axis=0) 98 | return np.array([min_x, min_y, max_x, max_y]) 99 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/sfd/net_s3fd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L2Norm(nn.Module): 7 | def __init__(self, n_channels, scale=1.0): 8 | super(L2Norm, self).__init__() 9 | self.n_channels = n_channels 10 | self.scale = scale 11 | self.eps = 1e-10 12 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 13 | self.weight.data *= 0.0 14 | self.weight.data += self.scale 15 | 16 | def forward(self, x): 17 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps 18 | x = x / norm * self.weight.view(1, -1, 1, 1) 19 | return x 20 | 21 | 22 | class s3fd(nn.Module): 23 | def __init__(self): 24 | super(s3fd, self).__init__() 25 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 26 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 27 | 28 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 29 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 30 | 31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 34 | 35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 38 | 39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 42 | 43 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) 44 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) 45 | 46 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 47 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) 48 | 49 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) 50 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 51 | 52 | self.conv3_3_norm = L2Norm(256, scale=10) 53 | self.conv4_3_norm = L2Norm(512, scale=8) 54 | self.conv5_3_norm = L2Norm(512, scale=5) 55 | 56 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 57 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 58 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 59 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 60 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 61 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 62 | 63 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) 64 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) 65 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 66 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 67 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) 68 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 69 | 70 | def forward(self, x): 71 | h = F.relu(self.conv1_1(x)) 72 | h = F.relu(self.conv1_2(h)) 73 | h = F.max_pool2d(h, 2, 2) 74 | 75 | h = F.relu(self.conv2_1(h)) 76 | h = F.relu(self.conv2_2(h)) 77 | h = F.max_pool2d(h, 2, 2) 78 | 79 | h = F.relu(self.conv3_1(h)) 80 | h = F.relu(self.conv3_2(h)) 81 | h = F.relu(self.conv3_3(h)) 82 | f3_3 = h 83 | h = F.max_pool2d(h, 2, 2) 84 | 85 | h = F.relu(self.conv4_1(h)) 86 | h = F.relu(self.conv4_2(h)) 87 | h = F.relu(self.conv4_3(h)) 88 | f4_3 = h 89 | h = F.max_pool2d(h, 2, 2) 90 | 91 | h = F.relu(self.conv5_1(h)) 92 | h = F.relu(self.conv5_2(h)) 93 | h = F.relu(self.conv5_3(h)) 94 | f5_3 = h 95 | h = F.max_pool2d(h, 2, 2) 96 | 97 | h = F.relu(self.fc6(h)) 98 | h = F.relu(self.fc7(h)) 99 | ffc7 = h 100 | h = F.relu(self.conv6_1(h)) 101 | h = F.relu(self.conv6_2(h)) 102 | f6_2 = h 103 | h = F.relu(self.conv7_1(h)) 104 | h = F.relu(self.conv7_2(h)) 105 | f7_2 = h 106 | 107 | f3_3 = self.conv3_3_norm(f3_3) 108 | f4_3 = self.conv4_3_norm(f4_3) 109 | f5_3 = self.conv5_3_norm(f5_3) 110 | 111 | cls1 = self.conv3_3_norm_mbox_conf(f3_3) 112 | reg1 = self.conv3_3_norm_mbox_loc(f3_3) 113 | cls2 = self.conv4_3_norm_mbox_conf(f4_3) 114 | reg2 = self.conv4_3_norm_mbox_loc(f4_3) 115 | cls3 = self.conv5_3_norm_mbox_conf(f5_3) 116 | reg3 = self.conv5_3_norm_mbox_loc(f5_3) 117 | cls4 = self.fc7_mbox_conf(ffc7) 118 | reg4 = self.fc7_mbox_loc(ffc7) 119 | cls5 = self.conv6_2_mbox_conf(f6_2) 120 | reg5 = self.conv6_2_mbox_loc(f6_2) 121 | cls6 = self.conv7_2_mbox_conf(f7_2) 122 | reg6 = self.conv7_2_mbox_loc(f7_2) 123 | 124 | # max-out background label 125 | chunk = torch.chunk(cls1, 4, 1) 126 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) 127 | cls1 = torch.cat([bmax, chunk[3]], dim=1) 128 | 129 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] 130 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/sfd/sfd_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from torch.utils.model_zoo import load_url 4 | 5 | from ..core import FaceDetector 6 | 7 | from .net_s3fd import s3fd 8 | from .bbox import * 9 | from .detect import * 10 | 11 | models_urls = { 12 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', 13 | } 14 | 15 | 16 | class SFDDetector(FaceDetector): 17 | def __init__(self, device, path_to_detector=None, verbose=False): 18 | super(SFDDetector, self).__init__(device, verbose) 19 | 20 | # Initialise the face detector 21 | if path_to_detector is None: 22 | model_weights = load_url(models_urls['s3fd']) 23 | else: 24 | model_weights = torch.load(path_to_detector) 25 | 26 | self.face_detector = s3fd() 27 | self.face_detector.load_state_dict(model_weights) 28 | self.face_detector.to(device) 29 | self.face_detector.eval() 30 | 31 | def detect_from_image(self, tensor_or_path): 32 | image = self.tensor_or_path_to_ndarray(tensor_or_path) 33 | 34 | bboxlist = detect(self.face_detector, image, device=self.device)[0] 35 | keep = nms(bboxlist, 0.3) 36 | bboxlist = bboxlist[keep, :] 37 | bboxlist = [x for x in bboxlist if x[-1] > 0.5] 38 | 39 | return bboxlist 40 | 41 | def detect_from_batch(self, tensor): 42 | bboxlists = batch_detect(self.face_detector, tensor, device=self.device) 43 | 44 | new_bboxlists = [] 45 | for i in range(bboxlists.shape[0]): 46 | bboxlist = bboxlists[i] 47 | keep = nms(bboxlist, 0.3) 48 | bboxlist = bboxlist[keep, :] 49 | bboxlist = [x for x in bboxlist if x[-1] > 0.5] 50 | new_bboxlists.append(bboxlist) 51 | 52 | return new_bboxlists 53 | 54 | @property 55 | def reference_scale(self): 56 | return 195 57 | 58 | @property 59 | def reference_x_shift(self): 60 | return 0 61 | 62 | @property 63 | def reference_y_shift(self): 64 | return 0 65 | -------------------------------------------------------------------------------- /face-alignment/requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | scipy>=0.17.0 3 | scikit-image 4 | -------------------------------------------------------------------------------- /face-alignment/setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.1.1 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:face_alignment/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [metadata] 15 | description-file = README.md 16 | 17 | [bdist_wheel] 18 | universal = 1 19 | 20 | [flake8] 21 | exclude = 22 | .github, 23 | examples, 24 | docs, 25 | .tox, 26 | bin, 27 | dist, 28 | tools, 29 | *.egg-info, 30 | __init__.py, 31 | *.yml 32 | max-line-length = 160 -------------------------------------------------------------------------------- /face-alignment/setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | from os import path 4 | import re 5 | from setuptools import setup, find_packages 6 | # To use consisten encodings 7 | from codecs import open 8 | 9 | # Function from: https://github.com/pytorch/vision/blob/master/setup.py 10 | 11 | 12 | def read(*names, **kwargs): 13 | with io.open( 14 | os.path.join(os.path.dirname(__file__), *names), 15 | encoding=kwargs.get("encoding", "utf8") 16 | ) as fp: 17 | return fp.read() 18 | 19 | # Function from: https://github.com/pytorch/vision/blob/master/setup.py 20 | 21 | 22 | def find_version(*file_paths): 23 | version_file = read(*file_paths) 24 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", 25 | version_file, re.M) 26 | if version_match: 27 | return version_match.group(1) 28 | raise RuntimeError("Unable to find version string.") 29 | 30 | here = path.abspath(path.dirname(__file__)) 31 | 32 | # Get the long description from the README file 33 | with open(path.join(here, 'README.md'), encoding='utf-8') as readme_file: 34 | long_description = readme_file.read() 35 | 36 | VERSION = find_version('face_alignment', '__init__.py') 37 | 38 | requirements = [ 39 | 'torch', 40 | 'numpy', 41 | 'scipy>=0.17', 42 | 'scikit-image', 43 | 'opencv-python', 44 | 'tqdm', 45 | 'enum34;python_version<"3.4"' 46 | ] 47 | 48 | setup( 49 | name='face_alignment', 50 | version=VERSION, 51 | 52 | description="Detector 2D or 3D face landmarks from Python", 53 | long_description=long_description, 54 | long_description_content_type="text/markdown", 55 | 56 | # Author details 57 | author="Adrian Bulat", 58 | author_email="adrian.bulat@nottingham.ac.uk", 59 | url="https://github.com/1adrianb/face-alignment", 60 | 61 | # Package info 62 | packages=find_packages(exclude=('test',)), 63 | 64 | install_requires=requirements, 65 | license='BSD', 66 | zip_safe=True, 67 | 68 | classifiers=[ 69 | 'Development Status :: 5 - Production/Stable', 70 | 'Operating System :: OS Independent', 71 | 'License :: OSI Approved :: BSD License', 72 | 'Natural Language :: English', 73 | 74 | # Supported python versions 75 | 'Programming Language :: Python :: 2', 76 | 'Programming Language :: Python :: 2.7', 77 | 'Programming Language :: Python :: 3', 78 | 'Programming Language :: Python :: 3.3', 79 | 'Programming Language :: Python :: 3.4', 80 | 'Programming Language :: Python :: 3.5', 81 | 'Programming Language :: Python :: 3.6', 82 | ], 83 | ) 84 | -------------------------------------------------------------------------------- /face-alignment/test/assets/aflw-test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/face-alignment/test/assets/aflw-test.jpg -------------------------------------------------------------------------------- /face-alignment/test/facealignment_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import face_alignment 3 | 4 | 5 | class Tester(unittest.TestCase): 6 | def test_predict_points(self): 7 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu') 8 | fa.get_landmarks('test/assets/aflw-test.jpg') 9 | 10 | if __name__ == '__main__': 11 | unittest.main() 12 | -------------------------------------------------------------------------------- /face-alignment/test/smoke_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import face_alignment 3 | -------------------------------------------------------------------------------- /face-alignment/test/test_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | import unittest 4 | from face_alignment.utils import * 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class Tester(unittest.TestCase): 10 | def test_flip_is_label(self): 11 | # Generate the points 12 | heatmaps = torch.from_numpy(np.random.randint(1, high=250, size=(68, 64, 64)).astype('float32')) 13 | 14 | flipped_heatmaps = flip(flip(heatmaps.clone(), is_label=True), is_label=True) 15 | 16 | assert np.allclose(heatmaps.numpy(), flipped_heatmaps.numpy()) 17 | 18 | def test_flip_is_image(self): 19 | fake_image = torch.torch.rand(3, 256, 256) 20 | fliped_fake_image = flip(flip(fake_image.clone())) 21 | 22 | assert np.allclose(fake_image.numpy(), fliped_fake_image.numpy()) 23 | 24 | def test_getpreds(self): 25 | pts = torch.from_numpy(np.random.randint(1, high=63, size=(68, 2)).astype('float32')) 26 | 27 | heatmaps = np.zeros((68, 256, 256)) 28 | for i in range(68): 29 | if pts[i, 0] > 0: 30 | heatmaps[i] = draw_gaussian(heatmaps[i], pts[i], 2) 31 | heatmaps = torch.from_numpy(np.expand_dims(heatmaps, axis=0)) 32 | 33 | preds, _ = get_preds_fromhm(heatmaps) 34 | 35 | assert np.allclose(pts.numpy(), preds.numpy(), atol=5) 36 | 37 | def test_create_heatmaps(self): 38 | reference_scale = 195 39 | target_landmarks = torch.randint(0, 255, (1, 68, 2)).type(torch.float) # simulated dataset 40 | bb = create_bounding_box(target_landmarks) 41 | centers = torch.stack([bb[:, 2] - (bb[:, 2] - bb[:, 0]) / 2.0, bb[:, 3] - (bb[:, 3] - bb[:, 1]) / 2.0], dim=1) 42 | centers[:, 1] = centers[:, 1] - (bb[:, 3] - bb[:, 1]) * 0.12 # Not sure where 0.12 comes from 43 | scales = (bb[:, 2] - bb[:, 0] + bb[:, 3] - bb[:, 1]) / reference_scale 44 | heatmaps = create_target_heatmap(target_landmarks, centers, scales) 45 | preds = get_preds_fromhm(heatmaps, centers.squeeze(), scales.squeeze())[1] 46 | 47 | assert np.allclose(preds.numpy(), target_landmarks.numpy(), atol=5) 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /face-alignment/tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E305,E402,E721,F401,F403,F405,F821,F841,F999,W503 -------------------------------------------------------------------------------- /images/controllable_gan.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/images/controllable_gan.gif -------------------------------------------------------------------------------- /images/dog_control.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/images/dog_control.png -------------------------------------------------------------------------------- /images/face_control.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/images/face_control.png -------------------------------------------------------------------------------- /images/painting_control.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/images/painting_control.png -------------------------------------------------------------------------------- /images/pinting_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/images/pinting_style.png -------------------------------------------------------------------------------- /notebooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/notebooks/__init__.py -------------------------------------------------------------------------------- /resources/README.md: -------------------------------------------------------------------------------- 1 | Save pre-trained GANs here -------------------------------------------------------------------------------- /resources/ffhq_1K_attributes_samples_df.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/resources/ffhq_1K_attributes_samples_df.pkl -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/__init__.py -------------------------------------------------------------------------------- /src/gan_control/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/__init__.py -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/afhq/default_w_latent_controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "afhq_dog001dog005ori01_class_last_layer", 3 | "results_dir": "/mnt/md4/orville/Alon/research_desk/control_models/latent_rec_l1", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 512, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 4, 9 | "in_dim": 3, 10 | "mid_dim": 512, 11 | "loss": "orientation_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "rec_loss": "l1", 16 | "generator_dir": "/mnt/md4/orville/Alon/research_desk/gan_models/afhq512v4/dog001dog005ori01_class_last_layer_20201012-140053", 17 | "iter": 800000, 18 | "start_iter": 0, 19 | "batch": 128, 20 | "reg_every": 4, 21 | "lr": 0.002, 22 | "parallel": true, 23 | "generate_controls": "sampled_df", 24 | "controller_type": "latent_w", 25 | "sampled_df_path": "/mnt/md4/orville/Alon/research_desk/attributes_dfs/afhq_samples200K_dog001dog005ori01_class_last_layer_20201012-140053_df.pkl", 26 | "min_evaluate_interval": 5000, 27 | "save_images_interval": 5000, 28 | "save_nets_interval": 20000, 29 | "losses": ["latent_rec", "latent_adv_", "attribute_rec_"], 30 | "attribute_rec_w": 0.01 31 | }, 32 | "data_config": { 33 | "data_set_name": "ffhq", 34 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 35 | "workers": 32 36 | }, 37 | "evaluation_config": { 38 | "sample_batch": 16 39 | }, 40 | "tensorboard_config": { 41 | "enabled": true 42 | }, 43 | "monitor_config": { 44 | "enabled": false 45 | }, 46 | "ckpt_config": { 47 | "enabled": false, 48 | "ckpt": "no_ckpt" 49 | } 50 | } -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/debug.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "hair_controller", 3 | "results_dir": "/mnt/md4/orville/Alon/res_gan_controllers/ffhq_gans_256X256", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 256, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 8, 9 | "in_dim": 3, 10 | "mid_dim": 256, 11 | "loss": "hair_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "generator_dir": "/mnt/md4/orville/Alon/res_gan/ffhq_gans_256X256_14_7/id_hair_ffhq_20200714-120438/", 16 | "iter": 800000, 17 | "start_iter": 0, 18 | "batch": 32, 19 | "reg_every": 4, 20 | "lr": 0.002, 21 | "parallel": true, 22 | "generate_controls": "sampled", 23 | 24 | "min_evaluate_interval": 500, 25 | "save_images_interval": 2000, 26 | "save_nets_interval": 10000 27 | }, 28 | "data_config": { 29 | "data_set_name": "ffhq", 30 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 31 | "workers": 32 32 | }, 33 | "evaluation_config": { 34 | "sample_batch": 16 35 | }, 36 | "tensorboard_config": { 37 | "enabled": true 38 | }, 39 | "monitor_config": { 40 | "enabled": false 41 | }, 42 | "ckpt_config": { 43 | "enabled": false, 44 | "ckpt": "no_ckpt" 45 | } 46 | } -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/ffhq/age_w_latent_controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "id25_or02_ex01_gamma04_hair01_age02", 3 | "results_dir": "/mnt/md4/orville/Alon/research_desk/control_models/latent_rec_l1", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 512, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 4, 9 | "in_dim": 1, 10 | "mid_dim": 512, 11 | "loss": "age_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "rec_loss": "l1", 16 | "generator_dir": "/mnt/md4/orville/Alon/res_gan/controllable_ffhq512/id25_or02_ex01_gamma04_hair01_age02_20200817-091636", 17 | "iter": 800000, 18 | "start_iter": 0, 19 | "batch": 128, 20 | "reg_every": 4, 21 | "lr": 0.002, 22 | "parallel": true, 23 | "generate_controls": "sampled_df", 24 | "controller_type": "latent_w", 25 | "sampled_df_path": "/mnt/md4/orville/Alon/research_desk/attributes_dfs/samples200K_id25_or02_ex01_gamma04_hair01_age02_20200817-091636_df.pkl", 26 | "min_evaluate_interval": 5000, 27 | "save_images_interval": 5000, 28 | "save_nets_interval": 20000, 29 | "losses": ["latent_rec", "latent_adv_", "attribute_rec_"], 30 | "attribute_rec_w": 0.01 31 | }, 32 | "data_config": { 33 | "data_set_name": "ffhq", 34 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 35 | "workers": 32 36 | }, 37 | "evaluation_config": { 38 | "sample_batch": 16 39 | }, 40 | "tensorboard_config": { 41 | "enabled": true 42 | }, 43 | "monitor_config": { 44 | "enabled": false 45 | }, 46 | "ckpt_config": { 47 | "enabled": false, 48 | "ckpt": "no_ckpt" 49 | } 50 | } -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/ffhq/default_w_latent_controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "age015id025exp02hai04ori02gam15_normal_copy", 3 | "results_dir": "/mnt/md4/orville/Alon/research_desk/control_models/latent_rec_l1", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 512, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 4, 9 | "in_dim": 3, 10 | "mid_dim": 512, 11 | "loss": "hair_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "rec_loss": "l1", 16 | "generator_dir": "/mnt/md4/orville/Alon/research_desk/gan_models/ffhq512/age015id025exp02hai04ori02gam15_normal_20200913-121433", 17 | "iter": 800000, 18 | "start_iter": 0, 19 | "batch": 128, 20 | "reg_every": 4, 21 | "lr": 0.002, 22 | "parallel": true, 23 | "generate_controls": "sampled_df", 24 | "controller_type": "latent_w", 25 | "sampled_df_path": "/mnt/md4/orville/Alon/research_desk/attributes_dfs/ffhq_samples200K_align3d_df.pkl", 26 | "min_evaluate_interval": 5000, 27 | "save_images_interval": 5000, 28 | "save_nets_interval": 20000, 29 | "losses": ["latent_rec", "latent_adv_", "attribute_rec_"], 30 | "attribute_rec_w": 0.01 31 | }, 32 | "data_config": { 33 | "data_set_name": "ffhq", 34 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 35 | "workers": 32 36 | }, 37 | "evaluation_config": { 38 | "sample_batch": 16 39 | }, 40 | "tensorboard_config": { 41 | "enabled": true 42 | }, 43 | "monitor_config": { 44 | "enabled": false 45 | }, 46 | "ckpt_config": { 47 | "enabled": false, 48 | "ckpt": "no_ckpt" 49 | } 50 | } -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/ffhq/expression3d_w_latent_controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "id25_or02_ex01_gamma04_hair01_age02", 3 | "results_dir": "/mnt/md4/orville/Alon/research_desk/control_models/latent_rec_l1", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 512, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 4, 9 | "in_dim": 64, 10 | "mid_dim": 512, 11 | "loss": "expression_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "rec_loss": "l1", 16 | "generator_dir": "/mnt/md4/orville/Alon/res_gan/controllable_ffhq512/id25_or02_ex01_gamma04_hair01_age02_20200817-091636", 17 | "iter": 800000, 18 | "start_iter": 0, 19 | "batch": 128, 20 | "reg_every": 4, 21 | "lr": 0.002, 22 | "parallel": true, 23 | "generate_controls": "sampled_df", 24 | "controller_type": "latent_w", 25 | "sampled_df_path": "/mnt/md4/orville/Alon/research_desk/attributes_dfs/samples200K_id25_or02_ex01_gamma04_hair01_age02_20200817-091636_df.pkl", 26 | "min_evaluate_interval": 5000, 27 | "save_images_interval": 5000, 28 | "save_nets_interval": 20000, 29 | "losses": ["latent_rec", "latent_adv_", "attribute_rec_"], 30 | "attribute_rec_w": 0.01 31 | }, 32 | "data_config": { 33 | "data_set_name": "ffhq", 34 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 35 | "workers": 32 36 | }, 37 | "evaluation_config": { 38 | "sample_batch": 16 39 | }, 40 | "tensorboard_config": { 41 | "enabled": true 42 | }, 43 | "monitor_config": { 44 | "enabled": false 45 | }, 46 | "ckpt_config": { 47 | "enabled": false, 48 | "ckpt": "no_ckpt" 49 | } 50 | } -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/ffhq/gamma_w_latent_controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "id25_or02_ex01_gamma04_hair01_age02", 3 | "results_dir": "/mnt/md4/orville/Alon/research_desk/control_models/latent_rec_l1", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 512, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 4, 9 | "in_dim": 27, 10 | "mid_dim": 512, 11 | "loss": "gamma_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "rec_loss": "l1", 16 | "generator_dir": "/mnt/md4/orville/Alon/res_gan/controllable_ffhq512/id25_or02_ex01_gamma04_hair01_age02_20200817-091636", 17 | "iter": 800000, 18 | "start_iter": 0, 19 | "batch": 128, 20 | "reg_every": 4, 21 | "lr": 0.002, 22 | "parallel": true, 23 | "generate_controls": "sampled_df", 24 | "controller_type": "latent_w", 25 | "sampled_df_path": "/mnt/md4/orville/Alon/research_desk/attributes_dfs/samples200K_id25_or02_ex01_gamma04_hair01_age02_20200817-091636_df.pkl", 26 | "min_evaluate_interval": 5000, 27 | "save_images_interval": 5000, 28 | "save_nets_interval": 20000, 29 | "losses": ["latent_rec", "latent_adv_", "attribute_rec_"], 30 | "attribute_rec_w": 0.01 31 | }, 32 | "data_config": { 33 | "data_set_name": "ffhq", 34 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 35 | "workers": 32 36 | }, 37 | "evaluation_config": { 38 | "sample_batch": 16 39 | }, 40 | "tensorboard_config": { 41 | "enabled": true 42 | }, 43 | "monitor_config": { 44 | "enabled": false 45 | }, 46 | "ckpt_config": { 47 | "enabled": false, 48 | "ckpt": "no_ckpt" 49 | } 50 | } -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/ffhq/hair_w_latent_controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "id25_or02_ex01_gamma04_hair01_age02", 3 | "results_dir": "/mnt/md4/orville/Alon/research_desk/control_models/latent_rec_l1", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 512, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 4, 9 | "in_dim": 3, 10 | "mid_dim": 512, 11 | "loss": "hair_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "rec_loss": "l1", 16 | "generator_dir": "/mnt/md4/orville/Alon/res_gan/controllable_ffhq512/id25_or02_ex01_gamma04_hair01_age02_20200817-091636", 17 | "iter": 800000, 18 | "start_iter": 0, 19 | "batch": 128, 20 | "reg_every": 4, 21 | "lr": 0.002, 22 | "parallel": true, 23 | "generate_controls": "sampled_df", 24 | "controller_type": "latent_w", 25 | "sampled_df_path": "/mnt/md4/orville/Alon/research_desk/attributes_dfs/samples200K_id25_or02_ex01_gamma04_hair01_age02_20200817-091636_df.pkl", 26 | "min_evaluate_interval": 5000, 27 | "save_images_interval": 5000, 28 | "save_nets_interval": 20000, 29 | "losses": ["latent_rec", "latent_adv_", "attribute_rec_"], 30 | "attribute_rec_w": 0.01 31 | }, 32 | "data_config": { 33 | "data_set_name": "ffhq", 34 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 35 | "workers": 32 36 | }, 37 | "evaluation_config": { 38 | "sample_batch": 16 39 | }, 40 | "tensorboard_config": { 41 | "enabled": true 42 | }, 43 | "monitor_config": { 44 | "enabled": false 45 | }, 46 | "ckpt_config": { 47 | "enabled": false, 48 | "ckpt": "no_ckpt" 49 | } 50 | } -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/ffhq/merged_attr_w_latent_controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "merged_attr_vanilla_controller_v0", 3 | "results_dir": "/mnt/md4/orville/Alon/research_desk/control_models/merged_attr", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 512, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 4, 9 | "in_dim": 1, 10 | "mid_dim": 512, 11 | "loss": "age_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "rec_loss": "l1", 16 | "generator_dir": "/mnt/md4/orville/Alon/research_desk/gan_models/ffhq512/vanilla_20201004-120748", 17 | "iter": 800000, 18 | "start_iter": 0, 19 | "batch": 128, 20 | "reg_every": 4, 21 | "lr": 0.002, 22 | "parallel": true, 23 | "generate_controls": "sampled_df", 24 | "controller_type": "latent_w", 25 | "sampled_df_path": "/mnt/md4/orville/Alon/research_desk/attributes_dfs/ffhq512_vanilla_gan_samples100K_align_3d_emb_df.pkl", 26 | "min_evaluate_interval": 5000, 27 | "save_images_interval": 5000, 28 | "save_nets_interval": 20000, 29 | "losses": ["latent_rec", "latent_adv_", "attribute_rec_"], 30 | "attribute_rec_w": 0.01 31 | }, 32 | "data_config": { 33 | "data_set_name": "ffhq", 34 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 35 | "workers": 32 36 | }, 37 | "evaluation_config": { 38 | "sample_batch": 16 39 | }, 40 | "tensorboard_config": { 41 | "enabled": true 42 | }, 43 | "monitor_config": { 44 | "enabled": false 45 | }, 46 | "ckpt_config": { 47 | "enabled": false, 48 | "ckpt": "no_ckpt" 49 | } 50 | } -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/ffhq/orientation_w_latent_controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "id25_or02_ex01_gamma04_hair01_age02", 3 | "results_dir": "/mnt/md4/orville/Alon/research_desk/control_models/latent_rec_l1", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 512, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 4, 9 | "in_dim": 3, 10 | "mid_dim": 512, 11 | "loss": "orientation_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "rec_loss": "l1", 16 | "generator_dir": "/mnt/md4/orville/Alon/res_gan/controllable_ffhq512/id25_or02_ex01_gamma04_hair01_age02_20200817-091636", 17 | "iter": 800000, 18 | "start_iter": 0, 19 | "batch": 128, 20 | "reg_every": 4, 21 | "lr": 0.002, 22 | "parallel": true, 23 | "generate_controls": "sampled_df", 24 | "controller_type": "latent_w", 25 | "sampled_df_path": "/mnt/md4/orville/Alon/research_desk/attributes_dfs/samples200K_id25_or02_ex01_gamma04_hair01_age02_20200817-091636_df.pkl", 26 | "min_evaluate_interval": 5000, 27 | "save_images_interval": 5000, 28 | "save_nets_interval": 20000, 29 | "losses": ["latent_rec", "latent_adv_", "attribute_rec_"], 30 | "attribute_rec_w": 0.01 31 | }, 32 | "data_config": { 33 | "data_set_name": "ffhq", 34 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 35 | "workers": 32 36 | }, 37 | "evaluation_config": { 38 | "sample_batch": 16 39 | }, 40 | "tensorboard_config": { 41 | "enabled": true 42 | }, 43 | "monitor_config": { 44 | "enabled": false 45 | }, 46 | "ckpt_config": { 47 | "enabled": false, 48 | "ckpt": "no_ckpt" 49 | } 50 | } -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/ffhq/w_latent_controller_no_negative.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "zero_not_same", 3 | "results_dir": "/mnt/md4/orville/Alon/research_desk/control_models/ablation_study", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 512, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 4, 9 | "in_dim": 1, 10 | "mid_dim": 512, 11 | "loss": "age_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "rec_loss": "l1", 16 | "generator_dir": "/mnt/md4/orville/Alon/research_desk/gan_models/ffhq512_ablation_study/age015id025exp02hai04ori02gam15_zero_not_same_20210309-081936_copy", 17 | "iter": 800000, 18 | "start_iter": 0, 19 | "batch": 128, 20 | "reg_every": 4, 21 | "lr": 0.002, 22 | "parallel": true, 23 | "generate_controls": "sampled_df", 24 | "controller_type": "latent_w", 25 | "sampled_df_path": "/mnt/md4/orville/Alon/research_desk/attributes_dfs/ffhq_no_negative_samples100K_align3d_df.pkl", 26 | "min_evaluate_interval": 5000, 27 | "save_images_interval": 5000, 28 | "save_nets_interval": 20000, 29 | "losses": ["latent_rec", "latent_adv_", "attribute_rec_"], 30 | "attribute_rec_w": 0.01 31 | }, 32 | "data_config": { 33 | "data_set_name": "ffhq", 34 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 35 | "workers": 32 36 | }, 37 | "evaluation_config": { 38 | "sample_batch": 16 39 | }, 40 | "tensorboard_config": { 41 | "enabled": true 42 | }, 43 | "monitor_config": { 44 | "enabled": false 45 | }, 46 | "ckpt_config": { 47 | "enabled": false, 48 | "ckpt": "no_ckpt" 49 | } 50 | } -------------------------------------------------------------------------------- /src/gan_control/configs/controller_configs/metfaces/default_w_latent_controller.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_name": "metfaces_age015id025exp025ori02sty01_normal_20201029-100516", 3 | "results_dir": "/mnt/md4/orville/Alon/research_desk/control_models/latent_rec_l1", 4 | "model_config": { 5 | "latent_size": 512, 6 | "size": 512, 7 | "lr_mlp": 0.01, 8 | "n_mlp": 4, 9 | "in_dim": 8, 10 | "mid_dim": 512, 11 | "loss": "expression_loss" 12 | }, 13 | "training_config": { 14 | "debug": false, 15 | "rec_loss": "l1", 16 | "generator_dir": "/mnt/md4/orville/Alon/research_desk/gan_models/met-faces512v2/age015id025exp025ori02sty01_normal_20201029-100516", 17 | "iter": 800000, 18 | "start_iter": 0, 19 | "batch": 128, 20 | "reg_every": 4, 21 | "lr": 0.002, 22 | "parallel": true, 23 | "generate_controls": "sampled_df", 24 | "controller_type": "latent_w", 25 | "sampled_df_path": "/mnt/md4/orville/Alon/research_desk/attributes_dfs/metfaces_samples100K_align3d_df.pkl", 26 | "min_evaluate_interval": 5000, 27 | "save_images_interval": 5000, 28 | "save_nets_interval": 20000, 29 | "losses": ["latent_rec", "latent_adv_", "attribute_rec_"], 30 | "attribute_rec_w": 0.01 31 | }, 32 | "data_config": { 33 | "data_set_name": "ffhq", 34 | "path": "/mnt/md4/orville/Alon/res/research_gan/ffhq-dataset/images1024x1024", 35 | "workers": 32 36 | }, 37 | "evaluation_config": { 38 | "sample_batch": 16 39 | }, 40 | "tensorboard_config": { 41 | "enabled": true 42 | }, 43 | "monitor_config": { 44 | "enabled": false 45 | }, 46 | "ckpt_config": { 47 | "enabled": false, 48 | "ckpt": "no_ckpt" 49 | } 50 | } -------------------------------------------------------------------------------- /src/gan_control/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/datasets/__init__.py -------------------------------------------------------------------------------- /src/gan_control/datasets/afhq_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # this file borrows from https://github.com/clovaai/stargan-v2/blob/master/core/data_loader.py 4 | 5 | from pathlib import Path 6 | from itertools import chain 7 | import os 8 | import random 9 | 10 | from PIL import Image 11 | import numpy as np 12 | 13 | import torch 14 | from torch.utils import data 15 | from torch.utils.data.sampler import WeightedRandomSampler 16 | from torchvision import transforms 17 | from torchvision.datasets import ImageFolder 18 | 19 | from gan_control.datasets.ffhq_dataset import data_sampler, sample_data 20 | from gan_control.utils.logging_utils import get_logger 21 | 22 | _log = get_logger(__name__) 23 | 24 | 25 | def listdir(dname): 26 | fnames = list(chain(*[list(Path(dname).rglob('*.' + ext)) 27 | for ext in ['png', 'jpg', 'jpeg', 'JPG']])) 28 | return fnames 29 | 30 | 31 | class AfhqDataset(data.Dataset): 32 | def __init__(self, root, transform=None): 33 | self.samples = listdir(os.path.join(root, 'train', 'dog')) 34 | self.samples = self.samples + listdir(os.path.join(root, 'val', 'dog')) 35 | self.samples.sort() 36 | self.transform = transform 37 | self.targets = None 38 | 39 | def __getitem__(self, index): 40 | fname = self.samples[index] 41 | img = Image.open(fname).convert('RGB') 42 | if self.transform is not None: 43 | img = self.transform(img) 44 | return img, (str(fname), str(fname)) 45 | 46 | def __len__(self): 47 | return len(self.samples) 48 | 49 | 50 | def get_afhq_data_loader(data_config, batch_size=4, size=512, training=True, prob=0.5): 51 | crop = transforms.RandomResizedCrop(size, scale=[0.8, 1.0], ratio=[0.9, 1.1]) 52 | rand_crop = transforms.Lambda(lambda x: crop(x) if random.random() < prob else x) 53 | compose_list = [] 54 | if training: compose_list.append(rand_crop) 55 | compose_list.append(transforms.Resize([size, size])) 56 | if training: compose_list.append(transforms.RandomHorizontalFlip()) 57 | compose_list.append(transforms.ToTensor()) 58 | compose_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)) 59 | transform = transforms.Compose(compose_list) 60 | dataset = AfhqDataset(data_config['path'], transform=transform) 61 | shuffle = True 62 | drop_last = True 63 | _log.info('init AFHQ data loader: image size:%s, batch size:%d, shuffle:%s, drop last:%s, num workers:%d' % (size, batch_size, str(shuffle), str(drop_last), data_config['workers'])) 64 | loader = data.DataLoader( 65 | dataset, 66 | batch_size=batch_size, 67 | sampler=data_sampler(dataset, shuffle=shuffle, distributed=False), 68 | drop_last=drop_last, 69 | num_workers=data_config['workers'] 70 | ) 71 | loader = sample_data(loader) 72 | return loader 73 | -------------------------------------------------------------------------------- /src/gan_control/datasets/dataframe_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from io import BytesIO 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torch.utils import data 9 | import torchvision 10 | from torchvision import transforms 11 | import pandas as pd 12 | 13 | from gan_control.utils.logging_utils import get_logger 14 | 15 | _log = get_logger(__name__) 16 | 17 | 18 | class DataFrameDataSet(Dataset): 19 | def __init__(self, daraframe_path, attribute=None, train=True): 20 | _log.info('Loading dataframe from: %s' % daraframe_path) 21 | self.train = train 22 | self.attributes_df = pd.read_pickle(daraframe_path) 23 | if train: 24 | self.attributes_df = self.attributes_df.iloc[:int(self.__len__() * 0.9)] 25 | else: 26 | self.attributes_df = self.attributes_df.iloc[int(self.__len__() * 0.9):] 27 | self.attribute = attribute 28 | if self.attribute is not None: 29 | self.attributes_df = self.attributes_df[['latents_w', attribute]] 30 | _log.info('Dataset length: %d' % self.__len__()) 31 | 32 | def __getitem__(self, index): 33 | attributes_series = self.attributes_df.iloc[index] 34 | if self.attribute in ['age']: 35 | attributes = torch.tensor(attributes_series[self.attribute]).unsqueeze(0) 36 | elif self.attribute in ['expression_q']: 37 | attributes = torch.nn.functional.one_hot(torch.tensor(attributes_series[self.attribute]), num_classes=8) 38 | else: 39 | attributes = torch.tensor(attributes_series[self.attribute]) 40 | return attributes, torch.tensor(attributes_series['latents_w']) 41 | 42 | def __len__(self): 43 | return len(self.attributes_df.latents_w) 44 | 45 | 46 | def get_dataframe_data_loader(daraframe_path, attribute, batch_size=32, shuffle=True, drop_last=True, workers=32, train=True): 47 | dataset = DataFrameDataSet(daraframe_path, attribute=attribute, train=train) 48 | _log.info('init dataframe data loader: batch size:%d, shuffle:%s, drop last:%s, num workers:%d' % (batch_size, str(shuffle), str(drop_last), workers)) 49 | loader = data.DataLoader( 50 | dataset, 51 | batch_size=batch_size, 52 | shuffle=shuffle, 53 | drop_last=drop_last, 54 | num_workers=workers 55 | ) 56 | return loader 57 | -------------------------------------------------------------------------------- /src/gan_control/datasets/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from io import BytesIO 5 | 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torch.utils import data 9 | import torchvision 10 | from torchvision import transforms 11 | 12 | from gan_control.utils.logging_utils import get_logger 13 | 14 | _log = get_logger(__name__) 15 | 16 | 17 | class FfhqData(torchvision.datasets.ImageFolder): 18 | def __init__(self, root, transform=None): 19 | super(FfhqData, self).__init__(root, transform=transform) 20 | 21 | def __getitem__(self, index): 22 | """ 23 | Args: 24 | index (int): Index 25 | 26 | Returns: 27 | tuple: (sample, target) where target is class_index of the target class. 28 | """ 29 | path, target = self.samples[index] 30 | sample = self.loader(path) 31 | if self.transform is not None: 32 | sample = self.transform(sample) 33 | if self.target_transform is not None: 34 | target = self.target_transform(target) 35 | 36 | return sample, (target, path) 37 | 38 | 39 | def data_sampler(dataset, shuffle, distributed): 40 | if distributed: 41 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 42 | 43 | if shuffle: 44 | return data.RandomSampler(dataset) 45 | 46 | else: 47 | return data.SequentialSampler(dataset) 48 | 49 | 50 | def sample_data(loader): 51 | while True: 52 | for batch in loader: 53 | yield batch 54 | 55 | 56 | def get_ffhq_data_loader(data_config, batch_size=4, size=1024, training=True): 57 | compose_list = [] 58 | if size != 1024: 59 | compose_list.append(transforms.Resize(size)) 60 | if training: 61 | compose_list.append(transforms.RandomHorizontalFlip()) 62 | compose_list.append(transforms.ToTensor()) 63 | compose_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)) 64 | transform = transforms.Compose(compose_list) 65 | 66 | #dataset = torchvision.datasets.ImageFolder(data_config['path'], transform=transform) 67 | dataset = FfhqData(data_config['path'], transform=transform) 68 | shuffle = True 69 | drop_last = True 70 | _log.info('init FFHQ data loader: image size:%s, batch size:%d, shuffle:%s, drop last:%s, num workers:%d' % (size, batch_size, str(shuffle), str(drop_last), data_config['workers'])) 71 | loader = data.DataLoader( 72 | dataset, 73 | batch_size=batch_size, 74 | sampler=data_sampler(dataset, shuffle=shuffle, distributed=False), 75 | drop_last=drop_last, 76 | num_workers=data_config['workers'] 77 | ) 78 | loader = sample_data(loader) 79 | return loader 80 | -------------------------------------------------------------------------------- /src/gan_control/datasets/merged_dataframe_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from io import BytesIO 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torch.utils import data 9 | import torchvision 10 | from torchvision import transforms 11 | import pandas as pd 12 | 13 | from gan_control.utils.logging_utils import get_logger 14 | 15 | _log = get_logger(__name__) 16 | 17 | 18 | class MergedDataFrameDataSet(Dataset): 19 | def __init__(self, daraframe_path, train=True): 20 | _log.info('Loading dataframe from: %s' % daraframe_path) 21 | self.train = train 22 | self.attributes_df = pd.read_pickle(daraframe_path) 23 | if train: 24 | self.attributes_df = self.attributes_df.iloc[:int(self.__len__() * 0.9)] 25 | else: 26 | self.attributes_df = self.attributes_df.iloc[int(self.__len__() * 0.9):] 27 | _log.info('Dataset length: %d' % self.__len__()) 28 | 29 | def __getitem__(self, index): 30 | attributes_series = self.attributes_df.iloc[index] 31 | output_dict = { 32 | 'arcface_emb': torch.tensor(attributes_series['arcface_emb']), 33 | 'orientation': torch.tensor(attributes_series['orientation']), 34 | 'gamma': torch.tensor(attributes_series['gamma3d']), 35 | 'hair': torch.tensor(attributes_series['hair']), 36 | 'age': torch.tensor(attributes_series['age']).unsqueeze(0), 37 | 'expression': torch.tensor(attributes_series['expression3d']), 38 | } 39 | return output_dict, torch.tensor(attributes_series['latents_w']) 40 | 41 | def __len__(self): 42 | return len(self.attributes_df.latents_w) 43 | 44 | 45 | def get_dataframe_data_loader(dataframe_path, batch_size=32, shuffle=True, drop_last=True, workers=32, train=True): 46 | dataset = MergedDataFrameDataSet(dataframe_path, train=train) 47 | _log.info('init dataframe data loader: batch size:%d, shuffle:%s, drop last:%s, num workers:%d' % (batch_size, str(shuffle), str(drop_last), workers)) 48 | loader = data.DataLoader( 49 | dataset, 50 | batch_size=batch_size, 51 | shuffle=shuffle, 52 | drop_last=drop_last, 53 | num_workers=workers 54 | ) 55 | return loader 56 | 57 | -------------------------------------------------------------------------------- /src/gan_control/datasets/metfaces_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from pathlib import Path 5 | from itertools import chain 6 | import os 7 | import random 8 | 9 | from PIL import Image 10 | import numpy as np 11 | 12 | import torch 13 | from torch.utils import data 14 | from torch.utils.data.sampler import WeightedRandomSampler 15 | from torchvision import transforms 16 | from torchvision.datasets import ImageFolder 17 | 18 | from gan_control.datasets.ffhq_dataset import data_sampler, sample_data 19 | from gan_control.utils.logging_utils import get_logger 20 | 21 | _log = get_logger(__name__) 22 | 23 | 24 | def listdir(dname): 25 | fnames = list(chain(*[list(Path(dname).rglob('*.' + ext)) 26 | for ext in ['png', 'jpg', 'jpeg', 'JPG']])) 27 | return fnames 28 | 29 | 30 | class MetFacesDataset(data.Dataset): 31 | def __init__(self, root, transform=None): 32 | self.samples = listdir(root) 33 | self.samples.sort() 34 | self.transform = transform 35 | self.targets = None 36 | 37 | def __getitem__(self, index): 38 | fname = self.samples[index] 39 | img = Image.open(fname).convert('RGB') 40 | if self.transform is not None: 41 | img = self.transform(img) 42 | return img, img 43 | 44 | def __len__(self): 45 | return len(self.samples) 46 | 47 | 48 | def get_metfaces_data_loader(data_config, batch_size=4, size=512, training=True, prob=0.5): 49 | #crop = transforms.RandomResizedCrop(size, scale=[0.8, 1.0], ratio=[0.9, 1.1]) 50 | #rand_crop = transforms.Lambda(lambda x: crop(x) if random.random() < prob else x) 51 | compose_list = [] 52 | #if training: compose_list.append(rand_crop) 53 | compose_list.append(transforms.Resize([size, size])) 54 | if training: compose_list.append(transforms.RandomHorizontalFlip()) 55 | compose_list.append(transforms.ToTensor()) 56 | compose_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)) 57 | transform = transforms.Compose(compose_list) 58 | dataset = MetFacesDataset(data_config['path'], transform=transform) 59 | shuffle = True 60 | drop_last = True 61 | _log.info('init metfaces data loader: image size:%s, batch size:%d, shuffle:%s, drop last:%s, num workers:%d' % (size, batch_size, str(shuffle), str(drop_last), data_config['workers'])) 62 | loader = data.DataLoader( 63 | dataset, 64 | batch_size=batch_size, 65 | sampler=data_sampler(dataset, shuffle=shuffle, distributed=False), 66 | drop_last=drop_last, 67 | num_workers=data_config['workers'] 68 | ) 69 | loader = sample_data(loader) 70 | return loader 71 | -------------------------------------------------------------------------------- /src/gan_control/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/evaluation/__init__.py -------------------------------------------------------------------------------- /src/gan_control/evaluation/age.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torchvision import transforms, utils 8 | from PIL import Image 9 | 10 | from gan_control.utils.ploting_utils import plot_bar, plot_hist 11 | from gan_control.utils.hopenet_utils import softmax_temperature, draw_axis 12 | from gan_control.utils.logging_utils import get_logger 13 | from gan_control.utils.pil_images_utils import create_image_grid_from_image_list, write_text_to_image 14 | 15 | _log = get_logger(__name__) 16 | 17 | 18 | def calc_age_from_tensor_images(age_loss_class, tensor_images): 19 | with torch.no_grad(): 20 | features_list = age_loss_class.calc_features(tensor_images) 21 | features = features_list[-1] 22 | ages = age_loss_class.last_layer_criterion.get_predict_age(features.cpu()) 23 | return ages 24 | 25 | 26 | def calc_and_write_age_to_image(age_loss_class, tensor_images): 27 | ages = calc_age_from_tensor_images(age_loss_class, tensor_images) 28 | tensor_images = tensor_images.mul(0.5).add(0.5).clamp(min=0., max=1.) 29 | images = [transforms.ToPILImage()(tensor_images[i]) for i in range(tensor_images.shape[0])] 30 | return write_age_to_image(images, ages) 31 | 32 | 33 | def write_age_to_image(images, ages): 34 | pil_images_with_ages = [] 35 | for image_num in range(len(images)): 36 | pil_image = write_text_to_image(images[image_num], 'age: %.2f' % (ages[image_num]), place=(10, 50)) 37 | pil_images_with_ages.append(pil_image) 38 | return pil_images_with_ages 39 | 40 | 41 | def make_age_hist(age_loss_class, loader=None, generator=None, number_os_samples=2000, batch_size=40, title=None, save_path=None): 42 | total_ages = np.array([]) 43 | for batch_num in tqdm(range(0, number_os_samples, batch_size)): 44 | if loader is not None: 45 | tensor_images, _ = next(loader) 46 | tensor_images = tensor_images.cuda() 47 | elif generator is not None: 48 | tensor_images = generator.gen_random() 49 | else: 50 | raise ValueError('loader and generator are None') 51 | ages = calc_age_from_tensor_images(age_loss_class, tensor_images) 52 | total_ages = np.concatenate([total_ages, ages], axis=0) 53 | plot_hist( 54 | [total_ages], 55 | title=title, 56 | labels=None, 57 | bins=151, 58 | plt_range=(0, 150), 59 | save_path=save_path 60 | ) 61 | return total_ages 62 | 63 | 64 | def make_ages_grid(ages_loss_class, tensor_images, nrow=6, save_path=None, downsample=None): 65 | pil_images_with_ages = calc_and_write_age_to_image(ages_loss_class, tensor_images) 66 | image_grid = create_image_grid_from_image_list(pil_images_with_ages, nrow=nrow) 67 | if downsample is not None: 68 | width, height = image_grid.size 69 | image_grid = transforms.Resize((width // downsample, height // downsample), interpolation=Image.BILINEAR)(image_grid) 70 | if save_path is not None: 71 | image_grid.save(save_path) 72 | return image_grid 73 | 74 | 75 | if __name__ == '__main__': 76 | import argparse 77 | from gan_control.datasets.ffhq_dataset import get_ffhq_data_loader 78 | from gan_control.utils.file_utils import read_json 79 | from gan_control.utils.ploting_utils import plot_hist 80 | from gan_control.losses.loss_model import LossModelClass 81 | 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--config_path', type=str, required=True) 84 | parser.add_argument('--batch_size', type=int, default=40) 85 | parser.add_argument('--number_os_samples', type=int, default=5000) 86 | args = parser.parse_args() 87 | config = read_json(args.config_path, return_obj=True) 88 | loader = get_ffhq_data_loader(config.data_config, batch_size=args.batch_size, training=True, size=config.model_config['size']) 89 | age_loss_class = LossModelClass(config.training_config['age_loss'], loss_name='age_loss', mini_batch_size=args.batch_size, device="cuda") 90 | 91 | make_age_hist(age_loss_class, loader=loader, number_os_samples=args.number_os_samples, batch_size=args.batch_size, title=None, save_path='path to save image') # TODO: 92 | tensor_images, _ = next(loader) 93 | make_ages_grid(age_loss_class, tensor_images[:8], nrow=4, save_path='path to save image') # TODO: 94 | 95 | -------------------------------------------------------------------------------- /src/gan_control/evaluation/expression.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torchvision import transforms, utils 8 | from PIL import Image 9 | 10 | from gan_control.utils.ploting_utils import plot_bar 11 | from gan_control.utils.hopenet_utils import softmax_temperature, draw_axis 12 | from gan_control.utils.logging_utils import get_logger 13 | from gan_control.utils.pil_images_utils import create_image_grid_from_image_list, write_text_to_image 14 | 15 | _log = get_logger(__name__) 16 | 17 | 18 | def get_class(idx): 19 | classes = { 20 | 0: 'Neutral', 21 | 1: 'Happy', 22 | 2: 'Sad', 23 | 3: 'Surprise', 24 | 4: 'Fear', 25 | 5: 'Disgust', 26 | 6: 'Anger', 27 | 7: 'Contempt'} 28 | 29 | return classes[idx] 30 | 31 | 32 | def calc_expression_from_features(features): 33 | batch_size, ensemble_size, emotion_num = features.shape 34 | emotion_votes = np.zeros((batch_size, emotion_num)) 35 | emotion = [features[:, i, :].cpu().detach().numpy() for i in range(ensemble_size)] 36 | batches = range(batch_size) 37 | for e in emotion: 38 | e_idx = np.argmax(e, 1) 39 | emotion_votes[batches, e_idx] += 1 40 | return np.argmax(emotion_votes, 1) 41 | 42 | 43 | def calc_expression_from_tensor_images(expression_loss_class, tensor_images): 44 | with torch.no_grad(): 45 | features_list = expression_loss_class.calc_features(tensor_images) 46 | features = features_list[-1] 47 | expressions = calc_expression_from_features(features) 48 | return expressions 49 | 50 | 51 | def calc_and_write_expression_to_image(expression_loss_class, tensor_images): 52 | expressions = calc_expression_from_tensor_images(expression_loss_class, tensor_images) 53 | tensor_images = tensor_images.mul(0.5).add(0.5).clamp(min=0., max=1.) 54 | images = [transforms.ToPILImage()(tensor_images[i]) for i in range(tensor_images.shape[0])] 55 | return write_expression_to_image(images, expressions) 56 | 57 | 58 | def write_expression_to_image(images, expressions): 59 | pil_images_with_expressions = [] 60 | for image_num in range(len(images)): 61 | pil_image = write_text_to_image(images[image_num], get_class(expressions[image_num])) 62 | pil_images_with_expressions.append(pil_image) 63 | return pil_images_with_expressions 64 | 65 | 66 | def make_expression_bar(expression_loss_class, loader=None, generator=None, number_os_samples=2000, batch_size=40, title=None, save_path=None): 67 | total_expressions = np.array([]) 68 | for batch_num in tqdm(range(0, number_os_samples, batch_size)): 69 | if loader is not None: 70 | tensor_images, _ = next(loader) 71 | tensor_images = tensor_images.cuda() 72 | elif generator is not None: 73 | tensor_images = generator.gen_random() 74 | else: 75 | raise ValueError('loader and generator are None') 76 | expressions = calc_expression_from_tensor_images(expression_loss_class, tensor_images) 77 | total_expressions = np.concatenate([total_expressions, expressions], axis=0) 78 | plot_bar( 79 | [total_expressions], 80 | [get_class(i) for i in range(8)], 81 | title=title, 82 | labels=None, 83 | save_path=save_path 84 | ) 85 | return total_expressions 86 | 87 | 88 | def make_expression_grid(expression_loss_class, tensor_images, nrow=6, save_path=None, downsample=None): 89 | pil_images_with_expressions = calc_and_write_expression_to_image(expression_loss_class, tensor_images) 90 | image_grid = create_image_grid_from_image_list(pil_images_with_expressions, nrow=nrow) 91 | if downsample is not None: 92 | width, height = image_grid.size 93 | image_grid = transforms.Resize((width // downsample, height // downsample), interpolation=Image.BILINEAR)(image_grid) 94 | if save_path is not None: 95 | image_grid.save(save_path) 96 | return image_grid 97 | 98 | 99 | if __name__ == '__main__': 100 | import argparse 101 | from gan_control.datasets.ffhq_dataset import get_ffhq_data_loader 102 | from gan_control.utils.file_utils import read_json 103 | from gan_control.utils.ploting_utils import plot_hist 104 | from gan_control.losses.loss_model import LossModelClass 105 | 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument('--config_path', type=str, required=True) 108 | parser.add_argument('--batch_size', type=int, default=40) 109 | parser.add_argument('--number_os_samples', type=int, default=70000) 110 | args = parser.parse_args() 111 | config = read_json(args.config_path, return_obj=True) 112 | loader = get_ffhq_data_loader(config.data_config, batch_size=args.batch_size, training=True, size=config.model_config['size']) 113 | expression_loss_class = None 114 | expression_loss_class = LossModelClass(config.training_config['expression_loss'], loss_name='expression_loss', mini_batch_size=args.batch_size, device="cuda") 115 | 116 | make_expression_bar(expression_loss_class, loader=loader, number_os_samples=args.number_os_samples, batch_size=args.batch_size, title=None, save_path='path to save') # TODO: 117 | tensor_images, _ = next(loader) 118 | make_expression_grid(expression_loss_class, tensor_images[:8], nrow=4, save_path='path to save') # TODO: 119 | 120 | -------------------------------------------------------------------------------- /src/gan_control/evaluation/extract_recon_3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/evaluation/extract_recon_3d/__init__.py -------------------------------------------------------------------------------- /src/gan_control/evaluation/extract_recon_3d/disentanglement_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from pathlib import Path 5 | from itertools import chain 6 | import os 7 | import random 8 | 9 | from PIL import Image 10 | import numpy as np 11 | 12 | import torch 13 | from torch.utils import data 14 | from torchvision import transforms 15 | 16 | 17 | from igt_res_gan.utils.logging_utils import get_logger 18 | 19 | _log = get_logger(__name__) 20 | 21 | 22 | def listdir(dname): 23 | fnames = list(chain(*[list(Path(dname).rglob('*.' + ext)) 24 | for ext in ['png', 'jpg', 'jpeg', 'JPG']])) 25 | return fnames 26 | 27 | 28 | class DisentanglementDataset(data.Dataset): 29 | def __init__(self, root, transform=None): 30 | self.samples = listdir(os.path.join(root)) 31 | self.samples.sort() 32 | self.transform = transform 33 | self.targets = None 34 | 35 | def __getitem__(self, index): 36 | fname = self.samples[index] 37 | img = Image.open(fname).convert('RGB') 38 | im_name = os.path.split(fname)[1] 39 | uj = int(im_name.split('_')[0]) 40 | ui = int(im_name.split('_')[1].split('.')[0]) 41 | if self.transform is not None: 42 | img = self.transform(img) 43 | return img, uj, ui 44 | 45 | def __len__(self): 46 | return len(self.samples) 47 | 48 | 49 | def get_disentanglement_data_loader(data_config, batch_size=4): 50 | compose_list = [] 51 | compose_list.append(transforms.ToTensor()) 52 | compose_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)) 53 | transform = transforms.Compose(compose_list) 54 | dataset = DisentanglementDataset(data_config['path'], transform=transform) 55 | shuffle = False 56 | _log.info('init Disentanglement data loader: batch size:%d, shuffle:%s, num workers:%d' % (batch_size, str(shuffle), data_config['workers'])) 57 | loader = data.DataLoader( 58 | dataset, 59 | batch_size=batch_size, 60 | num_workers=data_config['workers'] 61 | ) 62 | return loader 63 | -------------------------------------------------------------------------------- /src/gan_control/evaluation/extract_recon_3d/extract_recon_3d.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | from tqdm import tqdm 6 | import pandas as pd 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def calc_vectors_mean_and_std(vecs_list, all_vs_all=True): 12 | vecs = np.array(vecs_list) 13 | mean_vecs = vecs.mean(axis=0) 14 | if not all_vs_all: 15 | distances = np.sqrt(np.sum(np.power(vecs - mean_vecs, 2), axis=1)) 16 | else: 17 | distances_list = [] 18 | sig = torch.tensor(vecs).unsqueeze(0) 19 | gue_chunks = torch.tensor(vecs).split(100, dim=0) 20 | for gue_chunk in gue_chunks: 21 | gue_chunk = gue_chunk.unsqueeze(1) 22 | distances = torch.pow(gue_chunk - sig, 2) 23 | distances = torch.sqrt(torch.sum(distances, dim=-1)) 24 | distances_list.append(distances) 25 | distances = torch.cat(distances_list, dim=0) 26 | mask = np.tril(np.ones([len(vecs_list), len(vecs_list)]), -1) == 1 27 | distances = distances.numpy()[mask] 28 | 29 | return mean_vecs, distances.mean() 30 | -------------------------------------------------------------------------------- /src/gan_control/evaluation/face_alignment_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/evaluation/face_alignment_utils/__init__.py -------------------------------------------------------------------------------- /src/gan_control/evaluation/face_alignment_utils/face_alignment_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import cv2 5 | import face_alignment 6 | from skimage import io 7 | import numpy as np 8 | from PIL import Image 9 | from scipy.io import loadmat 10 | import torch 11 | from torchvision import utils, transforms 12 | 13 | 14 | def make_68_ln_to_5_lm(Lm3D): 15 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 16 | Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(Lm3D[lm_idx[[3, 4]], :], 0), 17 | Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) 18 | Lm3D = Lm3D[[1, 2, 0, 3, 4], :] 19 | 20 | return Lm3D 21 | 22 | 23 | def load_lm3d(): 24 | path = 'path_to_similarity_Lm3D_all.mat' 25 | assert path != 'path_to_similarity_Lm3D_all.mat', 'download similarity_Lm3D_all.mat from https://github.com/microsoft/Deep3DFaceReconstruction/blob/master/BFM/similarity_Lm3D_all.mat' 26 | Lm3D = loadmat('path') 27 | Lm3D = Lm3D['lm'] 28 | 29 | # calculate 5 facial landmarks using 68 landmarks 30 | return make_68_ln_to_5_lm(Lm3D) 31 | 32 | 33 | #calculating least square problem 34 | def POS(xp,x): 35 | npts = xp.shape[1] 36 | 37 | A = np.zeros([2*npts,8]) 38 | 39 | A[0:2*npts-1:2,0:3] = x.transpose() 40 | A[0:2*npts-1:2,3] = 1 41 | 42 | A[1:2*npts:2,4:7] = x.transpose() 43 | A[1:2*npts:2,7] = 1; 44 | 45 | b = np.reshape(xp.transpose(),[2*npts,1]) 46 | 47 | k,_,_,_ = np.linalg.lstsq(A,b) 48 | 49 | R1 = k[0:3] 50 | R2 = k[4:7] 51 | sTx = k[3] 52 | sTy = k[7] 53 | s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 54 | t = np.stack([sTx,sTy],axis = 0) 55 | 56 | return t,s 57 | 58 | 59 | def process_img(img,lm,t,s,target_size = 224.): 60 | w0,h0 = img.size 61 | w = (w0/s*102).astype(np.int32) 62 | h = (h0/s*102).astype(np.int32) 63 | img = img.resize((w,h),resample = Image.BICUBIC) 64 | 65 | left = (w/2 - target_size/2 + float((t[0] - w0/2)*102/s)).astype(np.int32) 66 | right = left + target_size 67 | up = (h/2 - target_size/2 + float((h0/2 - t[1])*102/s)).astype(np.int32) 68 | below = up + target_size 69 | 70 | img = img.crop((left,up,right,below)) 71 | img = np.array(img) 72 | #img = img[:,:,::-1] #RGBtoBGR 73 | img = np.expand_dims(img,0) 74 | lm = np.stack([lm[:,0] - t[0] + w0/2,lm[:,1] - t[1] + h0/2],axis = 1)/s*102 75 | lm = lm - np.reshape(np.array([(w/2 - target_size/2),(h/2-target_size/2)]),[1,2]) 76 | 77 | return img,lm 78 | 79 | 80 | # resize and crop input images before sending to the R-Net 81 | def Preprocess(img,lm,lm3D,crop_size=224): 82 | 83 | w0,h0 = img.size 84 | 85 | # change from image plane coordinates to 3D sapce coordinates(X-Y plane) 86 | lm = np.stack([lm[:,0],h0 - 1 - lm[:,1]], axis = 1) 87 | 88 | # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face 89 | t,s = POS(lm.transpose(),lm3D.transpose()) 90 | 91 | # processing the image 92 | img_new,lm_new = process_img(img,lm,t,s,target_size=crop_size) 93 | lm_new = np.stack([lm_new[:,0],223 - lm_new[:,1]], axis = 1) 94 | trans_params = np.array([w0,h0,102.0/s,t[0],t[1]]) 95 | 96 | return img_new,lm_new,trans_params 97 | 98 | 99 | def align_face_by_image_path(path, fa=None, size=None, lm3D=None): 100 | input = io.imread(path) 101 | preds, source_image, algin_image = align_face_by_image(input, fa=fa, size=size, lm3D=lm3D) 102 | return preds, source_image, algin_image 103 | 104 | 105 | def align_face_by_image(input, fa=None, size=None, lm3D=None, crop_size=224): 106 | input_is_tensor = False 107 | if isinstance(input, torch.Tensor): 108 | input_is_tensor = True 109 | if input.min() < 0: 110 | input = input.mul(0.5).add(0.5).clamp(min=0., max=1.) * 255 111 | input = input.numpy().astype('uint8') 112 | input = input.swapaxes(0, 1).swapaxes(1, 2) 113 | if fa is None: 114 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False) 115 | if size is not None: 116 | input = cv2.resize(input, dsize=(size, size), interpolation=cv2.INTER_CUBIC) 117 | if lm3D is None: 118 | lm3D = load_lm3d() 119 | with torch.no_grad(): 120 | try: 121 | preds = fa.get_landmarks(input) 122 | except: 123 | preds = None 124 | if preds is not None: 125 | img_new, lm_new, trans_params = Preprocess(Image.fromarray(input), make_68_ln_to_5_lm(preds[0]), lm3D, crop_size=crop_size) 126 | img_new = img_new[0] 127 | else: 128 | img_new = cv2.resize(input, dsize=(224, 224), interpolation=cv2.INTER_CUBIC) 129 | if input_is_tensor: 130 | img_new = img_new.swapaxes(1, 2).swapaxes(0, 1) 131 | img_new = torch.tensor(img_new).float().div(255).add(-0.5).mul(2) 132 | img_new = img_new.unsqueeze(0) 133 | return preds, input, img_new 134 | 135 | 136 | def align_tensor_images(tensor, fa=None, lm3D=None, crop_size=224): 137 | is_cuda = tensor.device.type == 'cuda' 138 | tensor_list = [] 139 | for i in range(tensor.shape[0]): 140 | preds, input, align_image = align_face_by_image(tensor[i].cpu(), fa=fa, lm3D=lm3D, crop_size=crop_size) 141 | tensor_list.append(align_image) 142 | out_tensor = torch.cat(tensor_list, dim=0) 143 | if is_cuda: 144 | out_tensor = out_tensor.cuda() 145 | return out_tensor 146 | 147 | 148 | def paint_pred_on_face(image, pred): 149 | for i in range(pred.shape[0]): 150 | image[int(pred[i, 1]), int(pred[i, 0]), 0] = 255 151 | image[int(pred[i, 1]), int(pred[i, 0]), 1] = 0 152 | image[int(pred[i, 1]), int(pred[i, 0]), 2] = 0 153 | return image 154 | 155 | -------------------------------------------------------------------------------- /src/gan_control/evaluation/gan_evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/evaluation/gan_evaluation/__init__.py -------------------------------------------------------------------------------- /src/gan_control/evaluation/gan_evaluation/error_bar_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def plot_error_bar(xs, ys, xlabel='layer #', ylabel='abs mean', legends=None, save_path=None, title=None, mean_abs=False): 9 | markers = ['^', 'o', 's', '+', 'x', '>', '<', '*'] 10 | for i, (x, y, marker) in enumerate(zip(xs, ys, markers)): 11 | x_means = [np.array(xi).mean() for xi in x] 12 | x_std = [np.array(xi).std() if len(xi) > 1 else 0 for xi in x] 13 | 14 | if mean_abs: 15 | y_means = [np.abs(np.array(yi)).mean() for yi in y] 16 | else: 17 | y_means = [np.array(yi).mean() for yi in y] 18 | y_std = [np.array(yi).std() if len(yi) > 1 else 0 for yi in y] 19 | 20 | plt.errorbar(x_means, y_means, xerr=x_std, yerr=y_std, linestyle='None', marker='^') 21 | 22 | 23 | plt.ylabel(ylabel) 24 | plt.xlabel(xlabel) 25 | if legends is not None: 26 | plt.legend(legends) 27 | if title is not None: 28 | plt.title(title) 29 | plt.show() 30 | if save_path is not None: 31 | plt.savefig(save_path) -------------------------------------------------------------------------------- /src/gan_control/evaluation/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchvision import transforms, utils 7 | from PIL import Image 8 | 9 | from gan_control.utils.logging_utils import get_logger 10 | 11 | _log = get_logger(__name__) 12 | 13 | 14 | def gen_grid(model, latent, injection_noise=None, nrow=4, downsample=None): 15 | with torch.no_grad(): 16 | output_tensor, _ = model([latent], noise=injection_noise) 17 | image_tensor = output_tensor.mul(0.5).add(0.5).clamp(min=0., max=1.).cpu() 18 | image = transforms.ToPILImage()(utils.make_grid(image_tensor, nrow=nrow)) 19 | if downsample is not None: 20 | width, height = image.size 21 | image = transforms.Resize((width // downsample, height // downsample), interpolation=Image.BILINEAR)(image) 22 | return image 23 | 24 | 25 | def make_noise_id_pose_matrix(model, ids_in_row=6, pose_in_col=6, device='cpu', id_chunk=(256, 512)): 26 | ids = [] 27 | poses = [] 28 | noises = [] 29 | z_sampels = [] 30 | start_list = list(range(id_chunk[0])) 31 | end_list = list(range(id_chunk[1], 512)) 32 | same_id_chunk = list(range(id_chunk[0], id_chunk[1])) 33 | same_pose_chunks = start_list + end_list 34 | for i in range(ids_in_row): 35 | sample_z = torch.randn(1, 512, device=device) 36 | ids.append(sample_z[0, same_id_chunk].unsqueeze(dim=0)) 37 | poses.append(sample_z[0, same_pose_chunks].unsqueeze(dim=0)) 38 | for row in range(pose_in_col): 39 | for col in range(ids_in_row): 40 | canvas = torch.zeros_like(torch.cat([poses[col], ids[row]], dim=1)) 41 | canvas[:, same_id_chunk] = ids[row] 42 | canvas[:, same_pose_chunks] = poses[col] 43 | z_sampels.append(canvas) 44 | if isinstance(model, nn.DataParallel): 45 | noises = [model.module.make_noise(device=device) for _ in range(ids_in_row)] 46 | else: 47 | noises = [model.make_noise(device=device) for _ in range(ids_in_row)] 48 | return z_sampels, noises 49 | 50 | 51 | @torch.no_grad() 52 | def gen_matrix( 53 | model, 54 | ids_in_row=6, 55 | pose_in_col=6, 56 | latents=None, 57 | injection_noises=None, 58 | device='cuda', 59 | same_noise_per_id=False, 60 | downsample=None, 61 | return_list=False, 62 | same_chunk=(256, 512), 63 | same_noise_for_all=False 64 | ): 65 | if same_noise_per_id and same_noise_for_all: 66 | _log.warning('same_noise_for_all and same_noise_for_all is True -> same_noise_for_all') 67 | injection_noise = None 68 | injection_num = 0 69 | if latents is None or injection_noises is None: 70 | temp_latents, temp_injection_noises = make_noise_id_pose_matrix(model, ids_in_row=ids_in_row, pose_in_col=pose_in_col, device='cpu', id_chunk=same_chunk) 71 | if latents is None: 72 | latents = temp_latents 73 | if injection_noises is None: 74 | injection_noises = temp_injection_noises 75 | if same_noise_per_id or same_noise_for_all: 76 | injection_noise = injection_noises[injection_num] 77 | injection_noise = [injection_noise[n].cuda() for n in range(len(injection_noise))] 78 | injection_num += 1 79 | total_sample, _ = model([latents[0].cuda()], noise=injection_noise) 80 | for pic_num in range(1, ids_in_row * pose_in_col): 81 | if same_noise_per_id and (pic_num % ids_in_row == 0): 82 | injection_noise = injection_noises[pic_num] 83 | injection_noise = [injection_noise[n].cuda() for n in range(len(injection_noise))] 84 | injection_num += 1 85 | sample, _ = model([latents[pic_num].cuda()], noise=injection_noise) 86 | total_sample = torch.cat([total_sample, sample], dim=0) 87 | if return_list: 88 | return total_sample.cpu() 89 | total_sample = utils.make_grid(total_sample.mul(0.5).add(0.5).clamp(min=0., max=1.).cpu(), nrow=ids_in_row) 90 | image = transforms.ToPILImage()(total_sample) 91 | if downsample is not None: 92 | width, height = image.size 93 | image = transforms.Resize((width // downsample, height // downsample), interpolation=Image.BILINEAR)(image) 94 | return image 95 | 96 | 97 | class IterableModel(): 98 | def __init__(self, model, same_noise_for_same_id=False, batch_size=20): 99 | self.model = model 100 | self.same_noise_for_same_id=same_noise_for_same_id 101 | self.batch_size = batch_size 102 | 103 | def gen_random(self): 104 | random_latent = torch.randn(self.batch_size, 512, device='cuda') 105 | output, _ = self.model([random_latent.cuda()], noise=None) 106 | return output 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /src/gan_control/evaluation/hair.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | import numpy as np 8 | from tqdm import tqdm 9 | from torchvision import transforms, utils 10 | from PIL import Image 11 | 12 | from gan_control.utils.ploting_utils import plot_bar, plot_hist 13 | from gan_control.utils.hopenet_utils import softmax_temperature, draw_axis 14 | from gan_control.utils.logging_utils import get_logger 15 | from gan_control.utils.pil_images_utils import create_image_grid_from_image_list, write_text_to_image 16 | 17 | _log = get_logger(__name__) 18 | 19 | 20 | def calc_hair_color_from_images(hair_loss_class, tensor_images): 21 | features = hair_loss_class.calc_features(tensor_images)[-1] 22 | return hair_loss_class.last_layer_criterion.predict(features) 23 | 24 | def calc_hair_mask_from_images(hair_loss_class, tensor_images): 25 | with torch.no_grad(): 26 | features_list = hair_loss_class.calc_features(tensor_images) 27 | features = features_list[-1] 28 | return features[:,3:,:,:] 29 | 30 | 31 | def calc_and_add_hair_to_image(hair_loss_class, tensor_images): 32 | mask = calc_hair_mask_from_images(hair_loss_class, tensor_images) 33 | b, c, h, w = mask.shape 34 | tensor_images = F.interpolate(tensor_images, size=(h, w), mode='bilinear', align_corners=True) 35 | tensor_images = tensor_images.cpu() 36 | tensor_images[:,2:3,:,:] = tensor_images[:,2:3,:,:].cpu() + mask.cpu() 37 | tensor_images = tensor_images.mul(0.5).add(0.5).clamp(min=0., max=1.) 38 | images = [transforms.ToPILImage()(tensor_images[i]) for i in range(tensor_images.shape[0])] 39 | return images 40 | 41 | 42 | def make_hair_seg_grid(hair_loss_class, tensor_images, nrow=6, save_path=None, downsample=None): 43 | pil_images_with_hair = calc_and_add_hair_to_image(hair_loss_class, tensor_images) 44 | image_grid = create_image_grid_from_image_list(pil_images_with_hair, nrow=nrow) 45 | if downsample is not None: 46 | width, height = image_grid.size 47 | image_grid = transforms.Resize((width // downsample, height // downsample), interpolation=Image.BILINEAR)(image_grid) 48 | if save_path is not None: 49 | image_grid.save(save_path) 50 | return image_grid 51 | 52 | 53 | def add_colors_to_images(preds, tensor_images, target_pred=None): 54 | tensor_images = tensor_images.mul(0.5).add(0.5).clamp(min=0., max=1.) 55 | tensor_images[:, :, :40, :40] = preds.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 40, 40).clone() 56 | if target_pred is not None: 57 | tensor_images[:, :, :40, 40:45] = torch.zeros_like(tensor_images)[:, :, :40, 40:45] 58 | tensor_images[:, :, :40, 45:85] = target_pred.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 40, 40).clone() 59 | tensor_grid = utils.make_grid(tensor_images, nrow=6) 60 | return transforms.ToPILImage()(tensor_grid.cpu()) 61 | 62 | 63 | def make_hair_color_grid(hair_loss_class, tensor_images, nrow=6, save_path=None, downsample=None, target_pred=None): 64 | preds = calc_hair_color_from_images(hair_loss_class, tensor_images) 65 | image_grid = add_colors_to_images(preds, tensor_images, target_pred=target_pred) 66 | if downsample is not None: 67 | width, height = image_grid.size 68 | image_grid = transforms.Resize((width // downsample, height // downsample), interpolation=Image.BILINEAR)(image_grid) 69 | if save_path is not None: 70 | image_grid.save(save_path) 71 | return image_grid 72 | 73 | 74 | if __name__ == '__main__': 75 | import argparse 76 | from gan_control.datasets.ffhq_dataset import get_ffhq_data_loader 77 | from gan_control.utils.file_utils import read_json 78 | from gan_control.losses.loss_model import LossModelClass 79 | 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--config_path', type=str, required=True) 82 | parser.add_argument('--batch_size', type=int, default=40) 83 | parser.add_argument('--number_os_samples', type=int, default=5000) 84 | args = parser.parse_args() 85 | config = read_json(args.config_path, return_obj=True) 86 | loader = get_ffhq_data_loader(config.data_config, batch_size=args.batch_size, training=True, size=config.model_config['size']) 87 | age_loss_class = LossModelClass(config.training_config['hair_loss'], loss_name='hair_loss', mini_batch_size=args.batch_size, device="cuda") 88 | 89 | tensor_images, _ = next(loader) 90 | make_hair_seg_grid(age_loss_class, tensor_images[:8], nrow=4, save_path='path to save image') # TODO: 91 | 92 | -------------------------------------------------------------------------------- /src/gan_control/evaluation/recon_3d.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torchvision import transforms, utils 8 | from PIL import Image 9 | 10 | from gan_control.utils.ploting_utils import plot_hist 11 | from gan_control.utils.hopenet_utils import softmax_temperature, draw_axis 12 | from gan_control.utils.logging_utils import get_logger 13 | from gan_control.utils.pil_images_utils import create_image_grid_from_image_list 14 | 15 | _log = get_logger(__name__) 16 | 17 | 18 | def calc_recon_3d_from_tensor_images(recon_3d_loss_class, tensor_images): 19 | with torch.no_grad(): 20 | features_list = recon_3d_loss_class.calc_features(tensor_images) 21 | id_futures, ex_futures, tex_futures, angles_futures, gamma_futures, xy_futures, z_futures = recon_3d_loss_class.skeleton_model.module.extract_futures_from_vec(features_list) 22 | return id_futures[0].detach().cpu(), ex_futures[0].detach().cpu(), tex_futures[0].detach().cpu(), angles_futures[0].detach().cpu(), gamma_futures[0].detach().cpu(), xy_futures[0].detach().cpu(), z_futures[0].detach().cpu() 23 | 24 | 25 | def evaluate_recon_3d(recon_3d_loss_class, tensor): 26 | id_futures, ex_futures, tex_futures, angles_futures, gamma_futures, xy_futures, z_futures = calc_recon_3d_from_tensor_images(recon_3d_loss_class, tensor) 27 | tensor_images = tensor.mul(0.5).add(0.5).clamp(min=0., max=1.) 28 | angles_futures = angles_futures.detach().numpy() 29 | orientation_images = draw_orientation_to_tensor_images(tensor_images, angles_futures[:,1], -angles_futures[:,0], angles_futures[:,2]) 30 | return orientation_images 31 | 32 | 33 | def draw_orientation_to_tensor_images(tensor_images, yaw_predicted, pitch_predicted, roll_predicted): 34 | pil_images_with_orientation = [] 35 | for tensor_num in range(tensor_images.shape[0]): 36 | pil_image = transforms.ToPILImage()(tensor_images[tensor_num]) 37 | draw_axis(pil_image, yaw_predicted[tensor_num], pitch_predicted[tensor_num], roll_predicted[tensor_num], radians=True) 38 | pil_images_with_orientation.append(pil_image) 39 | return pil_images_with_orientation 40 | 41 | 42 | def make_orientation_hist(recon_3d_loss_class, loader=None, generator=None, number_os_samples=2000, batch_size=40, title=None, save_path=None): 43 | total_yaw_predicted, total_pitch_predicted, total_roll_predicted = torch.tensor([]), torch.tensor([]), torch.tensor([]) 44 | for batch_num in tqdm(range(0, number_os_samples, batch_size)): 45 | if loader is not None: 46 | tensor_images, _ = next(loader) 47 | tensor_images = tensor_images.cuda() 48 | elif generator is not None: 49 | tensor_images = generator.gen_random() 50 | else: 51 | raise ValueError('loader and generator are None') 52 | id_futures, ex_futures, tex_futures, angles_futures, gamma_futures, xy_futures, z_futures = calc_recon_3d_from_tensor_images(recon_3d_loss_class, tensor_images) 53 | yaw_predicted = angles_futures[:, 1] 54 | pitch_predicted = -angles_futures[:, 0] 55 | roll_predicted = angles_futures[:, 2] 56 | total_yaw_predicted = torch.cat([total_yaw_predicted, yaw_predicted], dim=0) 57 | total_pitch_predicted = torch.cat([total_pitch_predicted, pitch_predicted], dim=0) 58 | total_roll_predicted = torch.cat([total_roll_predicted, roll_predicted], dim=0) 59 | yaw = total_yaw_predicted.squeeze().numpy() 60 | pitch = total_pitch_predicted.squeeze().numpy() 61 | roll = total_roll_predicted.squeeze().numpy() 62 | arrays = [yaw, pitch, roll] 63 | plot_hist( 64 | arrays, 65 | title=title, 66 | labels=['yaw', 'pitch', 'roll'], 67 | xlabel='Angles [radians]', 68 | bins=100, 69 | ncol=3, 70 | percentiles=(0.2, 0.5, 0.8), 71 | min_lim=-1000, 72 | max_lim=1000, 73 | save_path=save_path 74 | ) 75 | return yaw, pitch, roll 76 | 77 | 78 | def make_orientation_grid(recon_3d_loss_class, tensor_images, nrow=6, save_path=None, downsample=None): 79 | pil_images_with_orientation = evaluate_recon_3d(recon_3d_loss_class, tensor_images) 80 | image_grid = create_image_grid_from_image_list(pil_images_with_orientation, nrow=nrow) 81 | if downsample is not None: 82 | width, height = image_grid.size 83 | image_grid = transforms.Resize((width // downsample, height // downsample), interpolation=Image.BILINEAR)(image_grid) 84 | if save_path is not None: 85 | image_grid.save(save_path) 86 | return image_grid 87 | 88 | 89 | if __name__ == '__main__': 90 | import argparse 91 | from gan_control.datasets.ffhq_dataset import get_ffhq_data_loader 92 | from gan_control.utils.file_utils import read_json 93 | from gan_control.utils.ploting_utils import plot_hist 94 | from gan_control.losses.loss_model import LossModelClass 95 | 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument('--config_path', type=str, default='../configs/ffhq.json') 98 | parser.add_argument('--batch_size', type=int, default=40) 99 | parser.add_argument('--number_os_samples', type=int, default=2000) 100 | args = parser.parse_args() 101 | config = read_json(args.config_path, return_obj=True) 102 | loader = get_ffhq_data_loader(config.data_config, batch_size=args.batch_size, training=True, size=config.model_config['size']) 103 | orientation_loss_model = LossModelClass(config.training_config['recon_3d_loss'], loss_name='recon_3d_loss', mini_batch_size=args.batch_size, device="cuda") 104 | 105 | make_orientation_hist(orientation_loss_model, loader=loader, number_os_samples=args.number_os_samples, batch_size=args.batch_size, title=None, save_path='path to save image') # TODO: 106 | tensor_images, _ = next(loader) 107 | make_orientation_grid(orientation_loss_model, tensor_images[:8], nrow=4, save_path='path to save image') # TODO: 108 | 109 | -------------------------------------------------------------------------------- /src/gan_control/evaluation/separability.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torchvision import transforms, utils 8 | 9 | from gan_control.utils.ploting_utils import plot_hist 10 | 11 | from gan_control.utils.logging_utils import get_logger 12 | 13 | _log = get_logger(__name__) 14 | 15 | 16 | def re_arrange_inject_noise(noises): 17 | for i in range(0, noises[0].shape[0], 2): 18 | for j in range(len(noises)): 19 | noises[j][i + 1, :, :, :] = noises[j][i, :, :, :].detach() 20 | return noises 21 | 22 | 23 | def compute_half_same_ids_embeddings_from_generator( 24 | generator, 25 | loss_model_class, 26 | num_of_samples, 27 | same_noise_for_same_id=False, 28 | return_images=False, 29 | same_chunk=(256, 512) 30 | ): 31 | images_list = [] 32 | inject_noise = None 33 | latents = torch.randn([num_of_samples, 512]) 34 | #start_list = list(range(same_chunk[0])) 35 | #end_list = list(range(same_chunk[1], 512)) 36 | same_id_chunk = slice(same_chunk[0], same_chunk[1], 1) 37 | #same_pose_chunks = start_list + end_list 38 | even_half = slice(0, num_of_samples, 2) 39 | odd_half = slice(1, num_of_samples, 2) 40 | latents[odd_half, same_id_chunk] = latents[even_half, same_id_chunk] 41 | 42 | #if same_latent_side == 'id': 43 | # latents[odd_half, 256:] = latents[even_half, 256:] 44 | #elif same_latent_side == 'pose': 45 | # latents[odd_half, :256] = latents[even_half, :256] 46 | #else: 47 | # raise ValueError('same_latent_side is %s (not in [id, pose])' % same_latent_side) 48 | out_latents = latents.numpy() 49 | latents = latents.chunk(num_of_samples // 20, dim=0) 50 | with torch.no_grad(): 51 | for i in tqdm(range(num_of_samples // len(latents[0]))): 52 | if same_noise_for_same_id: 53 | if isinstance(generator, torch.nn.DataParallel): 54 | inject_noise = generator.module.make_noise(batch_size=latents[i].shape[0]) 55 | else: 56 | inject_noise = generator.make_noise(batch_size=latents[i].shape[0]) 57 | inject_noise = re_arrange_inject_noise(inject_noise) 58 | fake_images, _ = generator([latents[i].cuda()], noise=inject_noise) 59 | if return_images: 60 | images_list += [transforms.ToPILImage()(fake_images[i].mul(0.5).add(0.5).clamp(min=0.,max=1.).cpu()) for i in range(fake_images.shape[0])] 61 | if i == 0: 62 | embeddings = loss_model_class.calc_features(fake_images) 63 | embeddings = [embeddings[n].cpu() for n in range(len(embeddings))] 64 | else: 65 | feture_list = loss_model_class.calc_features(fake_images) 66 | feture_list = [feture_list[n].cpu() for n in range(len(feture_list))] 67 | embeddings = [torch.cat([embeddings[j], feture_list[j]], dim=0) for j in range(len(embeddings))] 68 | _log.info('extracted %d embeddings' % embeddings[0].shape[0]) 69 | embeddings = [torch.cat([embeddings[n][even_half], embeddings[n][odd_half]], dim=0) for n in range(len(embeddings))] 70 | if return_images: 71 | images_list = images_list[even_half] + images_list[odd_half] 72 | return embeddings, images_list 73 | 74 | 75 | def calc_separability( 76 | generator, 77 | loss_model_class, 78 | same_noise_for_same_id=False, 79 | num_of_samples=2000, 80 | save_path=None, 81 | title=None, 82 | return_images=False, 83 | same_chunk=(256, 512), 84 | last_layer_separability_only=False 85 | ): 86 | embeddings, images_list = compute_half_same_ids_embeddings_from_generator( 87 | generator, 88 | loss_model_class, 89 | num_of_samples, 90 | same_noise_for_same_id=same_noise_for_same_id, 91 | return_images=return_images, 92 | same_chunk=same_chunk 93 | ) 94 | signatures_embs = [embeddings[i][:embeddings[i].shape[0] // 2] for i in range(len(embeddings))] 95 | queries_embs = [embeddings[i][embeddings[i].shape[0] // 2:] for i in range(len(embeddings))] 96 | signature_pids = np.array(range(signatures_embs[0].shape[0])) 97 | queries_pids = np.array(range(queries_embs[0].shape[0])) 98 | same_not_same_list = loss_model_class.calc_same_not_same_list(signatures_embs, queries_embs, signature_pids, queries_pids, last_layer_separability_only=last_layer_separability_only) 99 | for i in range(len(same_not_same_list)): 100 | arrays = [ 101 | same_not_same_list[i]['same'], 102 | same_not_same_list[i]['not_same'], 103 | same_not_same_list[i]['all_not_same'] 104 | ] 105 | if title is not None: 106 | plot_title = '%s layer %d' % (title, i) 107 | plot_hist( 108 | arrays, 109 | title=plot_title, 110 | labels=['same', 'not_same_2nd_best', 'all_not_same'], 111 | xlabel='Distance', 112 | bins=100, 113 | ncol=3, 114 | percentiles=(0.2, 0.5, 0.8), 115 | min_lim=0, 116 | max_lim=1000, 117 | save_path='%s_layer_%d.jpg' % (save_path, i) 118 | ) 119 | return same_not_same_list, images_list 120 | -------------------------------------------------------------------------------- /src/gan_control/fid_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/fid_utils/__init__.py -------------------------------------------------------------------------------- /src/gan_control/fid_utils/calc_inception.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | import cv2 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | from torchvision.models import inception_v3, Inception3 12 | import numpy as np 13 | from tqdm import tqdm 14 | from torch.utils import data 15 | 16 | from gan_control.fid_utils.inception import InceptionV3 17 | 18 | 19 | class Inception3Feature(Inception3): 20 | def forward(self, x): 21 | if x.shape[2] != 299 or x.shape[3] != 299: 22 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) 23 | 24 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 25 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 26 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 27 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 28 | 29 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 30 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 31 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 32 | 33 | x = self.Mixed_5b(x) # 35 x 35 x 192 34 | x = self.Mixed_5c(x) # 35 x 35 x 256 35 | x = self.Mixed_5d(x) # 35 x 35 x 288 36 | 37 | x = self.Mixed_6a(x) # 35 x 35 x 288 38 | x = self.Mixed_6b(x) # 17 x 17 x 768 39 | x = self.Mixed_6c(x) # 17 x 17 x 768 40 | x = self.Mixed_6d(x) # 17 x 17 x 768 41 | x = self.Mixed_6e(x) # 17 x 17 x 768 42 | 43 | x = self.Mixed_7a(x) # 17 x 17 x 768 44 | x = self.Mixed_7b(x) # 8 x 8 x 1280 45 | x = self.Mixed_7c(x) # 8 x 8 x 2048 46 | 47 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 48 | 49 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 50 | 51 | 52 | def load_patched_inception_v3(): 53 | # inception = inception_v3(pretrained=True) 54 | # inception_feat = Inception3Feature() 55 | # inception_feat.load_state_dict(inception.state_dict()) 56 | inception_feat = InceptionV3([3], normalize_input=False) 57 | 58 | return inception_feat 59 | 60 | 61 | @torch.no_grad() 62 | def extract_features(loader, inception, device, num_of_input_channels=3, batch_size=36): 63 | pbar = tqdm(loader) 64 | 65 | feature_list = [] 66 | 67 | for iter, (img, _) in enumerate(pbar): 68 | img = img.to(device) 69 | if num_of_input_channels == 1: # TODO: check 70 | img[:, 1, :, :] = img[:, 0, :, :] 71 | img[:, 2, :, :] = img[:, 0, :, :] 72 | feature = inception(img)[0].view(img.shape[0], -1) 73 | feature_list.append(feature.to('cpu')) 74 | 75 | if (iter + 1) * batch_size > 50000: 76 | break 77 | 78 | features = torch.cat(feature_list, 0) 79 | 80 | return features 81 | 82 | -------------------------------------------------------------------------------- /src/gan_control/fid_utils/evaluate_fid.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | from gan_control.fid_utils.fid import extract_feature_from_samples, calc_fid 6 | import numpy as np 7 | import pickle 8 | import time 9 | 10 | 11 | def evaluate_fid(generator, inception, batch, n_sample, device, inception_stat_path, training=False): 12 | start_time = time.time() 13 | generator.eval() 14 | inception.eval() 15 | 16 | features = extract_feature_from_samples( 17 | generator, 18 | inception, 19 | batch, 20 | n_sample, 21 | device, 22 | training=training 23 | ).numpy() 24 | print(f'extracted {features.shape[0]} features') 25 | 26 | sample_mean = np.mean(features, 0) 27 | sample_cov = np.cov(features, rowvar=False) 28 | 29 | with open(inception_stat_path, 'rb') as f: 30 | embeds = pickle.load(f) 31 | real_mean = embeds['mean'] 32 | real_cov = embeds['cov'] 33 | 34 | fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) 35 | print('fid: %.3f, time: %.3f (min)' % (fid, (time.time() - start_time) / 60)) 36 | return fid -------------------------------------------------------------------------------- /src/gan_control/fid_utils/fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | from scipy import linalg 8 | from tqdm import tqdm 9 | 10 | from gan_control.models.gan_model import Generator 11 | from gan_control.fid_utils.calc_inception import load_patched_inception_v3 12 | 13 | 14 | @torch.no_grad() 15 | def extract_feature_from_samples( 16 | generator, 17 | inception, 18 | batch_size, 19 | n_sample, 20 | device="cuda", 21 | training=True): 22 | with torch.no_grad(): 23 | n_batch = n_sample // batch_size 24 | resid = n_sample - (n_batch * batch_size) 25 | batch_sizes = [batch_size] * n_batch + [resid] 26 | if resid == 0: 27 | batch_sizes = [batch_size] * n_batch 28 | features = [] 29 | 30 | for batch in tqdm(batch_sizes, disable=training): 31 | latent = torch.randn(batch, 512, device=device) 32 | img, _ = generator([latent]) 33 | if img.shape[1] == 1: 34 | img = torch.cat([img[:,0,:,:].unsqueeze(dim=1), img[:,0,:,:].unsqueeze(dim=1), img[:,0,:,:].unsqueeze(dim=1)], dim=1) 35 | feat = inception(img)[0].view(img.shape[0], -1) 36 | features.append(feat.to('cpu')) 37 | 38 | features = torch.cat(features, 0) 39 | 40 | return features 41 | 42 | 43 | def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): 44 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 45 | 46 | if not np.isfinite(cov_sqrt).all(): 47 | print('product of cov matrices is singular') 48 | offset = np.eye(sample_cov.shape[0]) * eps 49 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 50 | 51 | if np.iscomplexobj(cov_sqrt): 52 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 53 | m = np.max(np.abs(cov_sqrt.imag)) 54 | 55 | raise ValueError(f'Imaginary component {m}') 56 | 57 | cov_sqrt = cov_sqrt.real 58 | 59 | mean_diff = sample_mean - real_mean 60 | mean_norm = mean_diff @ mean_diff 61 | 62 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 63 | 64 | fid = mean_norm + trace 65 | 66 | return fid 67 | 68 | -------------------------------------------------------------------------------- /src/gan_control/inception_stats/README.md: -------------------------------------------------------------------------------- 1 | Put inception statistic here (for FID calculations). 2 | 3 | -------------------------------------------------------------------------------- /src/gan_control/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/inference/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/arc_face/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/arc_face/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/arc_face/arc_face_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | 6 | 7 | class ArcFaceCriterion: 8 | """ 9 | ArcFaceCriterion 10 | Implements an embedding distance 11 | """ 12 | 13 | def __init__(self): 14 | super(ArcFaceCriterion, self).__init__() 15 | 16 | def __call__(self, signatures: torch.Tensor, queries: torch.Tensor): 17 | signatures = signatures.unsqueeze(dim=1) 18 | queries = queries.unsqueeze(dim=0) 19 | diff = signatures - queries 20 | distances = torch.sum(torch.pow(diff, 2), dim=-1) 21 | #print(distances.shape) 22 | return distances -------------------------------------------------------------------------------- /src/gan_control/losses/arc_face/arc_face_skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import torch 6 | from collections import namedtuple 7 | import yaml 8 | import importlib 9 | from torch.nn import functional as F 10 | 11 | from gan_control.losses.arc_face.arc_face_model import Backbone, l2_norm 12 | from gan_control.utils.tensor_transforms import center_crop_tensor 13 | 14 | 15 | class ArcFaceSkeleton(torch.nn.Module): 16 | def __init__(self, config): 17 | super(ArcFaceSkeleton, self).__init__() 18 | self.config = config 19 | self.net = self.get_arc_face_model(config) 20 | self.layer1 = self.net.body[:3] 21 | self.layer2 = self.net.body[3:7] 22 | self.layer3 = self.net.body[7:21] 23 | self.layer4 = self.net.body[21:] 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def forward(self, x): 28 | if x.shape[-1] != 112: 29 | if self.config['center_crop'] is not None: 30 | x = center_crop_tensor(x, self.config['center_crop']) 31 | x = F.interpolate(x, size=(112, 112), mode='bilinear', align_corners=True) 32 | x = self.net.input_layer(x) 33 | layer1 = self.layer1(x) 34 | layer2 = self.layer2(layer1) 35 | layer3 = self.layer3(layer2) 36 | layer4 = self.layer4(layer3) 37 | output = l2_norm(self.net.output_layer(layer4)) 38 | out = [layer1, layer2, layer3, layer4, output] 39 | return out 40 | 41 | @staticmethod 42 | def get_arc_face_model(config): 43 | model = Backbone(config['num_layers'], config['drop_ratio'], mode=config['mode']) 44 | model.load_state_dict(torch.load(config['model_path'])) 45 | model.eval() 46 | return model 47 | 48 | @staticmethod 49 | def normelize_to_model_input(batch): 50 | return batch 51 | 52 | 53 | -------------------------------------------------------------------------------- /src/gan_control/losses/deep_expectation_age/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/deep_expectation_age/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/deep_expectation_age/deep_age_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | class DeepAgeCriterion: 9 | """ 10 | HopenetCriterion 11 | Implements an embedding distance 12 | """ 13 | 14 | def __init__(self): 15 | super(DeepAgeCriterion, self).__init__() 16 | self.mse = torch.nn.MSELoss() 17 | 18 | def __call__(self, signatures: torch.Tensor, queries: torch.Tensor): 19 | signatures = signatures.unsqueeze(dim=1) 20 | queries = queries.unsqueeze(dim=0) 21 | diff = signatures - queries 22 | distances = torch.mean(torch.abs(diff), dim=(-1)) 23 | return distances 24 | 25 | @staticmethod 26 | def get_predict_age(age_pb): 27 | predict_age_pb = F.softmax(age_pb, dim=-1) 28 | predict_age = torch.zeros(age_pb.size(0)).type_as(predict_age_pb) 29 | for i in range(age_pb.size(0)): 30 | for j in range(age_pb.size(1)): 31 | predict_age[i] += j * predict_age_pb[i][j] 32 | return predict_age 33 | 34 | def predict(self, age_pb): 35 | return self.get_predict_age(age_pb) 36 | 37 | def controller_criterion(self, pred, target): 38 | return self.mse(pred, target) -------------------------------------------------------------------------------- /src/gan_control/losses/deep_expectation_age/deep_age_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import math 5 | import torch.nn.functional as F 6 | 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, pool='max'): 10 | super(VGG, self).__init__() 11 | # vgg modules 12 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 13 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 14 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 15 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 16 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 17 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 18 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 19 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 20 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 21 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 22 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 23 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 24 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 25 | self.fc6 = nn.Linear(25088, 4096, bias=True) 26 | self.fc7 = nn.Linear(4096, 4096, bias=True) 27 | self.fc8_101 = nn.Linear(4096, 101, bias=True) 28 | if pool == 'max': 29 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 30 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 31 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 32 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 33 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 34 | elif pool == 'avg': 35 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 36 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 37 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) 38 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) 39 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) 40 | 41 | def forward(self, x): 42 | out = {} 43 | out['r11'] = F.relu(self.conv1_1(x)) 44 | out['r12'] = F.relu(self.conv1_2(out['r11'])) 45 | out['p1'] = self.pool1(out['r12']) 46 | out['r21'] = F.relu(self.conv2_1(out['p1'])) 47 | out['r22'] = F.relu(self.conv2_2(out['r21'])) 48 | out['p2'] = self.pool2(out['r22']) 49 | out['r31'] = F.relu(self.conv3_1(out['p2'])) 50 | out['r32'] = F.relu(self.conv3_2(out['r31'])) 51 | out['r33'] = F.relu(self.conv3_3(out['r32'])) 52 | out['p3'] = self.pool3(out['r33']) 53 | out['r41'] = F.relu(self.conv4_1(out['p3'])) 54 | out['r42'] = F.relu(self.conv4_2(out['r41'])) 55 | out['r43'] = F.relu(self.conv4_3(out['r42'])) 56 | out['p4'] = self.pool4(out['r43']) 57 | out['r51'] = F.relu(self.conv5_1(out['p4'])) 58 | out['r52'] = F.relu(self.conv5_2(out['r51'])) 59 | out['r53'] = F.relu(self.conv5_3(out['r52'])) 60 | out['p5'] = self.pool5(out['r53']) 61 | out['p5'] = out['p5'].view(out['p5'].size(0), -1) 62 | out['fc6'] = F.relu(self.fc6(out['p5'])) 63 | out['fc7'] = F.relu(self.fc7(out['fc6'])) 64 | out['fc8'] = self.fc8_101(out['fc7']) 65 | return out 66 | -------------------------------------------------------------------------------- /src/gan_control/losses/deep_expectation_age/deep_age_skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import torch 6 | from collections import namedtuple 7 | import yaml 8 | import importlib 9 | import torchvision 10 | from torch.nn import functional as F 11 | 12 | from gan_control.losses.deep_expectation_age.deep_age_model import VGG 13 | from gan_control.utils.tensor_transforms import center_crop_tensor 14 | 15 | 16 | class DeepAgeSkeleton(torch.nn.Module): 17 | def __init__(self, config): 18 | super(DeepAgeSkeleton, self).__init__() 19 | self.config = config 20 | self.net = self.get_vgg(config) 21 | for param in self.parameters(): 22 | param.requires_grad = False 23 | 24 | @staticmethod 25 | def vgg_transform(x): 26 | x = x.mul(0.5).add(0.5) 27 | x[:,0,:,:] = x[:,0,:,:] - 0.48501961 28 | x[:,1,:,:] = x[:,1,:,:] - 0.45795686 29 | x[:,2,:,:] = x[:,2,:,:] - 0.40760392 30 | """Adapt image for vgg network, x: image of range(0,1) subtracting ImageNet mean""" 31 | r, g, b = torch.split(x, 1, 1) 32 | out = torch.cat((b, g, r), dim=1) 33 | out = F.interpolate(out, size=(224, 224), mode='bilinear', align_corners=False) 34 | out = out * 255. 35 | return out 36 | 37 | @staticmethod 38 | def get_predict_age(age_pb): 39 | predict_age_pb = F.softmax(age_pb) 40 | predict_age = torch.zeros(age_pb.size(0)).type_as(predict_age_pb) 41 | for i in range(age_pb.size(0)): 42 | for j in range(age_pb.size(1)): 43 | predict_age[i] += j * predict_age_pb[i][j] 44 | return predict_age 45 | 46 | def forward(self, x): 47 | if self.config['center_crop'] is not None: 48 | x = center_crop_tensor(x, self.config['center_crop']) 49 | x = self.vgg_transform(x) 50 | x = F.relu(self.net.conv1_1(x)) 51 | x = F.relu(self.net.conv1_2(x)) 52 | x = self.net.pool1(x) 53 | x = F.relu(self.net.conv2_1(x)) 54 | x = F.relu(self.net.conv2_2(x)) 55 | x = self.net.pool2(x) 56 | x = F.relu(self.net.conv3_1(x)) 57 | x = F.relu(self.net.conv3_2(x)) 58 | x = F.relu(self.net.conv3_3(x)) 59 | x = self.net.pool3(x) 60 | x = F.relu(self.net.conv4_1(x)) 61 | x = F.relu(self.net.conv4_2(x)) 62 | x = F.relu(self.net.conv4_3(x)) 63 | x = self.net.pool4(x) 64 | x = F.relu(self.net.conv5_1(x)) 65 | x = F.relu(self.net.conv5_2(x)) 66 | x = F.relu(self.net.conv5_3(x)) 67 | x = self.net.pool5(x) 68 | x = x.view(x.size(0), -1) 69 | x = F.relu(self.net.fc6(x)) 70 | out0 = F.relu(self.net.fc7(x)) 71 | out1 = self.net.fc8_101(out0) 72 | 73 | 74 | return [out1] 75 | 76 | @staticmethod 77 | def get_vgg(config): 78 | model = VGG() 79 | vgg_state_dict = torch.load(config['model_path']) 80 | vgg_state_dict = {k.replace('-', '_'): v for k, v in vgg_state_dict.items()} 81 | model.load_state_dict(vgg_state_dict) 82 | model.eval() 83 | return model 84 | 85 | -------------------------------------------------------------------------------- /src/gan_control/losses/deep_head_pose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/deep_head_pose/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/deep_head_pose/hopenet_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | from gan_control.evaluation.orientation import softmax_temperature 6 | 7 | def calc_orientation_from_features(features): 8 | idx_tensor = [idx for idx in range(66)] 9 | idx_tensor = torch.FloatTensor(idx_tensor).cuda() 10 | 11 | _, yaw_bpred = torch.max(features[:, 0, :], 1) 12 | _, pitch_bpred = torch.max(features[:, 1, :], 1) 13 | _, roll_bpred = torch.max(features[:, 2, :], 1) 14 | 15 | yaw_predicted = softmax_temperature(features[:, 0, :], 1) 16 | pitch_predicted = softmax_temperature(features[:, 1, :], 1) 17 | roll_predicted = softmax_temperature(features[:, 2, :], 1) 18 | 19 | yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 99 20 | pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 99 21 | roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) * 3 - 99 22 | 23 | return yaw_predicted, pitch_predicted, roll_predicted 24 | 25 | class HopenetCriterion: 26 | """ 27 | HopenetCriterion 28 | Implements an embedding distance 29 | """ 30 | 31 | def __init__(self): 32 | super(HopenetCriterion, self).__init__() 33 | 34 | def __call__(self, signatures: torch.Tensor, queries: torch.Tensor): 35 | signatures = signatures.unsqueeze(dim=1) 36 | queries = queries.unsqueeze(dim=0) 37 | diff = signatures - queries 38 | distances = torch.mean(torch.abs(diff), dim=(-2, -1)) 39 | return distances 40 | 41 | def predict(self, features): 42 | yaw_predicted, pitch_predicted, roll_predicted = calc_orientation_from_features(features) 43 | return torch.cat([yaw_predicted.unsqueeze(1), pitch_predicted.unsqueeze(1), roll_predicted.unsqueeze(1)], dim=1) 44 | 45 | def controller_criterion(self, pred, target): 46 | return torch.abs(pred - target).mean() 47 | 48 | -------------------------------------------------------------------------------- /src/gan_control/losses/deep_head_pose/hopenet_skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import torch 6 | from collections import namedtuple 7 | import yaml 8 | import importlib 9 | import torchvision 10 | from torch.nn import functional as F 11 | 12 | from gan_control.losses.deep_head_pose.hopenet_model import Hopenet 13 | 14 | 15 | class HopenetSkeleton(torch.nn.Module): 16 | def __init__(self, config): 17 | super(HopenetSkeleton, self).__init__() 18 | self.net = self.get_hopenet_model(config) 19 | self.mean = [0.485, 0.456, 0.406] 20 | self.std = [0.229, 0.224, 0.225] 21 | for param in self.parameters(): 22 | param.requires_grad = False 23 | 24 | def forward(self, x): 25 | # transformations = transforms.Compose([transforms.Scale(224), 26 | # transforms.CenterCrop(224), transforms.ToTensor(), 27 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 28 | if x.shape[-1] != 224: 29 | x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=True) 30 | # normelize 31 | x = x.mul(0.5).add(0.5) 32 | x[:, 0, :, :] = (x[:, 0, :, :] - self.mean[0]) / self.std[0] 33 | x[:, 1, :, :] = (x[:, 1, :, :] - self.mean[1]) / self.std[1] 34 | x[:, 2, :, :] = (x[:, 2, :, :] - self.mean[2]) / self.std[2] 35 | 36 | 37 | x = self.net.conv1(x) 38 | x = self.net.bn1(x) 39 | x = self.net.relu(x) 40 | x = self.net.maxpool(x) 41 | 42 | layer1 = self.net.layer1(x) 43 | layer2 = self.net.layer2(layer1) 44 | layer3 = self.net.layer3(layer2) 45 | layer4 = self.net.layer4(layer3) 46 | 47 | x = self.net.avgpool(layer4) 48 | x = x.view(x.size(0), -1) 49 | pre_yaw = self.net.fc_yaw(x) 50 | pre_pitch = self.net.fc_pitch(x) 51 | pre_roll = self.net.fc_roll(x) 52 | 53 | output = torch.cat([pre_yaw.unsqueeze(1), pre_pitch.unsqueeze(1), pre_roll.unsqueeze(1)], dim=1) 54 | out = [layer1, layer2, layer3, layer4, output] 55 | 56 | return out 57 | 58 | @staticmethod 59 | def get_hopenet_model(config): 60 | model = Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66) 61 | model.load_state_dict(torch.load(config['model_path'])) 62 | model.eval() 63 | return model 64 | 65 | -------------------------------------------------------------------------------- /src/gan_control/losses/dogfacenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/dogfacenet/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/dogfacenet/dogfacenet_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | 6 | 7 | class DogFaceCriterion: 8 | def __init__(self, loss_name='defualt'): 9 | super(DogFaceCriterion, self).__init__() 10 | self.loss_name = loss_name 11 | 12 | def __call__(self, signatures: torch.Tensor, queries: torch.Tensor): 13 | signatures = signatures.unsqueeze(dim=1) 14 | queries = queries.unsqueeze(dim=0) 15 | diff = signatures - queries 16 | distances = torch.sum(torch.pow(diff, 2), dim=-1) 17 | 18 | return distances 19 | -------------------------------------------------------------------------------- /src/gan_control/losses/dogfacenet/dogfacenet_skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import torch 6 | from collections import namedtuple 7 | import yaml 8 | import importlib 9 | from torch.nn import functional as F 10 | import pickle 11 | 12 | from gan_control.losses.dogfacenet.models.pytorch_dogfacenet_model import DogFaceNet 13 | from gan_control.utils.tensor_transforms import center_crop_tensor 14 | 15 | 16 | class DogFaceNetSkeleton(torch.nn.Module): 17 | def __init__(self, config): 18 | super(DogFaceNetSkeleton, self).__init__() 19 | self.config = config 20 | self.net = self.get_dogfacenet_model(config) 21 | for param in self.parameters(): 22 | param.requires_grad = False 23 | 24 | def forward(self, x): 25 | x = x.mul(0.5).add(0.5) 26 | if x.shape[-1] != 224: 27 | if self.config['center_crop'] is not None: 28 | x = center_crop_tensor(x, self.config['center_crop']) 29 | x = F.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True) 30 | out = self.net(x) 31 | return [out] 32 | 33 | @staticmethod 34 | def get_dogfacenet_model(config): 35 | model = DogFaceNet() 36 | model.load_state_dict(torch.load(config['model_path'])) 37 | model.eval() 38 | return model 39 | 40 | @staticmethod 41 | def normelize_to_model_input(batch): 42 | return batch 43 | 44 | 45 | -------------------------------------------------------------------------------- /src/gan_control/losses/dogfacenet/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/dogfacenet/models/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/dogfacenet/models/pytorch_dogfacenet_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | #import tensorflow as tf 5 | 6 | 7 | #def print_diff(name, py_x, tf_m, index, tf_input): 8 | # py_x = py_x.detach().numpy() 9 | # print('Layer: %s' % name) 10 | # graph0 = tf.keras.Model(tf_m.input, tf_m.get_layer(index=index).output) 11 | # tf_x = graph0.predict(tf_input) 12 | # print('shapes: py %s, tf %s' % (str(py_x.shape), str(tf_x.shape))) 13 | # if len(py_x.shape) == 4: 14 | # print('2x4: \npy %s, \ntf %s' % (str(py_x[0, 0, :4, :4]), str(tf_x[0, :4, :4, 0]))) 15 | # print('diff: %s' % str(np.mean(np.abs(py_x[0, 0, :, :] - tf_x[0, :, :, 0])))) 16 | # else: 17 | # print('1: \npy %s, \ntf %s' % (str(py_x[0, :1]), str(tf_x[0, :1]))) 18 | # print('diff: %s' % str(np.mean(np.abs(py_x[0, :] - tf_x[0, :])))) 19 | 20 | 21 | def l2_norm(input, axis=1): 22 | norm = torch.norm(input, 2, axis, True) 23 | output = torch.div(input, norm) 24 | return output 25 | 26 | 27 | class ResBlock(nn.Module): 28 | def __init__(self, in_s, out_s, pad='reg'): 29 | super(ResBlock, self).__init__() 30 | if pad=='reg': 31 | self.pad0 = nn.ZeroPad2d((1, 1, 1, 1)) 32 | else: 33 | self.pad0 = nn.ZeroPad2d((0, 1, 0, 1)) 34 | self.conv0 = nn.Conv2d(in_s, out_s, kernel_size=3, stride=(2,2), padding=0, bias=False) 35 | self.relu = nn.ReLU() 36 | self.bn0 = nn.BatchNorm2d(out_s) 37 | self.conv1 = nn.Conv2d(out_s, out_s, kernel_size=3, padding=1, bias=False) 38 | self.bn1 = nn.BatchNorm2d(out_s) 39 | self.conv2 = nn.Conv2d(out_s, out_s, kernel_size=3, padding=1, bias=False) 40 | self.bn2 = nn.BatchNorm2d(out_s) 41 | 42 | #def forward2(self, x, tf_model, tf_in): 43 | # x = self.pad0(x) 44 | # x = self.conv0(x) 45 | # x = self.relu(x) 46 | # print_diff('block3_conv0', x, tf_model, 20, tf_in) 47 | # r = self.bn0(x) 48 | # print_diff('block3_bn0', r, tf_model, 21, tf_in) 49 | # 50 | # x = self.conv1(r) 51 | # x = self.relu(x) 52 | # print_diff('block3_conv1', x, tf_model, 22, tf_in) 53 | # x = self.bn1(x) 54 | # 55 | # r = r + x 56 | # print_diff('block3_add0', r, tf_model, 24, tf_in) 57 | # x = self.conv2(r) 58 | # x = self.relu(x) 59 | # x = self.bn2(x) 60 | # 61 | # r = r + x 62 | # return r 63 | 64 | def forward(self, x): 65 | x = self.pad0(x) 66 | x = self.conv0(x) 67 | x = self.relu(x) 68 | r = self.bn0(x) 69 | 70 | x = self.conv1(r) 71 | x = self.relu(x) 72 | x = self.bn1(x) 73 | 74 | r = r + x 75 | 76 | x = self.conv2(r) 77 | x = self.relu(x) 78 | x = self.bn2(x) 79 | 80 | r = r + x 81 | return r 82 | 83 | 84 | class DogFaceNet(nn.Module): 85 | def __init__(self): 86 | super(DogFaceNet, self).__init__() 87 | self.pad0 = nn.ZeroPad2d((2, 4, 2, 4)) 88 | self.conv0 = nn.Conv2d(3, 16, kernel_size=7, stride=2, padding=0, bias=False) 89 | self.relu = nn.ReLU() 90 | self.bn0 = nn.BatchNorm2d(16) 91 | self.maxpooling3 = nn.MaxPool2d(kernel_size=3) 92 | 93 | self.res_block1 = ResBlock(16,16) 94 | self.res_block2 = ResBlock(16, 32) 95 | self.res_block3 = ResBlock(32, 64, pad='b3') 96 | self.res_block4 = ResBlock(64, 128) 97 | self.res_block5 = ResBlock(128, 512) 98 | 99 | self.global_avg_pooling = nn.AdaptiveAvgPool2d(1) 100 | self.fc = nn.Linear(512, 32, bias=False) 101 | 102 | for param in self.parameters(): 103 | param.requires_grad = False 104 | 105 | def forward(self, x): 106 | x = self.pad0(x) 107 | x = self.conv0(x) 108 | x = self.relu(x) 109 | x = self.bn0(x) 110 | x = self.maxpooling3(x) 111 | 112 | x = self.res_block1(x) 113 | x = self.res_block2(x) 114 | x = self.res_block3(x) 115 | x = self.res_block4(x) 116 | x = self.res_block5(x) 117 | 118 | x = self.global_avg_pooling(x) 119 | 120 | x = x.squeeze(3).squeeze(2) 121 | x = self.fc(x) 122 | x = l2_norm(x) 123 | return x 124 | 125 | #def forward2(self, x, tf_model, tf_in): 126 | # x = self.pad0(x) 127 | # x = self.conv0(x) 128 | # x = self.relu(x) 129 | # print_diff('conv0', x, tf_model, 1, tf_in) 130 | # x = self.bn0(x) 131 | # print_diff('bn0', x, tf_model, 2, tf_in) 132 | # x = self.maxpooling3(x) 133 | # print_diff('maxpool3', x, tf_model, 3, tf_in) 134 | # 135 | # x = self.res_block1(x) 136 | # print_diff('block1', x, tf_model, 11, tf_in) 137 | # x = self.res_block2(x) 138 | # print_diff('block2', x, tf_model, 19, tf_in) 139 | # x = self.res_block3.forward2(x, tf_model, tf_in) 140 | # print_diff('block3', x, tf_model, 27, tf_in) 141 | # x = self.res_block4(x) 142 | # print_diff('block4', x, tf_model, 35, tf_in) 143 | # x = self.res_block5(x) 144 | # print_diff('block5', x, tf_model, 43, tf_in) 145 | # 146 | # x = self.global_avg_pooling(x) 147 | # 148 | # x = x.squeeze(3).squeeze(2) 149 | # x = self.fc(x) 150 | # x = l2_norm(x) 151 | # print_diff('l2', x, tf_model, 48, tf_in) 152 | # return x 153 | # 154 | 155 | 156 | -------------------------------------------------------------------------------- /src/gan_control/losses/face3dmm_recon/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/face3dmm_recon/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/face3dmm_recon/face3dmm_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | 6 | 7 | class Face3dmmCriterion: 8 | 9 | def __init__(self, loss_name='defualt'): 10 | super(Face3dmmCriterion, self).__init__() 11 | self.loss_name = loss_name 12 | self.mse = torch.nn.MSELoss() 13 | 14 | def __call__(self, signatures: torch.Tensor, queries: torch.Tensor): 15 | #if self.loss_name in ['recon_3d_loss', 'gamma_loss', 'id_loss', 'ex_loss', 'tex_loss', 'angles_loss', 'xy_loss', 'z_loss']: 16 | signatures = signatures.unsqueeze(dim=1) 17 | queries = queries.unsqueeze(dim=0) 18 | diff = signatures - queries 19 | distances = torch.mean(torch.abs(diff), dim=-1) 20 | #distances = torch.sqrt(torch.pow(diff, 2).sum(dim=-1)) used in some evaluations 21 | return distances 22 | 23 | def controller_criterion(self, pred, target): 24 | return torch.abs(pred - target).mean() 25 | -------------------------------------------------------------------------------- /src/gan_control/losses/face3dmm_recon/face3dmm_skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import torch 6 | from collections import namedtuple 7 | import yaml 8 | import importlib 9 | from torch.nn import functional as F 10 | 11 | from gan_control.losses.face3dmm_recon.models.pytorch_3d_recon_model import Recon3D 12 | from gan_control.utils.tensor_transforms import center_crop_tensor 13 | 14 | 15 | class Face3dmmSkeleton(torch.nn.Module): 16 | def __init__(self, config): 17 | super(Face3dmmSkeleton, self).__init__() 18 | self.config = config 19 | self.net = self.get_face_3dmm_model(config) 20 | for param in self.parameters(): 21 | param.requires_grad = False 22 | 23 | def forward(self, x): 24 | r = x[:, :1, :, :] 25 | g = x[:, 1:2, :, :] 26 | b = x[:, 2:3, :, :] 27 | x = torch.cat([b, g, r], dim=1).mul(0.5).add(0.5).mul(255) 28 | if x.shape[-1] != 224: 29 | if self.config['center_crop'] is not None: 30 | x = center_crop_tensor(x, self.config['center_crop']) 31 | x = F.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True) 32 | out = self.net(x) 33 | return [out] 34 | 35 | @staticmethod 36 | def extract_futures_from_vec(out_vec): 37 | out_vec = out_vec[-1] 38 | return [out_vec[:,:80]], [out_vec[:,80:144]], [out_vec[:,144:224]], [out_vec[:,224:227]], [out_vec[:,227:254]], [out_vec[:,254:256]], [out_vec[:,256:257]] 39 | 40 | @staticmethod 41 | def get_face_3dmm_model(config): 42 | model = Recon3D() 43 | model.load_state_dict(torch.load(config['model_path'])) 44 | model.eval() 45 | return model 46 | 47 | @staticmethod 48 | def normelize_to_model_input(batch): 49 | return batch 50 | 51 | 52 | -------------------------------------------------------------------------------- /src/gan_control/losses/face3dmm_recon/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/face3dmm_recon/models/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/face3dmm_recon/models/tf_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def load_graph(graph_filename): 5 | with tf.gfile.GFile(graph_filename,'rb') as f: 6 | graph_def = tf.GraphDef() 7 | graph_def.ParseFromString(f.read()) 8 | 9 | return graph_def 10 | 11 | 12 | if __name__ == '__main__': 13 | graph_def = load_graph('path to FaceReconModel.pb') # TODO: path to FaceReconModel.pb 14 | print(graph_def) -------------------------------------------------------------------------------- /src/gan_control/losses/facial_features_esr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/facial_features_esr/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/facial_features_esr/esr9_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | 6 | 7 | class ESR9Criterion: 8 | """ 9 | ESR9Criterion 10 | Implements an embedding distance 11 | """ 12 | 13 | def __init__(self): 14 | super(ESR9Criterion, self).__init__() 15 | 16 | def __call__(self, signatures: torch.Tensor, queries: torch.Tensor): 17 | signatures = signatures.unsqueeze(dim=1) 18 | queries = queries.unsqueeze(dim=0) 19 | diff = signatures - queries 20 | distances = torch.mean(torch.abs(diff), dim=(-2, -1)) 21 | return distances -------------------------------------------------------------------------------- /src/gan_control/losses/facial_features_esr/esr9_skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import torch 6 | from collections import namedtuple 7 | import yaml 8 | import importlib 9 | import torchvision 10 | from torch.nn import functional as F 11 | 12 | from gan_control.losses.facial_features_esr.esr9_model import ESR 13 | from gan_control.utils.tensor_transforms import center_crop_tensor 14 | 15 | 16 | class ESR9Skeleton(torch.nn.Module): 17 | def __init__(self, config): 18 | super(ESR9Skeleton, self).__init__() 19 | self.config = config 20 | self.net = self.get_esr9_model(config) 21 | for param in self.parameters(): 22 | param.requires_grad = False 23 | 24 | def forward(self, x): 25 | if self.config['center_crop'] is not None: 26 | x = center_crop_tensor(x, self.config['center_crop']) 27 | if x.shape[-1] != ESR.INPUT_IMAGE_SIZE[0]: 28 | x = F.interpolate(x, size=ESR.INPUT_IMAGE_SIZE, mode='bilinear', align_corners=True) 29 | # normelize 30 | x = x.mul(0.5).add(0.5) 31 | 32 | emotions = [] 33 | affect_values = [] 34 | 35 | # Get shared representations 36 | x_shared_representations = self.net.base(x) 37 | for branch in self.net.convolutional_branches: 38 | output_emotion, output_affect = branch(x_shared_representations) 39 | emotions.append(output_emotion.unsqueeze(1)) 40 | affect_values.append(output_affect) 41 | 42 | out = [x_shared_representations, torch.cat(emotions, dim=1)] 43 | return out 44 | 45 | @staticmethod 46 | def get_esr9_model(config): 47 | model = ESR(config['model_path']) 48 | model.eval() 49 | return model 50 | 51 | 52 | if __name__ == '__main__': 53 | model = ESR() 54 | print(model) 55 | 56 | -------------------------------------------------------------------------------- /src/gan_control/losses/hair_loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/hair_loss/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/hair_loss/hair_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | 6 | from gan_control.evaluation.hair import make_hair_color_grid 7 | 8 | 9 | class HairCriterion: 10 | def __init__(self): 11 | super(HairCriterion, self).__init__() 12 | self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda() 13 | self.std = torch.tensor([0.229, 0.224, 0.225]).cuda() 14 | self.mse = torch.nn.MSELoss().cuda() 15 | 16 | def __call__(self, signatures: torch.Tensor, queries: torch.Tensor): 17 | b, c, h, w = signatures.shape 18 | thres = 0.01 * w*h 19 | signatures_masked_image = signatures[:, :3, :, :] 20 | signatures_mask = signatures[:, 3:, :, :] 21 | queries_masked_image = queries[:, :3, :, :] 22 | queries_mask = queries[:, 3:, :, :] 23 | 24 | signatures_mask_sum = torch.sum(signatures_mask.detach(), dim=[-2, -1]) 25 | queries_mask_sum = torch.sum(queries_mask.detach(), dim=[-2, -1]) 26 | signatures_mask_u = (signatures_mask_sum > thres).unsqueeze(dim=1) 27 | queries_mask_v = (queries_mask_sum > thres).unsqueeze(dim=0) 28 | valid_uv_mask = signatures_mask_u * queries_mask_v 29 | #valid_mask = (signatures_mask_sum > thres) * (queries_mask_sum > thres) 30 | 31 | signatures = torch.div(torch.sum(signatures_masked_image, dim=[-2, -1]), signatures_mask_sum + (signatures_mask_sum < 0.5).float()) 32 | queries = torch.div(torch.sum(queries_masked_image, dim=[-2, -1]), queries_mask_sum + (queries_mask_sum < 0.5).float()) 33 | # signatures = (signatures * self.std + self.mean)[valid_mask] 34 | # queries = (queries * self.std + self.mean)[valid_mask] 35 | signatures = signatures.mul(0.5).add(0.5) #* valid_mask # signatures[valid_mask.squeeze()] 36 | queries = queries.mul(0.5).add(0.5) #* valid_mask # queries[valid_mask.squeeze()] 37 | 38 | signatures = signatures.unsqueeze(dim=1) 39 | queries = queries.unsqueeze(dim=0) 40 | diff = signatures - queries 41 | diff = diff * valid_uv_mask 42 | distances = torch.mean(torch.abs(diff), dim=-1) 43 | 44 | return distances 45 | 46 | @staticmethod 47 | def predict(features): 48 | masked_image = features[:, :3, :, :] 49 | mask = features[:, 3:, :, :] 50 | mask_sum = torch.sum(mask.detach(), dim=[-2, -1]) 51 | valid_mask = (mask_sum > 0.5) 52 | preds = torch.div(torch.sum(masked_image, dim=[-2, -1]), mask_sum + (mask_sum < 0.5).float()) 53 | preds = preds.mul(0.5).add(0.5) * valid_mask 54 | return preds 55 | 56 | def controller_criterion(self, pred, target): 57 | return self.mse(pred, target) 58 | 59 | @staticmethod 60 | def visual(hair_loss_class, tensor_images, save_path=None, target_pred=None, nrow=4): 61 | return make_hair_color_grid(hair_loss_class, tensor_images, nrow=nrow, save_path=save_path, downsample=None, target_pred=target_pred) -------------------------------------------------------------------------------- /src/gan_control/losses/hair_loss/hair_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision.models import squeezenet1_1, resnet101 5 | from torch.nn.init import xavier_normal_ 6 | 7 | """ 8 | Referenced from https://github.com/Lextal/pspnet-pytorch/blob/master/pspnet.py 9 | """ 10 | 11 | 12 | class ResNet101Extractor(nn.Module): 13 | def __init__(self): 14 | super(ResNet101Extractor, self).__init__() 15 | model = resnet101(pretrained=True) 16 | self.features = nn.Sequential(*list(model.children())[:7]) 17 | def forward(self, x): 18 | return self.features(x) 19 | 20 | class SqueezeNetExtractor(nn.Module): 21 | def __init__(self): 22 | super(SqueezeNetExtractor, self).__init__() 23 | model = squeezenet1_1(pretrained=True) 24 | features = model.features 25 | self.feature1 = features[:2] 26 | self.feature2 = features[2:5] 27 | self.feature3 = features[5:8] 28 | self.feature4 = features[8:] 29 | 30 | def forward(self, x): 31 | f1 = self.feature1(x) 32 | f2 = self.feature2(f1) 33 | f3 = self.feature3(f2) 34 | f4 = self.feature4(f3) 35 | return f4 36 | 37 | 38 | class PyramidPoolingModule(nn.Module): 39 | def __init__(self, in_channels, sizes=(1, 2, 3, 6)): 40 | super(PyramidPoolingModule, self).__init__() 41 | pyramid_levels = len(sizes) 42 | out_channels = in_channels // pyramid_levels 43 | 44 | pooling_layers = nn.ModuleList() 45 | for size in sizes: 46 | layers = [nn.AdaptiveAvgPool2d(size), nn.Conv2d(in_channels, out_channels, kernel_size=1)] 47 | pyramid_layer = nn.Sequential(*layers) 48 | pooling_layers.append(pyramid_layer) 49 | 50 | self.pooling_layers = pooling_layers 51 | 52 | def forward(self, x): 53 | h, w = x.size(2), x.size(3) 54 | features = [x] 55 | for pooling_layer in self.pooling_layers: 56 | # pool with different sizes 57 | pooled = pooling_layer(x) 58 | 59 | # upsample to original size 60 | upsampled = F.upsample(pooled, size=(h, w), mode='bilinear') 61 | 62 | features.append(upsampled) 63 | 64 | return torch.cat(features, dim=1) 65 | 66 | 67 | class UpsampleLayer(nn.Module): 68 | def __init__(self, in_channels, out_channels, upsample_size=None): 69 | super().__init__() 70 | self.upsample_size = upsample_size 71 | 72 | self.conv = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), 74 | nn.BatchNorm2d(out_channels), 75 | nn.ReLU() 76 | ) 77 | 78 | def forward(self, x): 79 | size = 2 * x.size(2), 2 * x.size(3) 80 | f = F.upsample(x, size=size, mode='bilinear') 81 | return self.conv(f) 82 | 83 | 84 | class PSPNet(nn.Module): 85 | def __init__(self, num_class=1, sizes=(1, 2, 3, 6), base_network='resnet101'): 86 | super(PSPNet, self).__init__() 87 | base_network = base_network.lower() 88 | if base_network == 'resnet101': 89 | self.base_network = ResNet101Extractor() 90 | feature_dim = 1024 91 | elif base_network == 'squeezenet': 92 | self.base_network = SqueezeNetExtractor() 93 | feature_dim = 512 94 | else: 95 | raise ValueError 96 | self.psp = PyramidPoolingModule(in_channels=feature_dim, sizes=sizes) 97 | self.drop_1 = nn.Dropout2d(p=0.3) 98 | 99 | self.up_1 = UpsampleLayer(2*feature_dim, 256) 100 | self.up_2 = UpsampleLayer(256, 64) 101 | self.up_3 = UpsampleLayer(64, 64) 102 | 103 | self.drop_2 = nn.Dropout2d(p=0.15) 104 | self.final = nn.Sequential( 105 | nn.Conv2d(64, num_class, kernel_size=1) 106 | ) 107 | 108 | self._init_weight() 109 | 110 | def forward(self, x): 111 | h, w = x.size(2), x.size(3) 112 | f = self.base_network(x) 113 | p = self.psp(f) 114 | p = self.drop_1(p) 115 | p = self.up_1(p) 116 | p = self.drop_2(p) 117 | 118 | p = self.up_2(p) 119 | p = self.drop_2(p) 120 | 121 | p = self.up_3(p) 122 | 123 | if (p.size(2) != h) or (p.size(3) != w): 124 | p = F.interpolate(p, size=(h, w), mode='bilinear') 125 | 126 | p = self.drop_2(p) 127 | 128 | return self.final(p) 129 | 130 | def _init_weight(self): 131 | layers = [self.up_1, self.up_2, self.up_3, self.final] 132 | for layer in layers: 133 | if isinstance(layer, nn.Conv2d): 134 | xavier_normal_(layer.weight.data) 135 | 136 | elif isinstance(layer, nn.BatchNorm2d): 137 | layer.weight.data.normal_(1.0, 0.02) 138 | layer.bias.data.fill_(0) 139 | 140 | -------------------------------------------------------------------------------- /src/gan_control/losses/hair_loss/hair_skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import torch 6 | from collections import namedtuple 7 | import yaml 8 | import importlib 9 | from torch.nn import functional as F 10 | import warnings 11 | 12 | from gan_control.losses.hair_loss.hair_model import PSPNet 13 | 14 | 15 | class HairSkeleton(torch.nn.Module): 16 | def __init__(self, config): 17 | super(HairSkeleton, self).__init__() 18 | self.net = self.get_hair_model(config) 19 | for param in self.parameters(): 20 | param.requires_grad = False 21 | 22 | def forward(self, x): 23 | with warnings.catch_warnings(): 24 | warnings.simplefilter("ignore") 25 | if x.shape[-1] != 256: 26 | x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=True) 27 | mask = x.detach() 28 | with torch.no_grad(): 29 | mask = self.normelize_to_model_input(mask) 30 | mask = self.net(mask) 31 | mask = self.make_mask_from_pred(mask) 32 | 33 | out = [torch.cat([x*mask, mask.float()], dim=1)] 34 | return out 35 | 36 | @staticmethod 37 | def make_mask_from_pred(pred): 38 | pred = torch.sigmoid(pred).detach() 39 | mask = pred >= 0.5 40 | return mask 41 | 42 | 43 | @staticmethod 44 | def get_hair_model(config): 45 | model = PSPNet(num_class=1, base_network='resnet101') 46 | model.load_state_dict(torch.load(config['model_path'])['weight']) 47 | model.eval() 48 | return model 49 | 50 | @staticmethod 51 | def normelize_to_model_input(batch): 52 | batch = batch.mul(0.5).add(0.5) 53 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) 54 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 55 | return (batch - mean) / std 56 | 57 | 58 | -------------------------------------------------------------------------------- /src/gan_control/losses/imagenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/imagenet/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/imagenet/imagenet_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | 6 | 7 | class ImageNetCriterion: 8 | def __init__(self): 9 | super(ImageNetCriterion, self).__init__() 10 | 11 | def __call__(self, signatures: torch.Tensor, queries: torch.Tensor): 12 | signatures = signatures.unsqueeze(dim=1) 13 | queries = queries.unsqueeze(dim=0) 14 | diff = signatures - queries 15 | distances = torch.mean(torch.abs(diff), dim=(-1)) 16 | #print(distances.shape) 17 | return distances -------------------------------------------------------------------------------- /src/gan_control/losses/imagenet/imagenet_skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import torch 6 | from collections import namedtuple 7 | import yaml 8 | import importlib 9 | from torch.nn import functional as F 10 | 11 | from torchvision.models import resnet18 12 | from gan_control.utils.tensor_transforms import center_crop_tensor 13 | 14 | 15 | class ImageNetSkeleton(torch.nn.Module): 16 | def __init__(self, config): 17 | super(ImageNetSkeleton, self).__init__() 18 | self.config = config 19 | self.net = resnet18(pretrained=True) 20 | for param in self.parameters(): 21 | param.requires_grad = False 22 | 23 | def forward(self, x): 24 | if x.shape[-1] != 224: 25 | if self.config['center_crop'] is not None: 26 | x = center_crop_tensor(x, self.config['center_crop']) 27 | x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=True) 28 | 29 | x = self.net.conv1(x) 30 | x = self.net.bn1(x) 31 | x = self.net.relu(x) 32 | x = self.net.maxpool(x) 33 | 34 | x = self.net.layer1(x) 35 | x = self.net.layer2(x) 36 | x = self.net.layer3(x) 37 | x = self.net.layer4(x) 38 | 39 | x = self.net.avgpool(x) 40 | b_last = torch.flatten(x, 1) 41 | last = self.net.fc(b_last) 42 | 43 | return [last, b_last] 44 | 45 | @staticmethod 46 | def normelize_to_model_input(batch): 47 | return batch 48 | 49 | 50 | if __name__ == '__main__': 51 | model = ImageNetSkeleton(None) 52 | 53 | -------------------------------------------------------------------------------- /src/gan_control/losses/stayle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/losses/stayle/__init__.py -------------------------------------------------------------------------------- /src/gan_control/losses/stayle/style_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | 6 | 7 | class StyleCriterion: 8 | def __init__(self): 9 | super(StyleCriterion, self).__init__() 10 | 11 | def __call__(self, signatures: torch.Tensor, queries: torch.Tensor): 12 | signatures = signatures.unsqueeze(dim=1).unsqueeze(dim=1) 13 | queries = queries.unsqueeze(dim=1).unsqueeze(dim=0) 14 | diff = signatures - queries 15 | distances = torch.mean(torch.pow(diff, 2), dim=(-2, -1)) 16 | return distances.squeeze(dim=2) * 1e5 17 | -------------------------------------------------------------------------------- /src/gan_control/losses/stayle/style_skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import torch 6 | from collections import namedtuple 7 | import yaml 8 | import importlib 9 | from torchvision import models 10 | import torchvision 11 | from torch.nn import functional as F 12 | 13 | from gan_control.utils.tensor_transforms import center_crop_tensor 14 | 15 | 16 | class StyleSkeleton(torch.nn.Module): 17 | def __init__(self, config): 18 | super(StyleSkeleton, self).__init__() 19 | self.net = models.vgg16(pretrained=True).features 20 | self.config = config 21 | self.resize_to = config['resize_to'] 22 | self.mean = [0.485, 0.456, 0.406] 23 | self.std = [0.229, 0.224, 0.225] 24 | self.slice1 = torch.nn.Sequential() 25 | self.slice2 = torch.nn.Sequential() 26 | self.slice3 = torch.nn.Sequential() 27 | self.slice4 = torch.nn.Sequential() 28 | for x in range(4): 29 | self.slice1.add_module(str(x), self.net[x]) 30 | for x in range(4, 9): 31 | self.slice2.add_module(str(x), self.net[x]) 32 | for x in range(9, 16): 33 | self.slice3.add_module(str(x), self.net[x]) 34 | for x in range(16, 23): 35 | self.slice4.add_module(str(x), self.net[x]) 36 | for param in self.parameters(): 37 | param.requires_grad = False 38 | 39 | def forward(self, x): 40 | # transformations = transforms.Compose([transforms.Scale(224), 41 | # transforms.CenterCrop(224), transforms.ToTensor(), 42 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 43 | if x.shape[-1] != self.resize_to: 44 | x = F.interpolate(x, size=(self.resize_to, self.resize_to), mode='bilinear', align_corners=True) 45 | if self.config['center_crop'] is not None: 46 | x = center_crop_tensor(x, self.config['center_crop']) 47 | # normelize 48 | x = x.mul(0.5).add(0.5) 49 | x[:, 0, :, :] = (x[:, 0, :, :] - self.mean[0]) / self.std[0] 50 | x[:, 1, :, :] = (x[:, 1, :, :] - self.mean[1]) / self.std[1] 51 | x[:, 2, :, :] = (x[:, 2, :, :] - self.mean[2]) / self.std[2] 52 | 53 | h = self.slice1(x) 54 | h_relu1_2 = h 55 | h = self.slice2(h) 56 | h_relu2_2 = h 57 | h = self.slice3(h) 58 | h_relu3_3 = h 59 | h = self.slice4(h) 60 | h_relu4_3 = h 61 | 62 | features_style = [h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3] 63 | features_style = [self.gram_matrix(y) for y in features_style] 64 | 65 | return features_style 66 | 67 | @staticmethod 68 | def gram_matrix(y): 69 | (b, ch, h, w) = y.shape 70 | features = y.view(b, ch, w * h) 71 | features_t = features.transpose(1, 2) 72 | gram = features.bmm(features_t) / (ch * h * w) 73 | return gram 74 | 75 | 76 | -------------------------------------------------------------------------------- /src/gan_control/make_attributes_df.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import argparse 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from tqdm import tqdm 10 | 11 | import sys 12 | from pathlib import Path 13 | _PWD = Path(__file__).absolute().parent 14 | 15 | sys.path.append(str(_PWD.parent)) 16 | sys.path.append(os.path.join(str(_PWD.parent.parent), 'face-alignment')) 17 | 18 | from face_alignment import FaceAlignment, LandmarksType 19 | 20 | from gan_control.trainers.generator_trainer import GeneratorTrainer 21 | from gan_control.evaluation.inference_class import Inference 22 | from gan_control.evaluation.age import calc_age_from_tensor_images 23 | from gan_control.evaluation.orientation import calc_orientation_from_tensor_images 24 | from gan_control.evaluation.expression import calc_expression_from_tensor_images, get_class 25 | from gan_control.evaluation.hair import calc_hair_color_from_images 26 | from gan_control.evaluation.face_alignment_utils.face_alignment_utils import align_tensor_images, load_lm3d 27 | 28 | 29 | @torch.no_grad() 30 | def make_attributes_df(model, trainer, attributes_df_save_path, batch_size=40, number_of_samples=10000, align_3d=True): 31 | attributes_df = pd.DataFrame(columns=['latents', 'latents_w', 'emb', 'age', 'orientation', 'expression_q', 'hair', 'gamma3d', 'expression3d', 'orientation3d']) 32 | lm3D = load_lm3d() 33 | fa = FaceAlignment(LandmarksType._3D, flip_input=False) 34 | if align_3d: 35 | trainer.recon_3d_loss_class.skeleton_model.module.config['center_crop'] = None 36 | trainer.id_embedding_class.skeleton_model.module.config['center_crop'] = None 37 | for batch_num in tqdm(range(number_of_samples // batch_size)): 38 | out, latent, latent_w = model.gen_batch(batch_size=batch_size, normalize=False) 39 | age = calc_age_from_tensor_images(trainer.age_class, out) 40 | orientation = calc_orientation_from_tensor_images(trainer.pose_orientation_class, out) 41 | expression_q = calc_expression_from_tensor_images(trainer.pose_expression_class, out) 42 | hair_color = calc_hair_color_from_images(trainer.hair_loss_class, out) 43 | if align_3d: 44 | out = align_tensor_images(out, fa=fa, lm3D=lm3D) 45 | recon_3d_features = trainer.recon_3d_loss_class.calc_features(out) 46 | else: 47 | recon_3d_features = trainer.recon_3d_loss_class.calc_features(out) 48 | id_futures, ex_futures, tex_futures, angles_futures, gamma_futures, xy_futures, z_futures = trainer.recon_3d_loss_class.skeleton_model.module.extract_futures_from_vec(recon_3d_features) 49 | gamma3d = gamma_futures 50 | expression3d = ex_futures 51 | orientation3d = angles_futures 52 | arcface_emb = trainer.id_embedding_class.calc_features(out)[-1] 53 | 54 | for latent_i, latent_w_i, age_i, yaw, pitch, roll, expression_q_i, hair_i, \ 55 | gamma3d_i, expression3d_i, orientation3d_i, arcface_emb_i in zip( 56 | latent.cpu().split(1), 57 | latent_w.cpu().split(1), 58 | age.cpu().split(1), 59 | orientation[0].cpu().split(1), 60 | orientation[1].cpu().split(1), 61 | orientation[2].cpu().split(1), 62 | expression_q, 63 | hair_color.cpu().split(1), 64 | gamma3d[0].cpu().split(1), 65 | expression3d[0].cpu().split(1), 66 | orientation3d[0].cpu().split(1), 67 | arcface_emb.cpu().split(1) 68 | ): 69 | df_entry = { 70 | 'latents': latent_i[0].numpy(), 71 | 'latents_w': latent_w_i[0][0].numpy(), 72 | 'age': age_i[0].item(), 73 | 'orientation': np.array([yaw[0].item(), pitch[0].item(), roll[0].item()]), 74 | 'expression_q': expression_q_i, 75 | 'hair': hair_i[0].numpy(), 76 | 'gamma3d': gamma3d_i[0].numpy(), 77 | 'expression3d': expression3d_i[0].numpy(), 78 | 'orientation3d': orientation3d_i[0].numpy(), 79 | 'arcface_emb': arcface_emb_i[0].numpy() 80 | } 81 | attributes_df = attributes_df.append(df_entry, ignore_index=True) 82 | 83 | if len(attributes_df.latents) % 50000 == 0: 84 | os.makedirs(os.path.split(attributes_df_save_path)[0], exist_ok=True) 85 | attributes_df.to_pickle(attributes_df_save_path) 86 | 87 | os.makedirs(os.path.split(attributes_df_save_path)[0], exist_ok=True) 88 | attributes_df.to_pickle(attributes_df_save_path) 89 | return attributes_df 90 | 91 | 92 | if __name__ == '__main__': 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('--model_dir', type=str, default='path to gan model dir') 95 | parser.add_argument('--trainer_model', type=str, default='same_as_model_dir') 96 | parser.add_argument('--batch_size', type=int, default=40) 97 | parser.add_argument('--number_of_samples', type=int, default=100000) 98 | parser.add_argument('--save_path', type=str, default='path to save dir') 99 | args = parser.parse_args() 100 | 101 | model = Inference(args.model_dir) 102 | if args.trainer_model == 'same_as_model_dir': 103 | args.trainer_model = args.model_dir 104 | config_path = os.path.join(args.trainer_model, 'args.json') 105 | trainer = GeneratorTrainer(config_path, init_dirs=False) 106 | make_attributes_df(model, trainer, args.save_path, batch_size=args.batch_size, number_of_samples=args.number_of_samples) -------------------------------------------------------------------------------- /src/gan_control/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/models/__init__.py -------------------------------------------------------------------------------- /src/gan_control/models/controller_model.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from gan_control.models.gan_model import PixelNorm, EqualLinear 8 | from gan_control.utils.logging_utils import get_logger 9 | 10 | _log = get_logger(__name__) 11 | 12 | 13 | class FcStack(nn.Module): 14 | def __init__(self, lr_mlp, n_mlp, in_dim, mid_dim, out_dim): 15 | super(FcStack, self).__init__() 16 | self.lr_mlp = lr_mlp 17 | self.n_mlp = n_mlp 18 | self.in_dim = in_dim 19 | self.mid_dim = mid_dim 20 | self.out_dim = out_dim 21 | self.fc_stack = self.create_input_middle_output_fc_stack(lr_mlp, n_mlp, in_dim, mid_dim, out_dim) 22 | 23 | @staticmethod 24 | def create_input_middle_output_fc_stack(lr_mlp, n_mlp, in_dim, mid_dim, out_dim): 25 | mid_dim = mid_dim if mid_dim is not None else mid_dim 26 | layers = [] 27 | for i in range(n_mlp): 28 | s_dim0 = mid_dim 29 | s_dim1 = mid_dim 30 | if i == 0: 31 | s_dim0 = in_dim 32 | elif i == n_mlp - 1: 33 | s_dim1 = out_dim 34 | elif i < n_mlp - 1: 35 | pass 36 | else: 37 | raise ValueError('debug') 38 | layers.append( 39 | EqualLinear( 40 | s_dim0, s_dim1, lr_mul=lr_mlp, activation='fused_lrelu' 41 | ) 42 | ) 43 | return nn.Sequential(*layers) 44 | 45 | def print(self): 46 | text = 'FcStack:\n' 47 | text += 'input dim: %d, middle dim:%d, output dim: %d\n' % (self.in_dim, self.mid_dim, self.out_dim) 48 | text += 'num of layers: %d\n' % (self.n_mlp) 49 | text += 'lr_mlp: %d' % (self.lr_mlp) 50 | _log.info(text) 51 | 52 | def forward(self, x): 53 | return self.fc_stack(x) -------------------------------------------------------------------------------- /src/gan_control/models/pytorch_upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | def upfirdn2d_native(input, kernel, up, down, pad): 10 | up_x, up_y = up[0], up[1] 11 | down_x, down_y = down[0], down[1] 12 | pad_x0, pad_x1, pad_y0, pad_y1 = pad[0], pad[1], pad[2], pad[3] 13 | _, channel, in_h, in_w = input.shape 14 | input = input.reshape(-1, in_h, in_w, 1) 15 | 16 | _, in_h, in_w, minor = input.shape 17 | kernel_h, kernel_w = kernel.shape 18 | 19 | out = input.view(-1, in_h, 1, in_w, 1, minor) 20 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 21 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 22 | 23 | out = F.pad( 24 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 25 | ) 26 | out = out[ 27 | :, 28 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 29 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 30 | :, 31 | ] 32 | 33 | out = out.permute(0, 3, 1, 2) 34 | out = out.reshape( 35 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 36 | ) 37 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 38 | out = F.conv2d(out, w) 39 | out = out.reshape( 40 | -1, 41 | minor, 42 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 43 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 44 | ) 45 | out = out.permute(0, 2, 3, 1) 46 | out = out[:, ::down_y, ::down_x, :] 47 | 48 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 49 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 50 | 51 | return out.view(-1, channel, out_h, out_w) -------------------------------------------------------------------------------- /src/gan_control/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | Put pretrained models here. 2 | 3 | -------------------------------------------------------------------------------- /src/gan_control/projection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/projection/__init__.py -------------------------------------------------------------------------------- /src/gan_control/projection/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /src/gan_control/projection/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/projection/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /src/gan_control/projection/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/projection/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /src/gan_control/projection/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/projection/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /src/gan_control/projection/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/projection/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /src/gan_control/projection/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/projection/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /src/gan_control/projection/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/projection/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /src/gan_control/projection/projection.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import math 6 | import os 7 | import pathlib 8 | import sys 9 | import time 10 | from typing import List 11 | 12 | import matplotlib.pyplot as plt 13 | import streamlit as st 14 | import torch 15 | from torch import nn 16 | from torch import optim 17 | from torch.nn import functional as F 18 | from PIL import Image 19 | from tqdm import tqdm 20 | from torchvision import transforms 21 | from torchvision import utils 22 | import numpy as np 23 | import subprocess as sp, shlex 24 | from sklearn.decomposition import PCA 25 | 26 | # from orville_conditional_gan.image_generation.image_generator import save_image 27 | 28 | from gan_control.inference.controller import Controller 29 | from gan_control.projection.lpips.lpips import PerceptualLoss 30 | from gan_control.utils.logging_utils import get_logger 31 | 32 | _log = get_logger(__name__) 33 | 34 | 35 | def tensor_to_numpy_img(tensor: torch.Tensor) -> np.array: 36 | tensor_tmp = tensor.mul(0.5).add(0.5).clamp(min=0., max=1.).cpu().detach() 37 | if tensor.ndimension == 3: 38 | tensor_tmp = tensor_tmp.premute(0,1,2) 39 | elif tensor.ndimension == 4: 40 | tensor_tmp = tensor_tmp.premute(0, 2, 3, 1) 41 | return tensor_tmp.numpy() 42 | 43 | 44 | @torch.no_grad() 45 | def get_pca_groups(controller, latent_mean, n_mean_latent, device): 46 | if isinstance(controller.model, nn.DataParallel): 47 | model = controller.model.module 48 | else: 49 | model = controller.model 50 | with torch.no_grad(): 51 | noise_sample = torch.randn(n_mean_latent, 512, device=device) 52 | latent_out = model.style(noise_sample) 53 | latent_out = latent_out.detach().cpu().numpy() 54 | latent_out = latent_out - latent_mean.cpu().numpy() 55 | variance_percent = 0.5 56 | pca_weights = {} 57 | 58 | for group in controller.fc_controls.keys(): 59 | if group == 'expression_q': 60 | continue 61 | pca = PCA() 62 | group_latent = controller.get_group_w_latent(latent_out, group) 63 | pca.fit(group_latent) 64 | idx_variance_percent = np.argmax(np.cumsum(pca.explained_variance_) / np.sum(pca.explained_variance_) > variance_percent) 65 | _log.info('%s PCA components: %s' % (group, str(idx_variance_percent))) 66 | pca_weight = pca.components_[:(idx_variance_percent+1), :] 67 | pca_weight = torch.from_numpy(pca_weight).cuda() 68 | pca_weights[group] = pca_weight 69 | return pca_weights 70 | 71 | 72 | def plot_figures(lrs, mse_losses, n_losses, p_losses, axes=None): 73 | if axes is None: 74 | fig, axes = plt.subplots(2, 2) 75 | fig.tight_layout() 76 | axes[0, 0].plot(np.arange(len(p_losses)), p_losses) 77 | axes[0, 0].set_title('Perceptual Loss') 78 | axes[0, 0].set_yscale('log') 79 | axes[0, 1].plot(np.arange(len(n_losses)), n_losses) 80 | axes[0, 1].set_title('Noise Loss') 81 | axes[0, 1].set_yscale('log') 82 | axes[1, 0].plot(np.arange(len(mse_losses)), mse_losses) 83 | axes[1, 0].set_title('MSE Loss') 84 | axes[1, 0].set_yscale('log') 85 | axes[1, 1].plot(np.arange(len(lrs)), lrs) 86 | axes[1, 1].set_title('Learning Rate') 87 | 88 | 89 | def load_source_images(images: List[str], res=256): 90 | transform = transforms.Compose( 91 | [ 92 | transforms.Resize(res), 93 | transforms.CenterCrop(res), 94 | transforms.ToTensor(), 95 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 96 | ] 97 | ) 98 | source_tensors = [] 99 | for image_name in images: 100 | loaded_image = Image.open(image_name).convert('RGB') 101 | source_tensors.append(transform(loaded_image).unsqueeze(dim=0)) 102 | source_tensors = torch.cat(source_tensors, dim=0) 103 | return source_tensors 104 | 105 | 106 | def merge_group_latents(controller, group, latent_in): 107 | group_latent = controller.get_group_w_latent(latent_in.data, group) 108 | controller.insert_group_w_latent(latent_in.data, group_latent.data.mean(dim=1).unsqueeze(1), group) 109 | 110 | 111 | def get_avg_latent(model, n_mean_latent, device): 112 | if isinstance(model, nn.DataParallel): 113 | model = model.module 114 | 115 | with torch.no_grad(): 116 | noise_sample = torch.randn(n_mean_latent, 512, device=device) 117 | latent_out = model.style(noise_sample) 118 | 119 | latent_mean = latent_out.mean(0) 120 | latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 121 | latent_mean = latent_mean.cpu().numpy() 122 | latent_std = latent_std.cpu().numpy() 123 | return latent_mean, latent_std 124 | 125 | 126 | def noise_regularize(noises): 127 | loss = 0 128 | 129 | for noise in noises: 130 | size = noise.shape[2] 131 | 132 | while True: 133 | loss = ( 134 | loss 135 | + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) 136 | + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) 137 | ) 138 | 139 | if size <= 8: 140 | break 141 | 142 | noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) 143 | noise = noise.mean([3, 5]) 144 | size //= 2 145 | 146 | return loss 147 | 148 | 149 | def noise_normalize_(noises): 150 | for noise in noises: 151 | mean = noise.mean() 152 | std = noise.std() 153 | 154 | noise.data.add_(-mean).div_(std) 155 | 156 | 157 | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): 158 | lr_ramp = min(1, (1 - t) / rampdown) 159 | lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) 160 | lr_ramp = lr_ramp * min(1, t / rampup) 161 | 162 | return initial_lr * lr_ramp 163 | 164 | 165 | def latent_noise(latent, strength): 166 | noise = torch.randn_like(latent) * strength 167 | 168 | return latent + noise 169 | 170 | 171 | def make_image(tensor): 172 | return ( 173 | tensor.detach() 174 | .clamp_(min=-1, max=1) 175 | .add(1) 176 | .div_(2) 177 | .mul(255) 178 | .type(torch.uint8) 179 | .permute(0, 2, 3, 1) 180 | .to("cpu") 181 | .numpy() 182 | ) 183 | -------------------------------------------------------------------------------- /src/gan_control/train_controller.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import sys 5 | import os 6 | from pathlib import Path 7 | _PWD = Path(__file__).absolute().parent 8 | sys.path.append(str(_PWD.parent)) 9 | sys.path.append(os.path.join(str(_PWD.parent.parent), 'face-alignment')) 10 | 11 | import argparse 12 | from gan_control.trainers.controller_trainer import ControllerTrainer 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--config_path', type=str, required=True) 17 | 18 | config_path = parser.parse_args().config_path 19 | trainer = ControllerTrainer(config_path) 20 | trainer.train() 21 | 22 | # python -m train_controller --config_path configs512/id_orientation_expression_512.json -------------------------------------------------------------------------------- /src/gan_control/train_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import sys 5 | from pathlib import Path 6 | _PWD = Path(__file__).absolute().parent 7 | sys.path.append(str(_PWD.parent)) 8 | 9 | import argparse 10 | from gan_control.trainers.generator_trainer import GeneratorTrainer 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--config_path', type=str, required=True) 15 | 16 | config_path = parser.parse_args().config_path 17 | trainer = GeneratorTrainer(config_path) 18 | trainer.dry_run() 19 | trainer.train() 20 | 21 | -------------------------------------------------------------------------------- /src/gan_control/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/trainers/__init__.py -------------------------------------------------------------------------------- /src/gan_control/trainers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | import random 6 | 7 | 8 | def accumulate(model1, model2, decay=0.999): 9 | par1 = dict(model1.named_parameters()) 10 | par2 = dict(model2.named_parameters()) 11 | for k in par1.keys(): 12 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 13 | 14 | def requires_grad(model, flag=True): 15 | for p in model.parameters(): 16 | p.requires_grad = flag 17 | 18 | 19 | def mixing_noise(batch, latent_dim, prob, device): 20 | if prob > 0 and random.random() < prob: 21 | return make_noise(batch, latent_dim, 2, device) 22 | else: 23 | return [make_noise(batch, latent_dim, 1, device)] 24 | 25 | 26 | def make_noise(batch, latent_dim, n_noise, device): 27 | if n_noise == 1: 28 | return torch.randn(batch, latent_dim, device=device) 29 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) 30 | return noises 31 | 32 | 33 | def make_mini_batch_from_noise(noise, batch, mini_batch): 34 | noise_mini_batch = [] 35 | chunks = [] 36 | for i in range(len(noise)): 37 | chunks.append(noise[i].chunk(batch // mini_batch)) 38 | for i in range(len(chunks[0])): 39 | noise_mini_batch.append([]) 40 | for j in range(len(noise)): 41 | noise_mini_batch[i].append(chunks[j][i]) 42 | return noise_mini_batch 43 | 44 | 45 | def set_grad_none(model, targets): 46 | for n, p in model.named_parameters(): 47 | if n in targets: 48 | p.grad = None 49 | -------------------------------------------------------------------------------- /src/gan_control/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gan-control/bf56937c497146ce594567b570c07c0eaec259bf/src/gan_control/utils/__init__.py -------------------------------------------------------------------------------- /src/gan_control/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | from datetime import datetime 6 | import json 7 | 8 | 9 | class DefaultObj(object): 10 | def __init__(self, dict): 11 | self.__dict__ = dict 12 | 13 | 14 | def read_json(path, return_obj=False): 15 | with open(path) as json_file: 16 | data = json.load(json_file) 17 | if return_obj: 18 | data = DefaultObj(data) 19 | return data 20 | 21 | 22 | def write_json(data_dict, path): 23 | with open(path, 'w') as outfile: 24 | json.dump(data_dict, outfile) 25 | 26 | 27 | def setup_logging_from_args(args): 28 | """ 29 | Calls setup_logging, exports args and creates a ResultsLog class. 30 | Can resume training/logging if args.resume is set 31 | """ 32 | def set_args_default(field_name, value): 33 | if hasattr(args, field_name): 34 | return eval('args.' + field_name) 35 | else: 36 | return value 37 | 38 | # Set default args in case they don't exist in args 39 | # resume = set_args_default('resume', False) 40 | args.save_name = f"{args.save_name}_{datetime.now().strftime('%Y%m%d-%H%M%S')}" 41 | results_dir = set_args_default('results_dir', './results') 42 | 43 | save_path = os.path.join(results_dir, args.save_name) 44 | os.makedirs(save_path, exist_ok=True) 45 | # log_file = os.path.join(save_path, 'log.txt') 46 | 47 | export_args(args, save_path) 48 | return save_path 49 | 50 | 51 | def export_args(args, save_path): 52 | """ 53 | args: argparse.Namespace 54 | arguments to save 55 | save_path: string 56 | path to directory to save at 57 | """ 58 | os.makedirs(save_path, exist_ok=True) 59 | json_file_name = os.path.join(save_path, 'args.json') 60 | with open(json_file_name, 'w') as fp: 61 | json.dump(args.__dict__, fp, sort_keys=True, indent=4) 62 | 63 | -------------------------------------------------------------------------------- /src/gan_control/utils/hopenet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import numpy as np 5 | import torch 6 | import os 7 | import scipy.io as sio 8 | import cv2 9 | import math 10 | from math import cos, sin 11 | from PIL import Image, ImageDraw 12 | 13 | from gan_control.utils.logging_utils import get_logger 14 | 15 | _log = get_logger(__name__) 16 | 17 | 18 | def softmax_temperature(tensor, temperature): 19 | result = torch.exp(tensor / temperature) 20 | result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result)) 21 | return result 22 | 23 | 24 | def get_pose_params_from_mat(mat_path): 25 | # This functions gets the pose parameters from the .mat 26 | # Annotations that come with the Pose_300W_LP dataset. 27 | mat = sio.loadmat(mat_path) 28 | # [pitch yaw roll tdx tdy tdz scale_factor] 29 | pre_pose_params = mat['Pose_Para'][0] 30 | # Get [pitch, yaw, roll, tdx, tdy] 31 | pose_params = pre_pose_params[:5] 32 | return pose_params 33 | 34 | 35 | def get_ypr_from_mat(mat_path): 36 | # Get yaw, pitch, roll from .mat annotation. 37 | # They are in radians 38 | mat = sio.loadmat(mat_path) 39 | # [pitch yaw roll tdx tdy tdz scale_factor] 40 | pre_pose_params = mat['Pose_Para'][0] 41 | # Get [pitch, yaw, roll] 42 | pose_params = pre_pose_params[:3] 43 | return pose_params 44 | 45 | 46 | def get_pt2d_from_mat(mat_path): 47 | # Get 2D landmarks 48 | mat = sio.loadmat(mat_path) 49 | pt2d = mat['pt2d'] 50 | return pt2d 51 | 52 | 53 | def mse_loss(input, target): 54 | return torch.sum(torch.abs(input.data - target.data) ** 2) 55 | 56 | 57 | def plot_pose_cube(img, yaw, pitch, roll, tdx=None, tdy=None, size=150.): 58 | # Input is a cv2 image 59 | # pose_params: (pitch, yaw, roll, tdx, tdy) 60 | # Where (tdx, tdy) is the translation of the face. 61 | # For pose we have [pitch yaw roll tdx tdy tdz scale_factor] 62 | 63 | p = pitch * np.pi / 180 64 | y = -(yaw * np.pi / 180) 65 | r = roll * np.pi / 180 66 | if tdx != None and tdy != None: 67 | face_x = tdx - 0.50 * size 68 | face_y = tdy - 0.50 * size 69 | else: 70 | height, width = img.shape[:2] 71 | face_x = width / 2 - 0.5 * size 72 | face_y = height / 2 - 0.5 * size 73 | 74 | x1 = size * (cos(y) * cos(r)) + face_x 75 | y1 = size * (cos(p) * sin(r) + cos(r) * sin(p) * sin(y)) + face_y 76 | x2 = size * (-cos(y) * sin(r)) + face_x 77 | y2 = size * (cos(p) * cos(r) - sin(p) * sin(y) * sin(r)) + face_y 78 | x3 = size * (sin(y)) + face_x 79 | y3 = size * (-cos(y) * sin(p)) + face_y 80 | 81 | # Draw base in red 82 | cv2.line(img, (int(face_x), int(face_y)), (int(x1),int(y1)),(0,0,255),3) 83 | cv2.line(img, (int(face_x), int(face_y)), (int(x2),int(y2)),(0,0,255),3) 84 | cv2.line(img, (int(x2), int(y2)), (int(x2+x1-face_x),int(y2+y1-face_y)),(0,0,255),3) 85 | cv2.line(img, (int(x1), int(y1)), (int(x1+x2-face_x),int(y1+y2-face_y)),(0,0,255),3) 86 | # Draw pillars in blue 87 | cv2.line(img, (int(face_x), int(face_y)), (int(x3),int(y3)),(255,0,0),2) 88 | cv2.line(img, (int(x1), int(y1)), (int(x1+x3-face_x),int(y1+y3-face_y)),(255,0,0),2) 89 | cv2.line(img, (int(x2), int(y2)), (int(x2+x3-face_x),int(y2+y3-face_y)),(255,0,0),2) 90 | cv2.line(img, (int(x2+x1-face_x),int(y2+y1-face_y)), (int(x3+x1+x2-2*face_x),int(y3+y2+y1-2*face_y)),(255,0,0),2) 91 | # Draw top in green 92 | cv2.line(img, (int(x3+x1-face_x),int(y3+y1-face_y)), (int(x3+x1+x2-2*face_x),int(y3+y2+y1-2*face_y)),(0,255,0),2) 93 | cv2.line(img, (int(x2+x3-face_x),int(y2+y3-face_y)), (int(x3+x1+x2-2*face_x),int(y3+y2+y1-2*face_y)),(0,255,0),2) 94 | cv2.line(img, (int(x3), int(y3)), (int(x3+x1-face_x),int(y3+y1-face_y)),(0,255,0),2) 95 | cv2.line(img, (int(x3), int(y3)), (int(x3+x2-face_x),int(y3+y2-face_y)),(0,255,0),2) 96 | 97 | return img 98 | 99 | 100 | def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size=200, radians=False): 101 | 102 | if not radians: 103 | pitch = pitch * np.pi / 180 104 | yaw = -(yaw * np.pi / 180) 105 | roll = roll * np.pi / 180 106 | 107 | if tdx != None and tdy != None: 108 | tdx = tdx 109 | tdy = tdy 110 | else: 111 | width, height = img.size 112 | tdx = width / 2 113 | tdy = height / 2 114 | 115 | # X-Axis pointing to right. drawn in red 116 | x1 = size * (cos(yaw) * cos(roll)) + tdx 117 | y1 = size * (cos(pitch) * sin(roll) + cos(roll) * sin(pitch) * sin(yaw)) + tdy 118 | 119 | # Y-Axis | drawn in green 120 | # v 121 | x2 = size * (-cos(yaw) * sin(roll)) + tdx 122 | y2 = size * (cos(pitch) * cos(roll) - sin(pitch) * sin(yaw) * sin(roll)) + tdy 123 | 124 | # Z-Axis (out of the screen) drawn in blue 125 | x3 = size * (sin(yaw)) + tdx 126 | y3 = size * (-cos(yaw) * sin(pitch)) + tdy 127 | 128 | tdx = 0 if tdx is None else tdx 129 | tdy = 0 if tdy is None else tdy 130 | x1 = 0 if x1 is None else x1 131 | x2 = 0 if x2 is None else x2 132 | x3 = 0 if x3 is None else x3 133 | y1 = 0 if y1 is None else y1 134 | y2 = 0 if y2 is None else y2 135 | y3 = 0 if y3 is None else y3 136 | try: 137 | draw = ImageDraw.Draw(img) 138 | draw.line(((int(tdx), int(tdy)), (int(x1),int(y1))), fill=(0, 0, 255), width=5) 139 | draw.line(((int(tdx), int(tdy)), (int(x2),int(y2))), fill=(0, 255, 0), width=5) 140 | draw.line(((int(tdx), int(tdy)), (int(x3),int(y3))), fill=(255, 0, 0), width=5) 141 | except: 142 | _log.warning('there was a problem drawing a line') 143 | 144 | return img 145 | -------------------------------------------------------------------------------- /src/gan_control/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | def get_logger(name): 5 | import logging 6 | log = logging.getLogger(name) 7 | log.setLevel(logging.INFO) 8 | ch = logging.StreamHandler() 9 | ch.setLevel(logging.INFO) 10 | formatter = logging.Formatter('%(levelname)s:%(name)s: %(message)s') 11 | ch.setFormatter(formatter) 12 | log.addHandler(ch) 13 | return log -------------------------------------------------------------------------------- /src/gan_control/utils/pandas_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pandas as pd 5 | 6 | 7 | def get_kmin(main_df, column_name='distance', k=5): 8 | main_df = main_df.copy(deep=True) 9 | kmin_df = pd.DataFrame(columns=list(main_df.columns)) 10 | for i in range(k): 11 | min_index = main_df[column_name].idxmin() 12 | kmin_df = kmin_df.append(main_df.iloc[min_index].copy(deep=True), ignore_index=True) 13 | main_df.drop(min_index, inplace=True) 14 | main_df.reset_index(inplace=True, drop=True) 15 | return kmin_df -------------------------------------------------------------------------------- /src/gan_control/utils/pil_images_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | import numpy as np 6 | import subprocess as sp, shlex 7 | from torchvision import transforms, utils 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def create_gif(save_dir, delay=50, name='sample'): 12 | sp.run(shlex.split(f'convert -delay {delay} -loop 0 -resize 3079x1027 *.jpg /mnt/md4/orville/Alon/gifs_8_6/{name}.gif'), cwd=save_dir) 13 | # sp.run('cd %s;cp sample.gif /mnt/md4/orville/Alon/gifs/%s.gif' % (save_dir, name)) 14 | # sp.run(shlex.split(f'convert -delay {delay} -loop 0 *.jpg sample.gif'), cwd=save_dir) 15 | 16 | 17 | def get_concat_h(im1, im2): 18 | dst = Image.new('RGB', (im1.width + im2.width, im1.height)) 19 | dst.paste(im1, (0, 0)) 20 | dst.paste(im2, (im1.width, 0)) 21 | return dst 22 | 23 | 24 | def get_concat_v(im1, im2): 25 | dst = Image.new('RGB', (im1.width, im1.height + im2.height)) 26 | dst.paste(im1, (0, 0)) 27 | dst.paste(im2, (0, im1.height)) 28 | return dst 29 | 30 | 31 | def fig2data(fig): 32 | """ 33 | @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it 34 | @param fig a matplotlib figure 35 | @return a numpy 3D array of RGBA values 36 | """ 37 | # draw the renderer 38 | fig.canvas.draw() 39 | 40 | # Get the RGBA buffer from the figure 41 | w, h = fig.canvas.get_width_height() 42 | buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) 43 | buf.shape = (w, h, 4) 44 | 45 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode 46 | buf = np.roll(buf, 3, axis=2) 47 | return buf 48 | 49 | 50 | def fig2img ( fig ): 51 | """ 52 | @brief Convert a Matplotlib figure to a PIL Image in RGBA format and return it 53 | @param fig a matplotlib figure 54 | @return a Python Imaging Library ( PIL ) image 55 | """ 56 | # put the figure pixmap into a numpy array 57 | buf = fig2data ( fig ) 58 | w, h, d = buf.shape 59 | return Image.frombytes( "RGBA", ( w ,h ), buf.tostring( ) ) 60 | 61 | 62 | def create_image_grid_from_image_list(images, nrow=6): 63 | to_tensor = transforms.ToTensor() 64 | tensors = [to_tensor(images[i]).unsqueeze(0) for i in range(len(images))] 65 | tensors = torch.cat(tensors, dim=0) 66 | tensor_grid = utils.make_grid(tensors, nrow=nrow) 67 | return transforms.ToPILImage()(tensor_grid) 68 | 69 | 70 | def write_text_to_image(image, text, size=36, place=(10, 10)): 71 | d = ImageDraw.Draw(image) 72 | d.text(place, text, fill=(255, 255, 0), font=ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', size)) 73 | return image 74 | -------------------------------------------------------------------------------- /src/gan_control/utils/spherical_harmonics_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def sh_eval_basis_1( 5 | x, y, z, 6 | p_0_0=0.282094791773878140, 7 | p_1_0=0.488602511902919920, 8 | p_1_1=-0.488602511902919920 9 | ): 10 | b = np.zeros(27) 11 | b[0::9] = p_0_0 # l=0,m=0 12 | b[2::9] = p_1_0 * z # l=1,m=0 13 | b[1::9] = p_1_1 * y # l=1,m=-1 14 | b[3::9] = p_1_1 * x # l=1,m=+1 15 | return b 16 | 17 | 18 | def sh_eval_basis_2( 19 | x, y, z, 20 | p_0_0=0.282094791773878140, 21 | p_1_0=0.488602511902919920, 22 | pp_2_0=0.946174695757560080, 23 | mp_2_0=-0.315391565252520050, 24 | p_1_1=-0.488602511902919920, 25 | p_2_1=-1.092548430592079200, 26 | p_2_2=0.546274215296039590 27 | ): 28 | b = np.zeros(27) 29 | b[0::9] = p_0_0 # l=0,m=0 30 | 31 | b[2::9] = p_1_0 # l=1,m=0 32 | b[6::9] = (pp_2_0 * z * z) + mp_2_0 33 | 34 | b[1::9] = p_1_1 * y # l=1,m=-1 35 | b[3::9] = p_1_1 * x # l=1,m=+1 36 | 37 | b[5] = p_2_1 * z * y # l=2,m=-1 38 | b[7] = p_2_1 * z * x # l=2,m=+1 39 | 40 | b[4] = p_2_2 * (x * y + y * x) # l=2,m=-2 41 | b[8] = p_2_2 * (y * y + x * x) # l=2,m=+2 42 | return b -------------------------------------------------------------------------------- /src/gan_control/utils/tensor_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | def center_crop_tensor(tensor, crop_size): 5 | b, c, h, w = tensor.shape 6 | up = (h - crop_size) // 2 7 | left = (w - crop_size) // 2 8 | tensor = tensor[:, :, up:up+crop_size, left:left+crop_size] 9 | return tensor 10 | --------------------------------------------------------------------------------