├── 3dswap_c_vision_experiments.ipynb
├── README.md
├── configs
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── data_configs.cpython-37.pyc
│ ├── data_configs.cpython-38.pyc
│ ├── data_configs.cpython-39.pyc
│ ├── global_config.cpython-36.pyc
│ ├── global_config.cpython-38.pyc
│ ├── global_config.cpython-39.pyc
│ ├── hyperparameters.cpython-36.pyc
│ ├── hyperparameters.cpython-38.pyc
│ ├── hyperparameters.cpython-39.pyc
│ ├── paths_config.cpython-36.pyc
│ ├── paths_config.cpython-37.pyc
│ ├── paths_config.cpython-38.pyc
│ ├── paths_config.cpython-39.pyc
│ ├── transforms_config.cpython-37.pyc
│ ├── transforms_config.cpython-38.pyc
│ └── transforms_config.cpython-39.pyc
├── data_configs.py
├── evaluation_config.py
├── global_config.py
├── hyperparameters.py
├── paths_config.py
└── transforms_config.py
├── datasets
├── CelebA-HD
│ ├── camera_pose
│ │ ├── 0.npy
│ │ ├── 1.npy
│ │ ├── 2.npy
│ │ ├── 3.npy
│ │ ├── 4.npy
│ │ ├── 5.npy
│ │ ├── 6.npy
│ │ ├── 7.npy
│ │ ├── 8.npy
│ │ └── 9.npy
│ └── final_crops
│ │ ├── 0.jpg
│ │ ├── 1.jpg
│ │ ├── 2.jpg
│ │ ├── 3.jpg
│ │ ├── 4.jpg
│ │ ├── 5.jpg
│ │ ├── 6.jpg
│ │ ├── 7.jpg
│ │ ├── 8.jpg
│ │ └── 9.jpg
├── __pycache__
│ ├── augmentations.cpython-38.pyc
│ └── images_dataset.cpython-38.pyc
├── augmentations.py
└── images_dataset.py
├── dnnlib
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── util.cpython-37.pyc
│ ├── util.cpython-38.pyc
│ └── util.cpython-39.pyc
└── util.py
├── eg3d_c_vision_experiments.ipynb
├── images
└── teaser.png
├── models
├── __pycache__
│ ├── discriminator.cpython-38.pyc
│ ├── faceswap_coach.cpython-38.pyc
│ ├── id_loss.cpython-38.pyc
│ ├── inversion_coach.cpython-38.pyc
│ ├── networks.cpython-38.pyc
│ ├── psp.cpython-38.pyc
│ ├── psp.cpython-39.pyc
│ └── w_norm.cpython-38.pyc
├── discriminator.py
├── encoders
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── encoder128.cpython-37.pyc
│ │ ├── helpers.cpython-37.pyc
│ │ ├── helpers.cpython-38.pyc
│ │ ├── helpers.cpython-39.pyc
│ │ ├── model_irse.cpython-37.pyc
│ │ ├── model_irse.cpython-38.pyc
│ │ ├── model_irse.cpython-39.pyc
│ │ ├── psp_encoders.cpython-37.pyc
│ │ ├── psp_encoders.cpython-38.pyc
│ │ └── psp_encoders.cpython-39.pyc
│ ├── encoder128.py
│ ├── helpers.py
│ ├── model_irse.py
│ └── psp_encoders.py
├── faceswap_coach.py
├── id_loss.py
├── inversion_coach.py
├── networks.py
├── psp.py
├── stylegan2
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── model.cpython-37.pyc
│ │ ├── model.cpython-38.pyc
│ │ └── model.cpython-39.pyc
│ ├── model.py
│ └── op
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── fused_act.cpython-37.pyc
│ │ ├── fused_act.cpython-38.pyc
│ │ ├── fused_act.cpython-39.pyc
│ │ ├── upfirdn2d.cpython-37.pyc
│ │ ├── upfirdn2d.cpython-38.pyc
│ │ └── upfirdn2d.cpython-39.pyc
│ │ ├── fused_act.py
│ │ ├── fused_bias_act.cpp
│ │ ├── fused_bias_act_kernel.cu
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
└── w_norm.py
├── options
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── train_options.cpython-37.pyc
│ ├── train_options.cpython-38.pyc
│ └── train_options.cpython-39.pyc
├── test_options.py
└── train_options.py
├── requirements.txt
├── run_3dSwap.py
├── run_inversion.py
├── torch_utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── custom_ops.cpython-38.pyc
│ ├── custom_ops.cpython-39.pyc
│ ├── misc.cpython-37.pyc
│ ├── misc.cpython-38.pyc
│ ├── misc.cpython-39.pyc
│ ├── persistence.cpython-38.pyc
│ ├── persistence.cpython-39.pyc
│ └── training_stats.cpython-39.pyc
├── custom_ops.py
├── misc.py
├── ops
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── bias_act.cpython-38.pyc
│ │ ├── bias_act.cpython-39.pyc
│ │ ├── conv2d_gradfix.cpython-38.pyc
│ │ ├── conv2d_gradfix.cpython-39.pyc
│ │ ├── conv2d_resample.cpython-38.pyc
│ │ ├── conv2d_resample.cpython-39.pyc
│ │ ├── filtered_lrelu.cpython-38.pyc
│ │ ├── filtered_lrelu.cpython-39.pyc
│ │ ├── fma.cpython-38.pyc
│ │ ├── fma.cpython-39.pyc
│ │ ├── upfirdn2d.cpython-38.pyc
│ │ └── upfirdn2d.cpython-39.pyc
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── filtered_lrelu.cpp
│ ├── filtered_lrelu.cu
│ ├── filtered_lrelu.h
│ ├── filtered_lrelu.py
│ ├── filtered_lrelu_ns.cu
│ ├── filtered_lrelu_rd.cu
│ ├── filtered_lrelu_wr.cu
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
├── persistence.py
└── training_stats.py
├── train_faceswap.py
├── train_inversion.py
├── training
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── networks_stylegan2.cpython-38.pyc
│ ├── networks_stylegan3.cpython-38.pyc
│ ├── ranger.cpython-38.pyc
│ ├── superresolution.cpython-38.pyc
│ └── triplane.cpython-38.pyc
├── augment.py
├── crosssection_utils.py
├── dataset.py
├── dual_discriminator.py
├── loss.py
├── networks_stylegan2.py
├── networks_stylegan3.py
├── projectors
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── w_plus_projector.cpython-38.pyc
│ │ ├── w_plus_projector.cpython-39.pyc
│ │ ├── w_projector.cpython-36.pyc
│ │ ├── w_projector.cpython-38.pyc
│ │ └── w_projector.cpython-39.pyc
│ ├── w_plus_projector.py
│ └── w_projector.py
├── ranger.py
├── superresolution.py
├── training_loop.py
├── triplane.py
└── volumetric_rendering
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── math_utils.cpython-38.pyc
│ ├── math_utils.cpython-39.pyc
│ ├── ray_marcher.cpython-38.pyc
│ ├── ray_marcher.cpython-39.pyc
│ ├── ray_sampler.cpython-38.pyc
│ ├── ray_sampler.cpython-39.pyc
│ ├── renderer.cpython-38.pyc
│ └── renderer.cpython-39.pyc
│ ├── math_utils.py
│ ├── ray_marcher.py
│ ├── ray_sampler.py
│ └── renderer.py
└── utils
├── __pycache__
├── camera_utils.cpython-38.pyc
├── data_utils.cpython-38.pyc
└── legacy.cpython-38.pyc
├── camera_utils.py
├── data_utils.py
├── legacy.py
└── shape_utils.py
/3dswap_c_vision_experiments.ipynb:
--------------------------------------------------------------------------------
1 | {"cells":[{"cell_type":"markdown","source":["Colab notebook for 3dSwap by Phillip T. Chananda, Credit to the developer Yixuan Li https://github.com/lyx0208"],"metadata":{"id":"Ski0tLwGDf6f"}},{"cell_type":"markdown","metadata":{"id":"km3VVxGwBAF0"},"source":["DATASET PREPROCESSING STEPS\n","\n","1. First clone 3dSwap\n","2. Go to eg3d notebook clone the repo and place your dataset in the dataset path in eg3d and edit dataset_name\n","3. Then run the rest of the eg3d code.\n","4. Then return to 3dSwap and then edit the dataroot and finally run this notebook."]},{"cell_type":"markdown","source":[],"metadata":{"id":"4DP-1rGgFTpZ"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"pvc3IQsdmKCz"},"outputs":[],"source":["from google.colab import drive\n","drive.mount(\"/content/drive/\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6VIP0kEOveFa"},"outputs":[],"source":["%cd \"/content/drive/MyDrive/3dSwap/\"\n","!ls"]},{"cell_type":"code","source":["!git clone https://github.com/lyx0208/3dSwap.git\n","%cd 3dSwap"],"metadata":{"id":"_w4OmPkdIngC"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"HeZY2rjaI-nK"},"outputs":[],"source":["!pip install -q condacolab\n","import condacolab\n","condacolab.install()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"1izhpsNjKQOm"},"outputs":[],"source":["!conda --version"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"XWcZ-IwPKUv0"},"outputs":[],"source":["!conda create -n 3dSwap python=3.8"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"y76uC8xsLQAk"},"outputs":[],"source":["import sys\n","_ = (sys.path\n"," .append(\"/usr/local/lib/python3.8/site-packages\"))"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"w0fn3d8pLM2Q"},"outputs":[],"source":["%%shell\n","cd /content/drive/MyDrive/3dSwap\n","source activate 3dSwap\n","pip install -r requirements.txt"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ADy03NUhHbiG"},"outputs":[],"source":["%%shell\n","source activate 3dSwap\n","# pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118\n","# pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 --extra-index-url https://download.pytorch.org/whl/cu111\n","# pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113\n","pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121"]},{"cell_type":"markdown","source":["INSTALL MISSING PACKAGES"],"metadata":{"id":"cq_qx1sCGrMT"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"U8-uBjPNzsSj"},"outputs":[],"source":["# uncomment run if these packages are not installed\n","%%shell\n","source activate 3dSwap\n","pip install imgui\n","pip install Ninja numpy\n","pip install scikit-image\n","pip install plyfile\n","pip install trimesh\n","pip install ninja\n","pip install click\n","pip install mrcfile\n","pip install mediapipe\n","pip install opencv-python\n","pip install pillow\n","pip install ipykernel\n","# pip install dlib"]},{"cell_type":"markdown","source":["INFERENCE CODE"],"metadata":{"id":"HvuGk1vSGFrR"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"e3pjfK2sOFsA"},"outputs":[],"source":["%%shell\n","cd /content/drive/MyDrive/3dSwap\n","source activate 3dSwap\n","python run_3dSwap.py --epoch 1000 --to_index \"A\" --from_index \"B\" --dataroot \"/content/drive/MyDrive/3dSwap/datasets/test_images\""]},{"cell_type":"markdown","source":["INFERENCE CODE FOR INVERSION"],"metadata":{"id":"jY05EMCxGkJt"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"Z2NHU-L6J4RW"},"outputs":[],"source":["#!python run_inversion.py --from_index \"apex-cropped-repression-36-64\" --to_index \"apex-cropped-others-16-69\" --dataroot \"datasets/my_dataset\""]},{"cell_type":"markdown","source":["TRAINING CODE FOR FACESWAPPING"],"metadata":{"id":"ITSMbiiMGBKA"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"5Bx3jzjr_L7r"},"outputs":[],"source":["# %%shell\n","# cd /content/drive/MyDrive/3dSwap/\n","# source activate 3dSwap\n","# python -m torch.distributed.run --nproc_per_node=1 --master_port=12345 fine_tune.py --workers=2 --exp_dir=faceswap"]}],"metadata":{"accelerator":"GPU","colab":{"provenance":[],"gpuType":"T4","mount_file_id":"14RpUvROgT830sVvgh4LJebJC44xoc_iv","authorship_tag":"ABX9TyPxvws88H5lPJVudGxUfcyO"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## 3D-Aware Face Swapping
Official PyTorch implementation of the CVPR 2023 paper "3D-Aware Face Swapping"
2 |
3 | 
4 |
5 | **3D-Aware Face Swapping**
6 | Yixuan Li, Chao Ma, Yichao Yan, Wenhan Zhu, Xiaokang Yang
7 |
8 | Abstract: *Face swapping is an important research topic in computer vision with wide applications in entertainment and privacy protection. Existing methods directly learn to swap 2D facial images, taking no account of the geometric information of human faces. In the presence of large pose variance between the source and the target faces, there always exist undesirable artifacts on the swapped face. In this paper, we present a novel 3D-aware face swapping method that generates high-fidelity and multi-view-consistent swapped faces from single-view source and target images. To achieve this, we take advantage of the strong geometry and texture prior of 3D human faces, where the 2D faces are projected into the latent space of a 3D generative model. By disentangling the identity and attribute features in the latent space, we succeed in swapping faces in a 3D-aware manner, being robust to pose variations while transferring fine-grained facial details. Extensive experiments demonstrate the superiority of our 3D-aware face swapping framework in terms of visual quality, identity similarity, and multi-view consistency. Project page: https://lyx0208.github.io/3dSwap*
9 |
10 | ## Requirements
11 | * Create and activate the Python environment:
12 | - `conda create -n 3dSwap python=3.8`
13 | - `conda activate 3dSwap`
14 | - `pip install -r requirements.txt`
15 |
16 | ## Datasets preparation
17 | * We preprocess the images from the original FFHQ and CelebA-HD dataset with the data preprocessing code from **[EG3D](https://github.com/NVlabs/eg3d)**, including re-cropping the images and extracting according camera poses.
18 |
19 | - To test on CelebA-HD dataset, please down our preprocessed data from [here](https://pan.baidu.com/s/1Qgru1Tyg3DkclnPny0gjhw?pwd=swap).
20 |
21 | - To test on your own images, please refer to the data preprocessing file of EG3D [here](https://github.com/NVlabs/eg3d/blob/main/dataset_preprocessing/ffhq/preprocess_in_the_wild.py).
22 |
23 | ## Inference
24 | Download our pretrained model from [Baidu Disk](https://pan.baidu.com/s/1yEJ8-4SLUdDDs9SEE-1hpA?pwd=swap) or [Goole Drive](https://drive.google.com/drive/folders/1rlZRO-pjKFedmx6-3QdSxxThN_jXA6Pb?usp=drive_link). Put model_ir_se50.pth under the "models" folder and other files under the "checkpoints" folder.
25 |
26 | Then run:
27 |
28 | ```.bash
29 | python run_3dSwap.py
30 | ```
31 |
32 | If you only want to perform the 3D GAN inversion without face swapping, run:
33 |
34 | ```.bash
35 | python run_inversion.py
36 | ```
37 | ## Training
38 |
39 | First, download the preprocessed FFHQ dataset from [here](https://pan.baidu.com/s/1Qgru1Tyg3DkclnPny0gjhw?pwd=swap) and put it under the "datasets" folder.
40 |
41 | To train the inversion module, run:
42 |
43 | ```.bash
44 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 train_inversion.py --exp_dir=inversion
45 | ```
46 |
47 | To train the faceswapping module, run:
48 |
49 | ```.bash
50 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 train_faceswap.py --exp_dir=faceswap
51 | ```
52 |
53 | ## Citation
54 |
55 | ```
56 | @InProceedings{Li_2023_CVPR,
57 | author = {Li, Yixuan and Ma, Chao and Yan, Yichao and Zhu, Wenhan and Yang, Xiaokang},
58 | title = {3D-Aware Face Swapping},
59 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
60 | month = {June},
61 | year = {2023},
62 | pages = {12705-12714}
63 | }
64 | ```
65 |
66 | ## Acknowledgements
67 | * Our code is developed based on:
68 | - https://github.com/NVlabs/eg3d
69 | - https://github.com/eladrich/pixel2style2pixel
70 | * [2024.4.8] Thanks [Phillip Chananda](https://github.com/takuphilchan) for uploading the ipynb file of 3dSwap!
71 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__init__.py
--------------------------------------------------------------------------------
/configs/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/data_configs.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/data_configs.cpython-37.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/data_configs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/data_configs.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/data_configs.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/data_configs.cpython-39.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/global_config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/global_config.cpython-36.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/global_config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/global_config.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/global_config.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/global_config.cpython-39.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/hyperparameters.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/hyperparameters.cpython-36.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/hyperparameters.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/hyperparameters.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/hyperparameters.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/hyperparameters.cpython-39.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/paths_config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/paths_config.cpython-36.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/paths_config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/paths_config.cpython-37.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/paths_config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/paths_config.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/paths_config.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/paths_config.cpython-39.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/transforms_config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/transforms_config.cpython-37.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/transforms_config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/transforms_config.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/transforms_config.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/configs/__pycache__/transforms_config.cpython-39.pyc
--------------------------------------------------------------------------------
/configs/data_configs.py:
--------------------------------------------------------------------------------
1 | from configs import transforms_config
2 |
3 |
4 | DATASETS = {
5 | 'ffhq_encode': {
6 | 'transforms': transforms_config.EncodeTransforms,
7 | 'train_source_root': 'datasets/EG3D_FFHQ/final_crops',
8 | 'train_target_root': 'datasets/EG3D/final_crops',
9 | 'test_source_root': 'datasets/EG3D_CelebA/final_crops',
10 | 'test_target_root':'datasets/EG3D_CelebA/final_crops',
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/configs/evaluation_config.py:
--------------------------------------------------------------------------------
1 | evaluated_methods = ['e4e', 'SG2', 'SG2Plus']
--------------------------------------------------------------------------------
/configs/global_config.py:
--------------------------------------------------------------------------------
1 | ## Device
2 | cuda_visible_devices = '0'
3 | device = 'cuda:0'
4 |
5 | ## Logs
6 | training_step = 1
7 | image_rec_result_log_snapshot = 100
8 | pivotal_training_steps = 0
9 | model_snapshot_interval = 400
10 |
11 | ## Run name to be updated during PTI
12 | run_name = ''
13 |
--------------------------------------------------------------------------------
/configs/hyperparameters.py:
--------------------------------------------------------------------------------
1 | ## Architechture
2 | lpips_type = 'alex'
3 | first_inv_type = 'w'
4 | optim_type = 'adam'
5 |
6 | ## Locality regularization
7 | latent_ball_num_of_samples = 1
8 | locality_regularization_interval = 1
9 | use_locality_regularization = False
10 | regulizer_l2_lambda = 0.1
11 | regulizer_lpips_lambda = 0.1
12 | regulizer_alpha = 30
13 |
14 | ## Loss
15 | pt_l2_lambda = 1
16 | pt_lpips_lambda = 1
17 |
18 | ## Steps
19 | LPIPS_value_threshold = 0.06
20 | max_pti_steps = 350
21 | first_inv_steps = 450
22 | max_images_to_invert = 30
23 |
24 | ## Optimization
25 | pti_learning_rate = 3e-4
26 | first_inv_lr = 5e-3
27 | train_batch_size = 1
28 | use_last_w_pivots = False
29 |
--------------------------------------------------------------------------------
/configs/paths_config.py:
--------------------------------------------------------------------------------
1 | ## Pretrained models paths
2 | e4e = './pretrained_models/e4e_ffhq_encode.pt'
3 | stylegan2_ada_ffhq = '../pretrained_models/ffhq.pkl'
4 | style_clip_pretrained_mappers = ''
5 | ir_se50 = './pretrained_models/model_ir_se50.pth'
6 | dlib = '../pretrained_models/align.dat'
7 |
8 | ## Dirs for output files
9 | checkpoints_dir = '../checkpoints'
10 | embedding_base_dir = '../embeddings'
11 | styleclip_output_dir = './StyleCLIP_results'
12 | experiments_output_dir = '../output'
13 |
14 | ## Input info
15 | ### Input dir, where the images reside
16 | input_data_path = '../dataset/aligned'
17 | ### Inversion identifier, used to keeping track of the inversion results. Both the latent code and the generator
18 | input_data_id = 'barcelona'
19 |
20 | ## Keywords
21 | pti_results_keyword = 'PTI'
22 | e4e_results_keyword = 'e4e'
23 | sg2_results_keyword = 'SG2'
24 | sg2_plus_results_keyword = 'SG2_plus'
25 | multi_id_model_type = 'multi_id'
26 |
27 | ## Edit directions
28 | interfacegan_age = '../editings/interfacegan_directions/age.pt'
29 | interfacegan_smile = '../editings/interfacegan_directions/smile.pt'
30 | interfacegan_rotation = '../editings/interfacegan_directions/rotation.pt'
31 | ffhq_pca = '../editings/ganspace_pca/ffhq_pca.pt'
32 |
33 | model_paths = {
34 | 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
35 | 'ir_se50': 'models/model_ir_se50.pth',
36 | 'circular_face': 'pretrained_models/CurricularFace_Backbone.pth',
37 | 'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy',
38 | 'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy',
39 | 'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy',
40 | 'shape_predictor': 'shape_predictor_68_face_landmarks.dat',
41 | 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar'
42 | }
43 |
--------------------------------------------------------------------------------
/configs/transforms_config.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | import torchvision.transforms as transforms
3 | from datasets import augmentations
4 |
5 |
6 | class TransformsConfig(object):
7 |
8 | def __init__(self, opts):
9 | self.opts = opts
10 |
11 | @abstractmethod
12 | def get_transforms(self):
13 | pass
14 |
15 |
16 | class EncodeTransforms(TransformsConfig):
17 |
18 | def __init__(self, opts):
19 | super(EncodeTransforms, self).__init__(opts)
20 |
21 | def get_transforms(self):
22 | transforms_dict = {
23 | 'transform_gt_train': transforms.Compose([
24 | transforms.Resize((256, 256)),
25 | # transforms.RandomHorizontalFlip(0.5),
26 | transforms.ToTensor(),
27 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
28 | 'transform_source': None,
29 | 'transform_test': transforms.Compose([
30 | transforms.Resize((256, 256)),
31 | transforms.ToTensor(),
32 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
33 | 'transform_inference': transforms.Compose([
34 | transforms.Resize((256, 256)),
35 | transforms.ToTensor(),
36 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
37 | }
38 | return transforms_dict
39 |
40 |
41 | class FrontalizationTransforms(TransformsConfig):
42 |
43 | def __init__(self, opts):
44 | super(FrontalizationTransforms, self).__init__(opts)
45 |
46 | def get_transforms(self):
47 | transforms_dict = {
48 | 'transform_gt_train': transforms.Compose([
49 | transforms.Resize((256, 256)),
50 | transforms.RandomHorizontalFlip(0.5),
51 | transforms.ToTensor(),
52 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
53 | 'transform_source': transforms.Compose([
54 | transforms.Resize((256, 256)),
55 | transforms.RandomHorizontalFlip(0.5),
56 | transforms.ToTensor(),
57 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
58 | 'transform_test': transforms.Compose([
59 | transforms.Resize((256, 256)),
60 | transforms.ToTensor(),
61 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
62 | 'transform_inference': transforms.Compose([
63 | transforms.Resize((256, 256)),
64 | transforms.ToTensor(),
65 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
66 | }
67 | return transforms_dict
68 |
69 |
70 | class SketchToImageTransforms(TransformsConfig):
71 |
72 | def __init__(self, opts):
73 | super(SketchToImageTransforms, self).__init__(opts)
74 |
75 | def get_transforms(self):
76 | transforms_dict = {
77 | 'transform_gt_train': transforms.Compose([
78 | transforms.Resize((256, 256)),
79 | transforms.ToTensor(),
80 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
81 | 'transform_source': transforms.Compose([
82 | transforms.Resize((256, 256)),
83 | transforms.ToTensor()]),
84 | 'transform_test': transforms.Compose([
85 | transforms.Resize((256, 256)),
86 | transforms.ToTensor(),
87 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
88 | 'transform_inference': transforms.Compose([
89 | transforms.Resize((256, 256)),
90 | transforms.ToTensor()]),
91 | }
92 | return transforms_dict
93 |
94 |
95 | class SegToImageTransforms(TransformsConfig):
96 |
97 | def __init__(self, opts):
98 | super(SegToImageTransforms, self).__init__(opts)
99 |
100 | def get_transforms(self):
101 | transforms_dict = {
102 | 'transform_gt_train': transforms.Compose([
103 | transforms.Resize((256, 256)),
104 | transforms.ToTensor(),
105 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
106 | 'transform_source': transforms.Compose([
107 | transforms.Resize((256, 256)),
108 | augmentations.ToOneHot(self.opts.label_nc),
109 | transforms.ToTensor()]),
110 | 'transform_test': transforms.Compose([
111 | transforms.Resize((256, 256)),
112 | transforms.ToTensor(),
113 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
114 | 'transform_inference': transforms.Compose([
115 | transforms.Resize((256, 256)),
116 | augmentations.ToOneHot(self.opts.label_nc),
117 | transforms.ToTensor()])
118 | }
119 | return transforms_dict
120 |
121 |
122 | class SuperResTransforms(TransformsConfig):
123 |
124 | def __init__(self, opts):
125 | super(SuperResTransforms, self).__init__(opts)
126 |
127 | def get_transforms(self):
128 | if self.opts.resize_factors is None:
129 | self.opts.resize_factors = '1,2,4,8,16,32'
130 | factors = [int(f) for f in self.opts.resize_factors.split(",")]
131 | print("Performing down-sampling with factors: {}".format(factors))
132 | transforms_dict = {
133 | 'transform_gt_train': transforms.Compose([
134 | transforms.Resize((256, 256)),
135 | transforms.ToTensor(),
136 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
137 | 'transform_source': transforms.Compose([
138 | transforms.Resize((256, 256)),
139 | augmentations.BilinearResize(factors=factors),
140 | transforms.Resize((256, 256)),
141 | transforms.ToTensor(),
142 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
143 | 'transform_test': transforms.Compose([
144 | transforms.Resize((256, 256)),
145 | transforms.ToTensor(),
146 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
147 | 'transform_inference': transforms.Compose([
148 | transforms.Resize((256, 256)),
149 | augmentations.BilinearResize(factors=factors),
150 | transforms.Resize((256, 256)),
151 | transforms.ToTensor(),
152 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
153 | }
154 | return transforms_dict
155 |
--------------------------------------------------------------------------------
/datasets/CelebA-HD/camera_pose/0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/camera_pose/0.npy
--------------------------------------------------------------------------------
/datasets/CelebA-HD/camera_pose/1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/camera_pose/1.npy
--------------------------------------------------------------------------------
/datasets/CelebA-HD/camera_pose/2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/camera_pose/2.npy
--------------------------------------------------------------------------------
/datasets/CelebA-HD/camera_pose/3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/camera_pose/3.npy
--------------------------------------------------------------------------------
/datasets/CelebA-HD/camera_pose/4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/camera_pose/4.npy
--------------------------------------------------------------------------------
/datasets/CelebA-HD/camera_pose/5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/camera_pose/5.npy
--------------------------------------------------------------------------------
/datasets/CelebA-HD/camera_pose/6.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/camera_pose/6.npy
--------------------------------------------------------------------------------
/datasets/CelebA-HD/camera_pose/7.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/camera_pose/7.npy
--------------------------------------------------------------------------------
/datasets/CelebA-HD/camera_pose/8.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/camera_pose/8.npy
--------------------------------------------------------------------------------
/datasets/CelebA-HD/camera_pose/9.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/camera_pose/9.npy
--------------------------------------------------------------------------------
/datasets/CelebA-HD/final_crops/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/final_crops/0.jpg
--------------------------------------------------------------------------------
/datasets/CelebA-HD/final_crops/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/final_crops/1.jpg
--------------------------------------------------------------------------------
/datasets/CelebA-HD/final_crops/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/final_crops/2.jpg
--------------------------------------------------------------------------------
/datasets/CelebA-HD/final_crops/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/final_crops/3.jpg
--------------------------------------------------------------------------------
/datasets/CelebA-HD/final_crops/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/final_crops/4.jpg
--------------------------------------------------------------------------------
/datasets/CelebA-HD/final_crops/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/final_crops/5.jpg
--------------------------------------------------------------------------------
/datasets/CelebA-HD/final_crops/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/final_crops/6.jpg
--------------------------------------------------------------------------------
/datasets/CelebA-HD/final_crops/7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/final_crops/7.jpg
--------------------------------------------------------------------------------
/datasets/CelebA-HD/final_crops/8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/final_crops/8.jpg
--------------------------------------------------------------------------------
/datasets/CelebA-HD/final_crops/9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/CelebA-HD/final_crops/9.jpg
--------------------------------------------------------------------------------
/datasets/__pycache__/augmentations.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/__pycache__/augmentations.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/images_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/datasets/__pycache__/images_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/augmentations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torchvision import transforms
6 |
7 |
8 | class ToOneHot(object):
9 | """ Convert the input PIL image to a one-hot torch tensor """
10 | def __init__(self, n_classes=None):
11 | self.n_classes = n_classes
12 |
13 | def onehot_initialization(self, a):
14 | if self.n_classes is None:
15 | self.n_classes = len(np.unique(a))
16 | out = np.zeros(a.shape + (self.n_classes, ), dtype=int)
17 | out[self.__all_idx(a, axis=2)] = 1
18 | return out
19 |
20 | def __all_idx(self, idx, axis):
21 | grid = np.ogrid[tuple(map(slice, idx.shape))]
22 | grid.insert(axis, idx)
23 | return tuple(grid)
24 |
25 | def __call__(self, img):
26 | img = np.array(img)
27 | one_hot = self.onehot_initialization(img)
28 | return one_hot
29 |
30 |
31 | class BilinearResize(object):
32 | def __init__(self, factors=[1, 2, 4, 8, 16, 32]):
33 | self.factors = factors
34 |
35 | def __call__(self, image):
36 | factor = np.random.choice(self.factors, size=1)[0]
37 | D = BicubicDownSample(factor=factor, cuda=False)
38 | img_tensor = transforms.ToTensor()(image).unsqueeze(0)
39 | img_tensor_lr = D(img_tensor)[0].clamp(0, 1)
40 | img_low_res = transforms.ToPILImage()(img_tensor_lr)
41 | return img_low_res
42 |
43 |
44 | class BicubicDownSample(nn.Module):
45 | def bicubic_kernel(self, x, a=-0.50):
46 | """
47 | This equation is exactly copied from the website below:
48 | https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
49 | """
50 | abs_x = torch.abs(x)
51 | if abs_x <= 1.:
52 | return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
53 | elif 1. < abs_x < 2.:
54 | return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
55 | else:
56 | return 0.0
57 |
58 | def __init__(self, factor=4, cuda=True, padding='reflect'):
59 | super().__init__()
60 | self.factor = factor
61 | size = factor * 4
62 | k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
63 | for i in range(size)], dtype=torch.float32)
64 | k = k / torch.sum(k)
65 | k1 = torch.reshape(k, shape=(1, 1, size, 1))
66 | self.k1 = torch.cat([k1, k1, k1], dim=0)
67 | k2 = torch.reshape(k, shape=(1, 1, 1, size))
68 | self.k2 = torch.cat([k2, k2, k2], dim=0)
69 | self.cuda = '.cuda' if cuda else ''
70 | self.padding = padding
71 | for param in self.parameters():
72 | param.requires_grad = False
73 |
74 | def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
75 | filter_height = self.factor * 4
76 | filter_width = self.factor * 4
77 | stride = self.factor
78 |
79 | pad_along_height = max(filter_height - stride, 0)
80 | pad_along_width = max(filter_width - stride, 0)
81 | filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
82 | filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
83 |
84 | # compute actual padding values for each side
85 | pad_top = pad_along_height // 2
86 | pad_bottom = pad_along_height - pad_top
87 | pad_left = pad_along_width // 2
88 | pad_right = pad_along_width - pad_left
89 |
90 | # apply mirror padding
91 | if nhwc:
92 | x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW
93 |
94 | # downscaling performed by 1-d convolution
95 | x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
96 | x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
97 | if clip_round:
98 | x = torch.clamp(torch.round(x), 0.0, 255.)
99 |
100 | x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
101 | x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
102 | if clip_round:
103 | x = torch.clamp(torch.round(x), 0.0, 255.)
104 |
105 | if nhwc:
106 | x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
107 | if byte_output:
108 | return x.type('torch.ByteTensor'.format(self.cuda))
109 | else:
110 | return x
111 |
--------------------------------------------------------------------------------
/datasets/images_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from torch.utils.data import Dataset
3 | from PIL import Image
4 | from utils import data_utils
5 | import random
6 | import numpy as np
7 | import torch
8 | import torchvision.transforms as transforms
9 |
10 | TRANSFORM = transforms.Compose([
11 | transforms.Resize((512, 512)),
12 | transforms.ToTensor(),
13 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
14 | )
15 |
16 |
17 | def gen_mask(p):
18 | if p is None:
19 | mask_rect = torch.zeros([1, 512, 512])
20 | num = 75
21 | for i in range(50, 100):
22 | for j in range(32 * 2 + num, 220 * 2 - num):
23 | mask_rect[0][i][j] += 1
24 | return mask_rect
25 | index = [1]
26 | mask = torch.zeros([1, 512, 512])
27 | for i in index:
28 | mask += p == i
29 | mask_rect = torch.zeros([1, 512, 512])
30 | num = 75
31 | for i in range(35 * 2 + num, 223 * 2 - num):
32 | for j in range(32 * 2 + num, 220 * 2 - num):
33 | mask_rect[0][i][j] += 1
34 |
35 | return mask * mask_rect
36 |
37 |
38 | def load_parameter(param_path):
39 | parameter = torch.zeros([1, 25])
40 | parameter_np = np.load(param_path)
41 | for i in range(parameter_np.__len__()):
42 | parameter[0, i] += parameter_np[i]
43 | return parameter
44 |
45 |
46 | class ImagesDataset(Dataset):
47 |
48 | def __init__(self, source_root, target_root, opts):
49 | self.camera_pose_root = source_root[:-11] + 'camera_pose'
50 |
51 | self.source_paths = sorted(data_utils.make_dataset(source_root))
52 | self.target_paths = sorted(data_utils.make_dataset(target_root))
53 | self.target_num = len(self.source_paths)
54 | self.opts = opts
55 |
56 | def __len__(self):
57 | return len(self.source_paths)
58 |
59 | def __getitem__(self, index):
60 | to_index = (random.randint(0, self.target_num)) % self.target_num
61 |
62 | from_im = Image.open(self.source_paths[index]).convert('RGB')
63 | from_im = TRANSFORM(from_im)
64 |
65 | to_im = Image.open(self.source_paths[to_index]).convert('RGB')
66 | to_im = TRANSFORM(to_im)
67 |
68 | from_camera_parameter = load_parameter(
69 | os.path.join(self.camera_pose_root, self.source_paths[index].split('/')[-1].split('.')[0] + '.npy'))
70 | to_camera_parameter = load_parameter(
71 | os.path.join(self.camera_pose_root, self.source_paths[to_index].split('/')[-1].split('.')[0] + '.npy'))
72 |
73 | try:
74 | to_label_path = os.path.join('datasets/EG3D/labels',
75 | self.source_paths[to_index].split('/')[-1].split('.')[0] + '.png')
76 | to_label = Image.open(to_label_path).convert('L')
77 | to_label_tensor = TRANSFORM(to_label) * 255.0
78 | to_face_mask = gen_mask(to_label_tensor)
79 | except:
80 | to_face_mask = gen_mask(None)
81 |
82 | return from_im, to_im, from_camera_parameter, to_camera_parameter, to_face_mask
83 |
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | from .util import EasyDict, make_cache_dir_path
12 |
--------------------------------------------------------------------------------
/dnnlib/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/dnnlib/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/dnnlib/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/dnnlib/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/dnnlib/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/dnnlib/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/dnnlib/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/dnnlib/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/dnnlib/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/dnnlib/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/eg3d_c_vision_experiments.ipynb:
--------------------------------------------------------------------------------
1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","mount_file_id":"10tpI22Lq2EHE7jCKN-fKoaEAhU-1tChY","authorship_tag":"ABX9TyPFnHtZXrzVVSd/z48Q4K3k"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","source":["DATASET PREPROCESSING NOTEBOOK USING EG3D BY PHILLIP T. CHANANDA\n","\n","CREDIT: NVIDIA LAB'S EG3D REPOSITORY\n","LINK: https://github.com/NVlabs/eg3d\n","PAPER: Efficient Geometry-aware {3D} Generative Adversarial Networks\n","AUTHORS: Eric R. Chan and Connor Z. Lin and Matthew A. Chan and Koki Nagano and Boxiao Pan and Shalini De Mello and Orazio Gallo and Leonidas Guibas and Jonathan Tremblay and Sameh Khamis and Tero Karras and Gordon Wetzstein\n","\n"],"metadata":{"id":"f210pMbSwbPH"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"pvc3IQsdmKCz"},"outputs":[],"source":["# MOUNT YOUR GOOGLE DRIVE\n","# from google.colab import drive\n","# drive.mount(\"/content/drive/\")"]},{"cell_type":"code","source":["# GIT CLONE THE REPOSITORY"],"metadata":{"id":"TgbUXFwC1weA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["%cd \"/content/drive/MyDrive/\"\n","!ls"],"metadata":{"id":"6VIP0kEOveFa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!git clone https://github.com/NVlabs/eg3d.git\n","%cd eg3d"],"metadata":{"id":"jG-LTHQMCt5z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# INSTALL CONDA\n","!pip install -q condacolab\n","import condacolab\n","condacolab.install()"],"metadata":{"id":"HeZY2rjaI-nK"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!conda --version"],"metadata":{"id":"1izhpsNjKQOm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["%cd /content/drive/MyDrive/eg3d/eg3d\n","!ls"],"metadata":{"id":"aUbYeZ29EvE5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!mamba env update -f environment.yml"],"metadata":{"id":"-cwViFQnGhQN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["%%shell\n","source activate eg3d"],"metadata":{"id":"I8062a7-K2T_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# INSTALL DEPENDENCIES\n","!pip install numpy kornia dominate tensorflow tensorboard scipy opencv-python scikit-image ninja trimesh\n","!pip install mtcnn\n","# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu119\n","!pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121"],"metadata":{"id":"e-UxPRRVMMOn"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# INSTALL NVIDIFFRAST\n","%cd /content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch/nvdiffrast\n","!pip install ."],"metadata":{"id":"FXzcMIuzL6d5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["%cd /content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch"],"metadata":{"id":"0CN6yyMiTaKp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["%cd /content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq"],"metadata":{"id":"wWe3RVQlIdXZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# INITIALIZE PATHS\n","# change to name of your dataset\n","dataset_name = \"all\"\n","dataset_path = \"/content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch/datasets/\"+dataset_name\n","print(dataset_path)"],"metadata":{"id":"rLhZt6WPPSCA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# RUN MTCNN NEEDED FOR Deep3DFaceRecon\n","!python batch_mtcnn.py --in_root {dataset_path}"],"metadata":{"id":"mIPJmpTEOnDT"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# !pip install --upgrade torch torchvision"],"metadata":{"id":"u9ZedzPDFLjD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["%cd /content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch\n","!python test.py --img_folder={dataset_path} --gpu_ids=0 --name=pretrained --epoch=20"],"metadata":{"id":"ErqRv8ALLzw5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# CROP THE IMAGES\n","%cd /content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq\n","!python crop_images_in_the_wild.py --indir={dataset_path}"],"metadata":{"id":"MFKLh05NYWDh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# CONVERT POSE TO NEW FORMAT\n","%cd /content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq/\n","!python 3dface2idr_mat.py --in_root /content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch/checkpoints/pretrained/results/{dataset_name}/epoch_20_000000 --out_path /content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch/datasets/{dataset_name}/crop/cameras.json"],"metadata":{"id":"tuhy0RLkYdj6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# PREPROCESS FACE CAMERAS\n","!python preprocess_face_cameras.py --source /content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch/datasets/{dataset_name}/crop --dest /content/drive/MyDrive/eg3d/dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch/datasets/{dataset_name}/preprocessed_cameras --mode orig"],"metadata":{"id":"DfkMdXF3ei_x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# CONCATENATE CAMERA_POSE\n","cameras_path = dataset_path+ \"/crop/cameras.json\"\n","print(cameras_path)\n","!python concate_camera_poses.py --cameras_path={cameras_path} --save_to={dataset_path}"],"metadata":{"id":"92t2dPh9dKrO"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# MAKING SURE THE CAMERA_POSE FILES CORRESPONDS TO THE FINAL_CROPS FILES\n","import os\n","import cv2\n","import numpy as np\n","# SET YOUR DESIRED FINAL DATASET PATH\n","destination_folder = \"/content/drive/MyDrive/3dSwap/datasets/\"\n","cameras = dataset_path+\"/camera_pose/\"\n","crops = dataset_path+\"/crop/\"\n","final_crops = \"final_crops\"\n","camera_poses = \"camera_pose\"\n","\n","# CHECK IF CAMERA_POSE HAVE MATCHING FINAL_CROPS\n","for camera in os.listdir(cameras):\n"," new_camera = camera[:-4]\n"," for crop in os.listdir(crops):\n"," new_crop = crop[:-4]\n"," new_path = destination_folder + dataset_name + \"/\"\n"," if os.path.exists(new_path)==False:\n"," os.mkdir(new_path)\n"," if new_camera == new_crop:\n"," image = cv2.imread(crops+crop)\n"," camera_npy_file = np.load(cameras+camera)\n"," if os.path.exists(new_path+final_crops) or os.path.exists(new_path+camera_poses):\n"," print(\"Path --- \"+new_path+final_crops+\"/\"+crop)\n"," cv2.imwrite(new_path+final_crops+\"/\"+crop, image)\n"," np.save(new_path+camera_poses+\"/\"+camera, camera_npy_file)\n"," else:\n"," print(\"Path --- \"+new_path+final_crops+\"/\"+crop)\n"," os.mkdir(new_path+final_crops)\n"," os.makedirs(new_path+camera_poses)\n"," cv2.imwrite(new_path+final_crops+\"/\"+crop, image)\n"," np.save(new_path+camera_poses+\"/\"+camera, camera_npy_file)"],"metadata":{"id":"CQlC0a_8-q59"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# !pip install click\n","# %cd /content/drive/MyDrive/eg3d/eg3d\n","# !python dataset_tool.py --source {dataset_path} --dest {dataset_path}"],"metadata":{"id":"iKNmoeA4RqIn"},"execution_count":null,"outputs":[]}]}
--------------------------------------------------------------------------------
/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/images/teaser.png
--------------------------------------------------------------------------------
/models/__pycache__/discriminator.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/__pycache__/discriminator.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/faceswap_coach.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/__pycache__/faceswap_coach.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/id_loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/__pycache__/id_loss.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/inversion_coach.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/__pycache__/inversion_coach.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/networks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/__pycache__/networks.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/psp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/__pycache__/psp.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/psp.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/__pycache__/psp.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/w_norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/__pycache__/w_norm.cpython-38.pyc
--------------------------------------------------------------------------------
/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import functools
4 |
5 |
6 | class NLayerDiscriminator(nn.Module):
7 | def __init__(self, input_nc, ndf=64, n_layers=6, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
8 | super(NLayerDiscriminator, self).__init__()
9 | self.getIntermFeat = getIntermFeat
10 | self.n_layers = n_layers
11 |
12 | kw = 4
13 | padw = int(np.ceil((kw-1.0)/2))
14 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
15 |
16 | nf = ndf
17 | for n in range(1, n_layers):
18 | nf_prev = nf
19 | nf = min(nf * 2, 512)
20 | sequence += [[
21 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
22 | norm_layer(nf), nn.LeakyReLU(0.2, True)
23 | ]]
24 |
25 | nf_prev = nf
26 | nf = min(nf * 2, 512)
27 | sequence += [[
28 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
29 | norm_layer(nf),
30 | nn.LeakyReLU(0.2, True)
31 | ]]
32 |
33 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
34 |
35 | if use_sigmoid:
36 | sequence += [[nn.Sigmoid()]]
37 |
38 | if getIntermFeat:
39 | for n in range(len(sequence)):
40 | setattr(self, 'encoder'+str(n), nn.Sequential(*sequence[n]))
41 | else:
42 | sequence_stream = []
43 | for n in range(len(sequence)):
44 | sequence_stream += sequence[n]
45 | self.model = nn.Sequential(*sequence_stream)
46 |
47 | def forward(self, x):
48 | if self.getIntermFeat:
49 | res = [x]
50 | for n in range(self.n_layers+2):
51 | model = getattr(self, 'encoder'+str(n))
52 | res.append(model(res[-1]))
53 | return res[1:]
54 | else:
55 | return self.model(x)
56 |
57 |
58 | class MultiscaleDiscriminator(nn.Module):
59 |
60 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
61 | use_sigmoid=False, num_D=3, getIntermFeat=False):
62 | super(MultiscaleDiscriminator, self).__init__()
63 | self.num_D = num_D
64 | self.n_layers = n_layers
65 | self.getIntermFeat = getIntermFeat
66 |
67 | for i in range(num_D):
68 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
69 | if getIntermFeat:
70 | for j in range(n_layers + 2):
71 | setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'encoder' + str(j)))
72 | else:
73 | setattr(self, 'layer' + str(i), netD.model)
74 |
75 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
76 |
77 | def singleD_forward(self, model, input):
78 | if self.getIntermFeat:
79 | result = [input]
80 | for i in range(len(model)):
81 | result.append(model[i](result[-1]))
82 | return result[1:]
83 | else:
84 | return [model(input)]
85 |
86 | def forward(self, input):
87 | num_D = self.num_D
88 | result = []
89 |
90 | input_downsampled = input
91 | for i in range(num_D):
92 | if self.getIntermFeat:
93 | model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
94 | range(self.n_layers + 2)]
95 | else:
96 | model = getattr(self, 'layer' + str(num_D - 1 - i))
97 | out = self.singleD_forward(model, input_downsampled)
98 | result.append(out)
99 |
100 | if i != (num_D - 1):
101 | input_downsampled = self.downsample(input_downsampled)
102 |
103 | return result
104 |
105 |
106 | def get_norm_layer(norm_type='instance'):
107 | if norm_type == 'batch':
108 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
109 | elif norm_type == 'instance':
110 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
111 | else:
112 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
113 | return norm_layer
114 |
115 |
116 |
--------------------------------------------------------------------------------
/models/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__init__.py
--------------------------------------------------------------------------------
/models/encoders/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/encoder128.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/encoder128.cpython-37.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/helpers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/helpers.cpython-37.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/helpers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/helpers.cpython-38.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/helpers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/helpers.cpython-39.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/model_irse.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/model_irse.cpython-37.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/model_irse.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/model_irse.cpython-38.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/model_irse.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/model_irse.cpython-39.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/psp_encoders.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/psp_encoders.cpython-37.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/psp_encoders.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/psp_encoders.cpython-38.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/psp_encoders.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/encoders/__pycache__/psp_encoders.cpython-39.pyc
--------------------------------------------------------------------------------
/models/encoders/helpers.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
4 |
5 | """
6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7 | """
8 |
9 |
10 | class Flatten(Module):
11 | def forward(self, input):
12 | return input.view(input.size(0), -1)
13 |
14 |
15 | def l2_norm(input, axis=1):
16 | norm = torch.norm(input, 2, axis, True)
17 | output = torch.div(input, norm)
18 | return output
19 |
20 |
21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
22 | """ A named tuple describing a ResNet block. """
23 |
24 |
25 | def get_block(in_channel, depth, num_units, stride=2):
26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
27 |
28 |
29 | def get_blocks(num_layers):
30 | if num_layers == 50:
31 | blocks = [
32 | get_block(in_channel=64, depth=64, num_units=3),
33 | get_block(in_channel=64, depth=128, num_units=4),
34 | get_block(in_channel=128, depth=256, num_units=14),
35 | get_block(in_channel=256, depth=512, num_units=3)
36 | ]
37 | elif num_layers == 100:
38 | blocks = [
39 | get_block(in_channel=64, depth=64, num_units=3),
40 | get_block(in_channel=64, depth=128, num_units=13),
41 | get_block(in_channel=128, depth=256, num_units=30),
42 | get_block(in_channel=256, depth=512, num_units=3)
43 | ]
44 | elif num_layers == 152:
45 | blocks = [
46 | get_block(in_channel=64, depth=64, num_units=3),
47 | get_block(in_channel=64, depth=128, num_units=8),
48 | get_block(in_channel=128, depth=256, num_units=36),
49 | get_block(in_channel=256, depth=512, num_units=3)
50 | ]
51 | else:
52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
53 | return blocks
54 |
55 |
56 | class SEModule(Module):
57 | def __init__(self, channels, reduction):
58 | super(SEModule, self).__init__()
59 | self.avg_pool = AdaptiveAvgPool2d(1)
60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
61 | self.relu = ReLU(inplace=True)
62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
63 | self.sigmoid = Sigmoid()
64 |
65 | def forward(self, x):
66 | module_input = x
67 | x = self.avg_pool(x)
68 | x = self.fc1(x)
69 | x = self.relu(x)
70 | x = self.fc2(x)
71 | x = self.sigmoid(x)
72 | return module_input * x
73 |
74 |
75 | class bottleneck_IR(Module):
76 | def __init__(self, in_channel, depth, stride):
77 | super(bottleneck_IR, self).__init__()
78 | if in_channel == depth:
79 | self.shortcut_layer = MaxPool2d(1, stride)
80 | else:
81 | self.shortcut_layer = Sequential(
82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
83 | BatchNorm2d(depth)
84 | )
85 | self.res_layer = Sequential(
86 | BatchNorm2d(in_channel),
87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
89 | )
90 |
91 | def forward(self, x):
92 | shortcut = self.shortcut_layer(x)
93 | res = self.res_layer(x)
94 | return res + shortcut
95 |
96 |
97 | class bottleneck_IR_SE(Module):
98 | def __init__(self, in_channel, depth, stride):
99 | super(bottleneck_IR_SE, self).__init__()
100 | if in_channel == depth:
101 | self.shortcut_layer = MaxPool2d(1, stride)
102 | else:
103 | self.shortcut_layer = Sequential(
104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
105 | BatchNorm2d(depth)
106 | )
107 | self.res_layer = Sequential(
108 | BatchNorm2d(in_channel),
109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
110 | PReLU(depth),
111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
112 | BatchNorm2d(depth),
113 | SEModule(depth, 16)
114 | )
115 |
116 | def forward(self, x):
117 | shortcut = self.shortcut_layer(x)
118 | res = self.res_layer(x)
119 | return res + shortcut
120 |
--------------------------------------------------------------------------------
/models/encoders/model_irse.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3 |
4 | """
5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6 | """
7 |
8 |
9 | class Backbone(Module):
10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11 | super(Backbone, self).__init__()
12 | assert input_size in [112, 224], "input_size should be 112 or 224"
13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15 | blocks = get_blocks(num_layers)
16 | if mode == 'ir':
17 | unit_module = bottleneck_IR
18 | elif mode == 'ir_se':
19 | unit_module = bottleneck_IR_SE
20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21 | BatchNorm2d(64),
22 | PReLU(64))
23 | if input_size == 112:
24 | self.output_layer = Sequential(BatchNorm2d(512),
25 | Dropout(drop_ratio),
26 | Flatten(),
27 | Linear(512 * 7 * 7, 512),
28 | BatchNorm1d(512, affine=affine))
29 | else:
30 | self.output_layer = Sequential(BatchNorm2d(512),
31 | Dropout(drop_ratio),
32 | Flatten(),
33 | Linear(512 * 14 * 14, 512),
34 | BatchNorm1d(512, affine=affine))
35 |
36 | modules = []
37 | for block in blocks:
38 | for bottleneck in block:
39 | modules.append(unit_module(bottleneck.in_channel,
40 | bottleneck.depth,
41 | bottleneck.stride))
42 | self.body = Sequential(*modules)
43 |
44 | def forward(self, x):
45 | x = self.input_layer(x)
46 | x = self.body(x)
47 | x = self.output_layer(x)
48 | return l2_norm(x)
49 |
50 |
51 | def IR_50(input_size):
52 | """Constructs a ir-50 model."""
53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54 | return model
55 |
56 |
57 | def IR_101(input_size):
58 | """Constructs a ir-101 model."""
59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60 | return model
61 |
62 |
63 | def IR_152(input_size):
64 | """Constructs a ir-152 model."""
65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66 | return model
67 |
68 |
69 | def IR_SE_50(input_size):
70 | """Constructs a ir_se-50 model."""
71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72 | return model
73 |
74 |
75 | def IR_SE_101(input_size):
76 | """Constructs a ir_se-101 model."""
77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78 | return model
79 |
80 |
81 | def IR_SE_152(input_size):
82 | """Constructs a ir_se-152 model."""
83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84 | return model
85 |
--------------------------------------------------------------------------------
/models/id_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from models.encoders.model_irse import Backbone
4 | import sys
5 |
6 | sys.path.append(".")
7 | sys.path.append("..")
8 |
9 |
10 | class IDLoss(nn.Module):
11 | def __init__(self):
12 | super(IDLoss, self).__init__()
13 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.4, mode='ir_se')
14 | self.facenet.load_state_dict(torch.load('models/model_ir_se50.pth'))
15 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
16 |
17 | def extract_feats(self, x):
18 | factor = int(x.shape[-1] / 256)
19 | x = x[:, :, 35 * factor:223 * factor, 32 * factor:220 * factor]
20 |
21 | x = self.face_pool(x)
22 | x_feats = self.facenet(x)
23 | return x_feats
24 |
25 | def forward(self, x, y_hat):
26 | n_samples = x.shape[0]
27 | x_feats = self.extract_feats(x)
28 | y_hat_feats = self.extract_feats(y_hat)
29 | x_feats = x_feats.detach()
30 | loss = 0
31 | count = 0
32 | for i in range(n_samples):
33 | diff_target = y_hat_feats[i].dot(x_feats[i])
34 | loss += 1 - diff_target
35 | count += 1
36 |
37 | return loss / count
38 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .stylegan2.model import EqualLinear
3 | import torch
4 | from .discriminator import MultiscaleDiscriminator
5 |
6 |
7 | def define_mlp(layers_num):
8 | layers = [EqualLinear(1024, 512)]
9 | for _ in range(layers_num - 1):
10 | layers.append(EqualLinear(512, 512))
11 | mlp = nn.Sequential(*layers)
12 | return mlp.cuda()
13 |
14 | def define_D(input_nc, n_layers=3, norm_layer=torch.nn.BatchNorm2d):
15 | netD = MultiscaleDiscriminator(input_nc, n_layers=n_layers, norm_layer=norm_layer, use_sigmoid=False)
16 | netD.cuda()
17 |
18 | def weights_init(m):
19 | classname = m.__class__.__name__
20 | if classname.find('Conv2d') != -1:
21 | m.weight.data.normal_(0.0, 0.02)
22 | elif classname.find('BatchNorm2d') != -1:
23 | m.weight.data.normal_(1.0, 0.02)
24 | m.bias.data.fill_(0)
25 |
26 | netD.apply(weights_init)
27 | return netD
--------------------------------------------------------------------------------
/models/psp.py:
--------------------------------------------------------------------------------
1 | """
2 | This file defines the core research contribution
3 | """
4 | import matplotlib
5 | import numpy as np
6 |
7 | matplotlib.use('Agg')
8 | import math
9 | import cv2
10 |
11 | import torch
12 | from torch import nn
13 | from models.encoders import psp_encoders
14 | from .networks import define_D, define_mlp
15 | import torch.nn.functional as F
16 |
17 | import dnnlib
18 | from utils import legacy
19 |
20 | from training.triplane import TriPlaneGenerator
21 | from torch_utils import misc
22 | from tqdm import tqdm
23 |
24 | from lpips import LPIPS
25 | from models.id_loss import IDLoss
26 | from models.w_norm import WNormLoss
27 |
28 |
29 | def get_keys(d, name):
30 | if 'state_dict' in d:
31 | d = d['state_dict']
32 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
33 | return d_filt
34 |
35 |
36 | class pSp(nn.Module):
37 |
38 | def __init__(self, opts, train_faceswap=True):
39 | super(pSp, self).__init__()
40 | self.opts = opts
41 |
42 | self.encoder = self.set_encoder().to(self.opts.device)
43 |
44 | if train_faceswap:
45 | encoder_ckpt = torch.load('checkpoints/encoder.pt')
46 | self.encoder.load_state_dict(get_keys(encoder_ckpt, 'encoder'), strict=True)
47 |
48 | for i in range(5):
49 | mlp = define_mlp(4)
50 | setattr(self, f'MLP{i}', mlp.train())
51 |
52 | with dnnlib.util.open_url('checkpoints/ffhq512-128.pkl') as f:
53 | G = legacy.load_network_pkl(f)['G_ema'].to(self.opts.device)
54 | G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(self.opts.device)
55 | misc.copy_params_and_buffers(G, G_new, require_all=True)
56 | G_new.neural_rendering_resolution = G.neural_rendering_resolution
57 | G_new.rendering_kwargs = G.rendering_kwargs
58 | self.decoder = G_new
59 |
60 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
61 |
62 | def detector(self, frame):
63 | rects = self.align.getAllFaceBoundingBoxes(frame)
64 | landmarks = {}
65 | if len(rects) > 0:
66 | bb = self.align.findLandmarks(frame, rects[0])
67 | for i in range(68):
68 | landmarks[i] = bb[i]
69 | return landmarks
70 |
71 | def set_encoder(self):
72 | if self.opts.encoder_type == 'GradualStyleEncoder':
73 | encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se')
74 | elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
75 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
76 | elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
77 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
78 | else:
79 | raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
80 | return encoder
81 |
82 | def Inversion(self, x, cp, rand_cp, w_avg):
83 | cp = cp.squeeze(dim=1)
84 | rand_cp = rand_cp.squeeze(dim=1).repeat(x.shape[0], 1)
85 | x = F.interpolate(x, size=[256, 256], mode='bilinear', align_corners=True)
86 | wx = self.encoder(x) + w_avg
87 | x_prime = self.decoder.synthesis(wx, cp)['image']
88 | x_hat = self.decoder.synthesis(wx, rand_cp)['image']
89 | wx_hat = self.encoder(F.interpolate(x_hat, size=[256, 256], mode='bilinear', align_corners=True)) + w_avg
90 | x_hat_prime = self.decoder.synthesis(wx_hat, cp)['image']
91 |
92 | return x_prime, x_hat, x_hat_prime, wx, wx_hat
93 |
94 | def ger_average_color(self, mask, arms):
95 | color = torch.zeros(arms.shape).cuda()
96 | mask = mask.repeat([arms.shape[0], 1, 1, 1])
97 | for i in range(arms.shape[0]):
98 | count = len(torch.nonzero(mask[i, :, :, :]))
99 | if count < 10:
100 | color[i, 0, :, :] = 0
101 | color[i, 1, :, :] = 0
102 | color[i, 2, :, :] = 0
103 |
104 | else:
105 | color[i, 0, :, :] = arms[i, 0, :, :].sum() / count
106 | color[i, 1, :, :] = arms[i, 1, :, :].sum() / count
107 | color[i, 2, :, :] = arms[i, 2, :, :].sum() / count
108 | return color
109 |
110 | def my_acti(self, w):
111 | return 1 / (1 + torch.exp(-100 * (w - 0.5)))
112 |
113 | def FaceSwap(self, x, y, x_cp, y_cp, w_avg):
114 | with torch.no_grad():
115 | x_cp = x_cp.squeeze(1)
116 | y_cp = y_cp.squeeze(1)
117 | x = F.interpolate(x, size=[256, 256], mode='bilinear', align_corners=True)
118 | y = F.interpolate(y, size=[256, 256], mode='bilinear', align_corners=True)
119 | x_ws = self.encoder(x) + w_avg
120 | y_ws = self.encoder(y) + w_avg
121 | x_rec = self.decoder.synthesis(x_ws, x_cp)['image']
122 | y_rec = self.decoder.synthesis(y_ws, y_cp)['image']
123 |
124 | x_codes, y_codes = [], []
125 |
126 | start_index = 5
127 | index_length = 5
128 |
129 | for i in range(start_index, start_index + index_length):
130 | x_codes.append(x_ws[:, i: i + 1])
131 | y_codes.append(y_ws[:, i: i + 1])
132 |
133 | yhat_codes = []
134 | yhat_codes.append(y_ws[:, :start_index])
135 | for i in range(start_index, start_index + index_length):
136 | i = i - start_index
137 | MLP = getattr(self, f'MLP{i}')
138 | rho = MLP(torch.cat([x_codes[i], y_codes[i]], dim=2))
139 | rho = (rho - rho.min()) / (rho.max() - rho.min())
140 | rho = self.my_acti(rho)
141 | yhat_codes.append(y_codes[i] * rho + x_codes[i] * (1 - rho))
142 | yhat_codes.append(y_ws[:, start_index + index_length:])
143 |
144 | ws = torch.cat(yhat_codes, dim=1)
145 | y_hat = self.decoder.synthesis(ws, y_cp)['image']
146 | y_rand = self.decoder.synthesis(ws, x_cp)['image']
147 |
148 | return x_rec, y_rec, y_hat, y_rand
149 |
150 | def set_opts(self, opts):
151 | self.opts = opts
152 |
153 | def __load_latent_avg(self, ckpt, repeat=None):
154 | if 'latent_avg' in ckpt:
155 | self.latent_avg_2d = ckpt['latent_avg'].to(self.opts.device)
156 | if repeat is not None:
157 | self.latent_avg_2d = self.latent_avg_2d.repeat(repeat, 1)
158 | else:
159 | self.latent_avg = None
160 |
--------------------------------------------------------------------------------
/models/stylegan2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/__init__.py
--------------------------------------------------------------------------------
/models/stylegan2/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/stylegan2/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/stylegan2/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/models/stylegan2/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/stylegan2/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/models/stylegan2/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/models/stylegan2/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/models/stylegan2/op/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/op/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/stylegan2/op/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/op/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/stylegan2/op/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/op/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/models/stylegan2/op/__pycache__/fused_act.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/op/__pycache__/fused_act.cpython-37.pyc
--------------------------------------------------------------------------------
/models/stylegan2/op/__pycache__/fused_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/op/__pycache__/fused_act.cpython-38.pyc
--------------------------------------------------------------------------------
/models/stylegan2/op/__pycache__/fused_act.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/op/__pycache__/fused_act.cpython-39.pyc
--------------------------------------------------------------------------------
/models/stylegan2/op/__pycache__/upfirdn2d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/op/__pycache__/upfirdn2d.cpython-37.pyc
--------------------------------------------------------------------------------
/models/stylegan2/op/__pycache__/upfirdn2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/op/__pycache__/upfirdn2d.cpython-38.pyc
--------------------------------------------------------------------------------
/models/stylegan2/op/__pycache__/upfirdn2d.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/models/stylegan2/op/__pycache__/upfirdn2d.cpython-39.pyc
--------------------------------------------------------------------------------
/models/stylegan2/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Function
6 | from torch.utils.cpp_extension import load
7 |
8 | module_path = os.path.dirname(__file__)
9 | fused = load(
10 | 'fused',
11 | sources=[
12 | os.path.join(module_path, 'fused_bias_act.cpp'),
13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'),
14 | ],
15 | )
16 |
17 |
18 | class FusedLeakyReLUFunctionBackward(Function):
19 | @staticmethod
20 | def forward(ctx, grad_output, out, negative_slope, scale):
21 | ctx.save_for_backward(out)
22 | ctx.negative_slope = negative_slope
23 | ctx.scale = scale
24 |
25 | empty = grad_output.new_empty(0)
26 |
27 | grad_input = fused.fused_bias_act(
28 | grad_output, empty, out, 3, 1, negative_slope, scale
29 | )
30 |
31 | dim = [0]
32 |
33 | if grad_input.ndim > 2:
34 | dim += list(range(2, grad_input.ndim))
35 |
36 | grad_bias = grad_input.sum(dim).detach()
37 |
38 | return grad_input, grad_bias
39 |
40 | @staticmethod
41 | def backward(ctx, gradgrad_input, gradgrad_bias):
42 | out, = ctx.saved_tensors
43 | gradgrad_out = fused.fused_bias_act(
44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
45 | )
46 |
47 | return gradgrad_out, None, None, None
48 |
49 |
50 | class FusedLeakyReLUFunction(Function):
51 | @staticmethod
52 | def forward(ctx, input, bias, negative_slope, scale):
53 | empty = input.new_empty(0)
54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
55 | ctx.save_for_backward(out)
56 | ctx.negative_slope = negative_slope
57 | ctx.scale = scale
58 |
59 | return out
60 |
61 | @staticmethod
62 | def backward(ctx, grad_output):
63 | out, = ctx.saved_tensors
64 |
65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
66 | grad_output, out, ctx.negative_slope, ctx.scale
67 | )
68 |
69 | return grad_input, grad_bias, None, None
70 |
71 |
72 | class FusedLeakyReLU(nn.Module):
73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
74 | super().__init__()
75 |
76 | self.bias = nn.Parameter(torch.zeros(channel))
77 | self.negative_slope = negative_slope
78 | self.scale = scale
79 |
80 | def forward(self, input):
81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
82 |
83 |
84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
86 |
--------------------------------------------------------------------------------
/models/stylegan2/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/models/stylegan2/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/models/stylegan2/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/models/stylegan2/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.autograd import Function
5 | from torch.utils.cpp_extension import load
6 |
7 | module_path = os.path.dirname(__file__)
8 | upfirdn2d_op = load(
9 | 'upfirdn2d',
10 | sources=[
11 | os.path.join(module_path, 'upfirdn2d.cpp'),
12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'),
13 | ],
14 | )
15 |
16 |
17 | class UpFirDn2dBackward(Function):
18 | @staticmethod
19 | def forward(
20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
21 | ):
22 | up_x, up_y = up
23 | down_x, down_y = down
24 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
25 |
26 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
27 |
28 | grad_input = upfirdn2d_op.upfirdn2d(
29 | grad_output,
30 | grad_kernel,
31 | down_x,
32 | down_y,
33 | up_x,
34 | up_y,
35 | g_pad_x0,
36 | g_pad_x1,
37 | g_pad_y0,
38 | g_pad_y1,
39 | )
40 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
41 |
42 | ctx.save_for_backward(kernel)
43 |
44 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
45 |
46 | ctx.up_x = up_x
47 | ctx.up_y = up_y
48 | ctx.down_x = down_x
49 | ctx.down_y = down_y
50 | ctx.pad_x0 = pad_x0
51 | ctx.pad_x1 = pad_x1
52 | ctx.pad_y0 = pad_y0
53 | ctx.pad_y1 = pad_y1
54 | ctx.in_size = in_size
55 | ctx.out_size = out_size
56 |
57 | return grad_input
58 |
59 | @staticmethod
60 | def backward(ctx, gradgrad_input):
61 | kernel, = ctx.saved_tensors
62 |
63 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
64 |
65 | gradgrad_out = upfirdn2d_op.upfirdn2d(
66 | gradgrad_input,
67 | kernel,
68 | ctx.up_x,
69 | ctx.up_y,
70 | ctx.down_x,
71 | ctx.down_y,
72 | ctx.pad_x0,
73 | ctx.pad_x1,
74 | ctx.pad_y0,
75 | ctx.pad_y1,
76 | )
77 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
78 | gradgrad_out = gradgrad_out.view(
79 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
80 | )
81 |
82 | return gradgrad_out, None, None, None, None, None, None, None, None
83 |
84 |
85 | class UpFirDn2d(Function):
86 | @staticmethod
87 | def forward(ctx, input, kernel, up, down, pad):
88 | up_x, up_y = up
89 | down_x, down_y = down
90 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
91 |
92 | kernel_h, kernel_w = kernel.shape
93 | batch, channel, in_h, in_w = input.shape
94 | ctx.in_size = input.shape
95 |
96 | input = input.reshape(-1, in_h, in_w, 1)
97 |
98 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
99 |
100 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
101 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
102 | ctx.out_size = (out_h, out_w)
103 |
104 | ctx.up = (up_x, up_y)
105 | ctx.down = (down_x, down_y)
106 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
107 |
108 | g_pad_x0 = kernel_w - pad_x0 - 1
109 | g_pad_y0 = kernel_h - pad_y0 - 1
110 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
111 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
112 |
113 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
114 |
115 | out = upfirdn2d_op.upfirdn2d(
116 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
117 | )
118 | # out = out.view(major, out_h, out_w, minor)
119 | out = out.view(-1, channel, out_h, out_w)
120 |
121 | return out
122 |
123 | @staticmethod
124 | def backward(ctx, grad_output):
125 | kernel, grad_kernel = ctx.saved_tensors
126 |
127 | grad_input = UpFirDn2dBackward.apply(
128 | grad_output,
129 | kernel,
130 | grad_kernel,
131 | ctx.up,
132 | ctx.down,
133 | ctx.pad,
134 | ctx.g_pad,
135 | ctx.in_size,
136 | ctx.out_size,
137 | )
138 |
139 | return grad_input, None, None, None, None
140 |
141 |
142 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
143 | out = UpFirDn2d.apply(
144 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
145 | )
146 |
147 | return out
148 |
149 |
150 | def upfirdn2d_native(
151 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
152 | ):
153 | _, in_h, in_w, minor = input.shape
154 | kernel_h, kernel_w = kernel.shape
155 |
156 | out = input.view(-1, in_h, 1, in_w, 1, minor)
157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
159 |
160 | out = F.pad(
161 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
162 | )
163 | out = out[
164 | :,
165 | max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
166 | max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
167 | :,
168 | ]
169 |
170 | out = out.permute(0, 3, 1, 2)
171 | out = out.reshape(
172 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
173 | )
174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
175 | out = F.conv2d(out, w)
176 | out = out.reshape(
177 | -1,
178 | minor,
179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
181 | )
182 | out = out.permute(0, 2, 3, 1)
183 |
184 | return out[:, ::down_y, ::down_x, :]
185 |
--------------------------------------------------------------------------------
/models/w_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class WNormLoss(nn.Module):
6 |
7 | def __init__(self, start_from_latent_avg=True):
8 | super(WNormLoss, self).__init__()
9 | self.start_from_latent_avg = start_from_latent_avg
10 | self.norm = nn.BatchNorm2d(1)
11 |
12 | def forward(self, latent, latent_avg=None):
13 | if self.start_from_latent_avg:
14 | latent = latent - latent_avg
15 | return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
16 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/options/__init__.py
--------------------------------------------------------------------------------
/options/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/options/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/options/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/options/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/options/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/options/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/options/__pycache__/train_options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/options/__pycache__/train_options.cpython-37.pyc
--------------------------------------------------------------------------------
/options/__pycache__/train_options.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/options/__pycache__/train_options.cpython-38.pyc
--------------------------------------------------------------------------------
/options/__pycache__/train_options.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/options/__pycache__/train_options.cpython-39.pyc
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 |
4 | class TestOptions:
5 |
6 | def __init__(self):
7 | self.parser = ArgumentParser()
8 | self.initialize()
9 |
10 | def initialize(self):
11 | # arguments for inference script
12 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
13 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint')
14 | self.parser.add_argument('--data_path', type=str, default='gt_images', help='Path to directory of images to evaluate')
15 | self.parser.add_argument('--couple_outputs', action='store_true', help='Whether to also save inputs + outputs side-by-side')
16 | self.parser.add_argument('--resize_outputs', action='store_true', help='Whether to resize outputs to 256x256 or keep at 1024x1024')
17 |
18 | self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
19 | self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')
20 |
21 | # arguments for style-mixing script
22 | self.parser.add_argument('--n_images', type=int, default=None, help='Number of images to output. If None, run on all data')
23 | self.parser.add_argument('--n_outputs_to_generate', type=int, default=5, help='Number of outputs to generate per input image.')
24 | self.parser.add_argument('--mix_alpha', type=float, default=None, help='Alpha value for style-mixing')
25 | self.parser.add_argument('--latent_mask', type=str, default=None, help='Comma-separated list of latents to perform style-mixing with')
26 |
27 | # arguments for super-resolution
28 | self.parser.add_argument('--resize_factors', type=str, default=None,
29 | help='Downsampling factor for super-res (should be a single value for inference).')
30 |
31 | def parse(self):
32 | opts = self.parser.parse_args()
33 | return opts
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 |
4 | class TrainOptions:
5 |
6 | def __init__(self):
7 | self.parser = ArgumentParser()
8 | self.initialize()
9 |
10 | def initialize(self):
11 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
12 | self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str, help='Type of dataset/experiment to run')
13 | self.parser.add_argument('--encoder_type', default='GradualStyleEncoder', type=str, help='Which encoder to use')
14 | self.parser.add_argument('--input_nc', default=3, type=int, help='Number of input image channels to the psp encoder')
15 | self.parser.add_argument('--label_nc', default=0, type=int, help='Number of input label channels to the psp encoder')
16 | self.parser.add_argument('--output_size', default=1024, type=int, help='Output size of generator')
17 |
18 | self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for training')
19 | self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
20 | self.parser.add_argument('--workers', default=0, type=int, help='Number of train dataloader workers')
21 | self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')
22 |
23 | self.parser.add_argument('--learning_rate', default=0.000025, type=float, help='Optimizer learning rate')
24 | self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
25 | self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model')
26 | self.parser.add_argument('--start_from_latent_avg', default=True, type=bool, help='Whether to add average latent vector to generate codes from encoder.')
27 | self.parser.add_argument('--learn_in_w', action='store_true', help='Whether to learn in w space instead of w+')
28 |
29 | self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor')
30 | self.parser.add_argument('--id_lambda', default=0, type=float, help='ID loss multiplier factor')
31 | self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
32 | self.parser.add_argument('--w_norm_lambda', default=0, type=float, help='W-norm loss multiplier factor')
33 | self.parser.add_argument('--lpips_lambda_crop', default=0, type=float, help='LPIPS loss multiplier factor for inner image region')
34 | self.parser.add_argument('--l2_lambda_crop', default=0, type=float, help='L2 loss multiplier factor for inner image region')
35 | self.parser.add_argument('--moco_lambda', default=0, type=float, help='Moco-based feature similarity loss multiplier factor')
36 |
37 | self.parser.add_argument('--stylegan_weights', default='checkpoints/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights')
38 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint')
39 |
40 | self.parser.add_argument('--max_steps', default=1000000, type=int, help='Maximum number of training steps')
41 | self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training')
42 | self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard')
43 | self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval')
44 | self.parser.add_argument('--save_interval', default=10000, type=int, help='Model checkpoint interval')
45 |
46 | # arguments for weights & biases support
47 | self.parser.add_argument('--use_wandb', action="store_true", help='Whether to use Weights & Biases to track experiment.')
48 |
49 | # arguments for super-resolution
50 | self.parser.add_argument('--resize_factors', type=str, default=None, help='For super-res, comma-separated resize factors to use for inference.')
51 |
52 | self.parser.add_argument('--local_rank', default=None, type=int)
53 |
54 | def parse(self):
55 | opts = self.parser.parse_args()
56 | return opts
57 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | appdirs==1.4.4
3 | cachetools==5.3.1
4 | certifi==2023.7.22
5 | charset-normalizer==2.0.12
6 | click==8.1.6
7 | cycler==0.11.0
8 | docker-pycreds==0.4.0
9 | gitdb==4.0.10
10 | GitPython==3.1.32
11 | glfw==2.2.0
12 | google-auth==2.22.0
13 | google-auth-oauthlib==1.0.0
14 | grpcio==1.56.2
15 | idna==3.4
16 | imageio==2.31.1
17 | imageio-ffmpeg==0.4.3
18 | imgui==1.3.0
19 | importlib-metadata==6.8.0
20 | kiwisolver==1.4.4
21 | lazy_loader==0.3
22 | lpips==0.1.4
23 | Markdown==3.4.3
24 | MarkupSafe==2.1.3
25 | matplotlib==3.4.2
26 | mrcfile==1.4.3
27 | networkx==3.1
28 | ninja==1.10.2
29 | numpy==1.22.4
30 | oauthlib==3.2.2
31 | opencv-python==4.8.0.74
32 | packaging==23.1
33 | pathtools==0.1.2
34 | Pillow==10.0.0
35 | plyfile==1.0.1
36 | protobuf==4.23.4
37 | psutil==5.9.5
38 | pyasn1==0.5.0
39 | pyasn1-modules==0.3.0
40 | PyOpenGL==3.1.5
41 | pyparsing==3.1.0
42 | pyspng==0.1.1
43 | python-dateutil==2.8.2
44 | PyWavelets==1.4.1
45 | PyYAML==6.0.1
46 | requests==2.26.0
47 | requests-oauthlib==1.3.1
48 | rsa==4.9
49 | scikit-image==0.21.0
50 | scipy==1.10.1
51 | sentry-sdk==1.28.1
52 | setproctitle==1.3.2
53 | six==1.16.0
54 | smmap==5.0.0
55 | tensorboard==2.13.0
56 | tensorboard-data-server==0.7.1
57 | tifffile==2023.7.10
58 | torch==1.8.1+cu111
59 | torchaudio==0.8.1
60 | torchvision==0.9.1+cu111
61 | tqdm==4.62.2
62 | trimesh==3.22.5
63 | typing_extensions==4.7.1
64 | urllib3==1.26.16
65 | wandb==0.15.5
66 | Werkzeug==2.3.6
67 | zipp==3.16.2
68 |
--------------------------------------------------------------------------------
/run_3dSwap.py:
--------------------------------------------------------------------------------
1 | import os
2 | from models.faceswap_coach import FaceSwapCoach
3 | import argparse
4 |
5 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
7 |
8 |
9 | def run_3dSwap(args):
10 | coach = FaceSwapCoach()
11 | coach.run(args)
12 |
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--from_index', type=int, default=0)
17 | parser.add_argument('--to_index', type=int, default=1)
18 | parser.add_argument('--epoch', type=int, default=500)
19 | parser.add_argument('--lr', type=float, default=3e-4)
20 | parser.add_argument('--dataroot', type=str, default='datasets/CelebA-HD')
21 | args = parser.parse_args()
22 |
23 | run_3dSwap(args)
24 |
--------------------------------------------------------------------------------
/run_inversion.py:
--------------------------------------------------------------------------------
1 | import os
2 | from models.inversion_coach import InversionCoach
3 | import argparse
4 |
5 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
6 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
7 |
8 |
9 | def run_inversion(args):
10 | coach = InversionCoach()
11 | coach.run(args)
12 |
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--index', type=int, default=0)
17 | parser.add_argument('--epoch', type=int, default=500)
18 | parser.add_argument('--lr', type=float, default=3e-4)
19 | parser.add_argument('--dataroot', type=str, default='datasets/CelebA-HD')
20 | args = parser.parse_args()
21 |
22 | run_inversion(args)
23 |
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | # empty
12 |
--------------------------------------------------------------------------------
/torch_utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/torch_utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/__pycache__/custom_ops.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/custom_ops.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/__pycache__/custom_ops.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/custom_ops.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/__pycache__/misc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/misc.cpython-37.pyc
--------------------------------------------------------------------------------
/torch_utils/__pycache__/misc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/misc.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/__pycache__/misc.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/misc.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/__pycache__/persistence.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/persistence.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/__pycache__/persistence.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/persistence.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/__pycache__/training_stats.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/__pycache__/training_stats.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import glob
12 | import hashlib
13 | import importlib
14 | import os
15 | import re
16 | import shutil
17 | import uuid
18 |
19 | import torch
20 | import torch.utils.cpp_extension
21 | from torch.utils.file_baton import FileBaton
22 |
23 | #----------------------------------------------------------------------------
24 | # Global options.
25 |
26 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
27 |
28 | #----------------------------------------------------------------------------
29 | # Internal helper funcs.
30 |
31 | def _find_compiler_bindir():
32 | patterns = [
33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
34 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
35 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
36 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
37 | ]
38 | for pattern in patterns:
39 | matches = sorted(glob.glob(pattern))
40 | if len(matches):
41 | return matches[-1]
42 | return None
43 |
44 | #----------------------------------------------------------------------------
45 |
46 | def _get_mangled_gpu_name():
47 | name = torch.cuda.get_device_name().lower()
48 | out = []
49 | for c in name:
50 | if re.match('[a-z0-9_-]+', c):
51 | out.append(c)
52 | else:
53 | out.append('-')
54 | return ''.join(out)
55 |
56 | #----------------------------------------------------------------------------
57 | # Main entry point for compiling and loading C++/CUDA plugins.
58 |
59 | _cached_plugins = dict()
60 |
61 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
62 | assert verbosity in ['none', 'brief', 'full']
63 | if headers is None:
64 | headers = []
65 | if source_dir is not None:
66 | sources = [os.path.join(source_dir, fname) for fname in sources]
67 | headers = [os.path.join(source_dir, fname) for fname in headers]
68 |
69 | # Already cached?
70 | if module_name in _cached_plugins:
71 | return _cached_plugins[module_name]
72 |
73 | # Print status.
74 | if verbosity == 'full':
75 | print(f'Setting up PyTorch plugin "{module_name}"...')
76 | elif verbosity == 'brief':
77 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
78 | verbose_build = (verbosity == 'full')
79 |
80 | # Compile and load.
81 | try: # pylint: disable=too-many-nested-blocks
82 | # Make sure we can find the necessary compiler binaries.
83 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
84 | compiler_bindir = _find_compiler_bindir()
85 | if compiler_bindir is None:
86 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
87 | os.environ['PATH'] += ';' + compiler_bindir
88 |
89 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
90 | # break the build or unnecessarily restrict what's available to nvcc.
91 | # Unset it to let nvcc decide based on what's available on the
92 | # machine.
93 | os.environ['TORCH_CUDA_ARCH_LIST'] = ''
94 |
95 | # Incremental build md5sum trickery. Copies all the input source files
96 | # into a cached build directory under a combined md5 digest of the input
97 | # source files. Copying is done only if the combined digest has changed.
98 | # This keeps input file timestamps and filenames the same as in previous
99 | # extension builds, allowing for fast incremental rebuilds.
100 | #
101 | # This optimization is done only in case all the source files reside in
102 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
103 | # environment variable is set (we take this as a signal that the user
104 | # actually cares about this.)
105 | #
106 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
107 | # around the *.cu dependency bug in ninja config.
108 | #
109 | all_source_files = sorted(sources + headers)
110 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
111 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
112 |
113 | # Compute combined hash digest for all source files.
114 | hash_md5 = hashlib.md5()
115 | for src in all_source_files:
116 | with open(src, 'rb') as f:
117 | hash_md5.update(f.read())
118 |
119 | # Select cached build directory name.
120 | source_digest = hash_md5.hexdigest()
121 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
122 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
123 |
124 | if not os.path.isdir(cached_build_dir):
125 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
126 | os.makedirs(tmpdir)
127 | for src in all_source_files:
128 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
129 | try:
130 | os.replace(tmpdir, cached_build_dir) # atomic
131 | except OSError:
132 | # source directory already exists, delete tmpdir and its contents.
133 | shutil.rmtree(tmpdir)
134 | if not os.path.isdir(cached_build_dir): raise
135 |
136 | # Compile.
137 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
138 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
139 | verbose=verbose_build, sources=cached_sources, **build_kwargs)
140 | else:
141 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
142 |
143 | # Load.
144 | module = importlib.import_module(module_name)
145 |
146 | except:
147 | if verbosity == 'brief':
148 | print('Failed!')
149 | raise
150 |
151 | # Print status and add to cache dict.
152 | if verbosity == 'full':
153 | print(f'Done setting up PyTorch plugin "{module_name}".')
154 | elif verbosity == 'brief':
155 | print('Done.')
156 | _cached_plugins[module_name] = module
157 | return module
158 |
159 | #----------------------------------------------------------------------------
160 |
--------------------------------------------------------------------------------
/torch_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | # empty
12 |
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/filtered_lrelu.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/filtered_lrelu.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/filtered_lrelu.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/filtered_lrelu.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/fma.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/fma.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/fma.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/fma.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include
14 | #include
15 | #include
16 | #include "bias_act.h"
17 |
18 | //------------------------------------------------------------------------
19 |
20 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
21 | {
22 | if (x.dim() != y.dim())
23 | return false;
24 | for (int64_t i = 0; i < x.dim(); i++)
25 | {
26 | if (x.size(i) != y.size(i))
27 | return false;
28 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
29 | return false;
30 | }
31 | return true;
32 | }
33 |
34 | //------------------------------------------------------------------------
35 |
36 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
37 | {
38 | // Validate arguments.
39 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
40 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
41 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
42 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
43 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
44 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
45 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
46 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
47 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
48 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
49 |
50 | // Validate layout.
51 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
52 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
53 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
54 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
55 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
56 |
57 | // Create output tensor.
58 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
59 | torch::Tensor y = torch::empty_like(x);
60 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
61 |
62 | // Initialize CUDA kernel parameters.
63 | bias_act_kernel_params p;
64 | p.x = x.data_ptr();
65 | p.b = (b.numel()) ? b.data_ptr() : NULL;
66 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
67 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
68 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
69 | p.y = y.data_ptr();
70 | p.grad = grad;
71 | p.act = act;
72 | p.alpha = alpha;
73 | p.gain = gain;
74 | p.clamp = clamp;
75 | p.sizeX = (int)x.numel();
76 | p.sizeB = (int)b.numel();
77 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
78 |
79 | // Choose CUDA kernel.
80 | void* kernel;
81 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
82 | {
83 | kernel = choose_bias_act_kernel(p);
84 | });
85 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
86 |
87 | // Launch CUDA kernel.
88 | p.loopX = 4;
89 | int blockSize = 4 * 32;
90 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
91 | void* args[] = {&p};
92 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93 | return y;
94 | }
95 |
96 | //------------------------------------------------------------------------
97 |
98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99 | {
100 | m.def("bias_act", &bias_act);
101 | }
102 |
103 | //------------------------------------------------------------------------
104 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include
14 | #include "bias_act.h"
15 |
16 | //------------------------------------------------------------------------
17 | // Helpers.
18 |
19 | template struct InternalType;
20 | template <> struct InternalType { typedef double scalar_t; };
21 | template <> struct InternalType { typedef float scalar_t; };
22 | template <> struct InternalType { typedef float scalar_t; };
23 |
24 | //------------------------------------------------------------------------
25 | // CUDA kernel.
26 |
27 | template
28 | __global__ void bias_act_kernel(bias_act_kernel_params p)
29 | {
30 | typedef typename InternalType::scalar_t scalar_t;
31 | int G = p.grad;
32 | scalar_t alpha = (scalar_t)p.alpha;
33 | scalar_t gain = (scalar_t)p.gain;
34 | scalar_t clamp = (scalar_t)p.clamp;
35 | scalar_t one = (scalar_t)1;
36 | scalar_t two = (scalar_t)2;
37 | scalar_t expRange = (scalar_t)80;
38 | scalar_t halfExpRange = (scalar_t)40;
39 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
40 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
41 |
42 | // Loop over elements.
43 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
44 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
45 | {
46 | // Load.
47 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
48 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
49 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
50 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
51 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
52 | scalar_t yy = (gain != 0) ? yref / gain : 0;
53 | scalar_t y = 0;
54 |
55 | // Apply bias.
56 | ((G == 0) ? x : xref) += b;
57 |
58 | // linear
59 | if (A == 1)
60 | {
61 | if (G == 0) y = x;
62 | if (G == 1) y = x;
63 | }
64 |
65 | // relu
66 | if (A == 2)
67 | {
68 | if (G == 0) y = (x > 0) ? x : 0;
69 | if (G == 1) y = (yy > 0) ? x : 0;
70 | }
71 |
72 | // lrelu
73 | if (A == 3)
74 | {
75 | if (G == 0) y = (x > 0) ? x : x * alpha;
76 | if (G == 1) y = (yy > 0) ? x : x * alpha;
77 | }
78 |
79 | // tanh
80 | if (A == 4)
81 | {
82 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
83 | if (G == 1) y = x * (one - yy * yy);
84 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
85 | }
86 |
87 | // sigmoid
88 | if (A == 5)
89 | {
90 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
91 | if (G == 1) y = x * yy * (one - yy);
92 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
93 | }
94 |
95 | // elu
96 | if (A == 6)
97 | {
98 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
99 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
100 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
101 | }
102 |
103 | // selu
104 | if (A == 7)
105 | {
106 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
107 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
108 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
109 | }
110 |
111 | // softplus
112 | if (A == 8)
113 | {
114 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
115 | if (G == 1) y = x * (one - exp(-yy));
116 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
117 | }
118 |
119 | // swish
120 | if (A == 9)
121 | {
122 | if (G == 0)
123 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
124 | else
125 | {
126 | scalar_t c = exp(xref);
127 | scalar_t d = c + one;
128 | if (G == 1)
129 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
130 | else
131 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
132 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
133 | }
134 | }
135 |
136 | // Apply gain.
137 | y *= gain * dy;
138 |
139 | // Clamp.
140 | if (clamp >= 0)
141 | {
142 | if (G == 0)
143 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
144 | else
145 | y = (yref > -clamp & yref < clamp) ? y : 0;
146 | }
147 |
148 | // Store.
149 | ((T*)p.y)[xi] = (T)y;
150 | }
151 | }
152 |
153 | //------------------------------------------------------------------------
154 | // CUDA kernel selection.
155 |
156 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
157 | {
158 | if (p.act == 1) return (void*)bias_act_kernel;
159 | if (p.act == 2) return (void*)bias_act_kernel;
160 | if (p.act == 3) return (void*)bias_act_kernel;
161 | if (p.act == 4) return (void*)bias_act_kernel;
162 | if (p.act == 5) return (void*)bias_act_kernel;
163 | if (p.act == 6) return (void*)bias_act_kernel;
164 | if (p.act == 7) return (void*)bias_act_kernel;
165 | if (p.act == 8) return (void*)bias_act_kernel;
166 | if (p.act == 9) return (void*)bias_act_kernel;
167 | return NULL;
168 | }
169 |
170 | //------------------------------------------------------------------------
171 | // Template specializations.
172 |
173 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
174 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
175 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
176 |
177 | //------------------------------------------------------------------------
178 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | //------------------------------------------------------------------------
14 | // CUDA kernel parameters.
15 |
16 | struct bias_act_kernel_params
17 | {
18 | const void* x; // [sizeX]
19 | const void* b; // [sizeB] or NULL
20 | const void* xref; // [sizeX] or NULL
21 | const void* yref; // [sizeX] or NULL
22 | const void* dy; // [sizeX] or NULL
23 | void* y; // [sizeX]
24 |
25 | int grad;
26 | int act;
27 | float alpha;
28 | float gain;
29 | float clamp;
30 |
31 | int sizeX;
32 | int sizeB;
33 | int stepB;
34 | int loopX;
35 | };
36 |
37 | //------------------------------------------------------------------------
38 | // CUDA kernel selection.
39 |
40 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
41 |
42 | //------------------------------------------------------------------------
43 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """2D convolution with optional up/downsampling."""
12 |
13 | import torch
14 |
15 | from .. import misc
16 | from . import conv2d_gradfix
17 | from . import upfirdn2d
18 | from .upfirdn2d import _parse_padding
19 | from .upfirdn2d import _get_filter_size
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | def _get_weight_shape(w):
24 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
25 | shape = [int(sz) for sz in w.shape]
26 | misc.assert_shape(w, shape)
27 | return shape
28 |
29 | #----------------------------------------------------------------------------
30 |
31 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
32 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
33 | """
34 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
35 |
36 | # Flip weight if requested.
37 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
38 | if not flip_weight and (kw > 1 or kh > 1):
39 | w = w.flip([2, 3])
40 |
41 | # Execute using conv2d_gradfix.
42 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
43 | return op(x, w, stride=stride, padding=padding, groups=groups)
44 |
45 | #----------------------------------------------------------------------------
46 |
47 | @misc.profiled_function
48 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
49 | r"""2D convolution with optional up/downsampling.
50 |
51 | Padding is performed only once at the beginning, not between the operations.
52 |
53 | Args:
54 | x: Input tensor of shape
55 | `[batch_size, in_channels, in_height, in_width]`.
56 | w: Weight tensor of shape
57 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
58 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
59 | calling upfirdn2d.setup_filter(). None = identity (default).
60 | up: Integer upsampling factor (default: 1).
61 | down: Integer downsampling factor (default: 1).
62 | padding: Padding with respect to the upsampled image. Can be a single number
63 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
64 | (default: 0).
65 | groups: Split input channels into N groups (default: 1).
66 | flip_weight: False = convolution, True = correlation (default: True).
67 | flip_filter: False = convolution, True = correlation (default: False).
68 |
69 | Returns:
70 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
71 | """
72 | # Validate arguments.
73 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
74 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
75 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
76 | assert isinstance(up, int) and (up >= 1)
77 | assert isinstance(down, int) and (down >= 1)
78 | assert isinstance(groups, int) and (groups >= 1)
79 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
80 | fw, fh = _get_filter_size(f)
81 | px0, px1, py0, py1 = _parse_padding(padding)
82 |
83 | # Adjust padding to account for up/downsampling.
84 | if up > 1:
85 | px0 += (fw + up - 1) // 2
86 | px1 += (fw - up) // 2
87 | py0 += (fh + up - 1) // 2
88 | py1 += (fh - up) // 2
89 | if down > 1:
90 | px0 += (fw - down + 1) // 2
91 | px1 += (fw - down) // 2
92 | py0 += (fh - down + 1) // 2
93 | py1 += (fh - down) // 2
94 |
95 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
96 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
97 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
98 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
99 | return x
100 |
101 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
102 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
103 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
104 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
105 | return x
106 |
107 | # Fast path: downsampling only => use strided convolution.
108 | if down > 1 and up == 1:
109 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
110 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
111 | return x
112 |
113 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
114 | if up > 1:
115 | if groups == 1:
116 | w = w.transpose(0, 1)
117 | else:
118 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
119 | w = w.transpose(1, 2)
120 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
121 | px0 -= kw - 1
122 | px1 -= kw - up
123 | py0 -= kh - 1
124 | py1 -= kh - up
125 | pxt = max(min(-px0, -px1), 0)
126 | pyt = max(min(-py0, -py1), 0)
127 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
128 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
129 | if down > 1:
130 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
131 | return x
132 |
133 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
134 | if up == 1 and down == 1:
135 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
136 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
137 |
138 | # Fallback: Generic reference implementation.
139 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
140 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
141 | if down > 1:
142 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
143 | return x
144 |
145 | #----------------------------------------------------------------------------
146 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include
14 |
15 | //------------------------------------------------------------------------
16 | // CUDA kernel parameters.
17 |
18 | struct filtered_lrelu_kernel_params
19 | {
20 | // These parameters decide which kernel to use.
21 | int up; // upsampling ratio (1, 2, 4)
22 | int down; // downsampling ratio (1, 2, 4)
23 | int2 fuShape; // [size, 1] | [size, size]
24 | int2 fdShape; // [size, 1] | [size, size]
25 |
26 | int _dummy; // Alignment.
27 |
28 | // Rest of the parameters.
29 | const void* x; // Input tensor.
30 | void* y; // Output tensor.
31 | const void* b; // Bias tensor.
32 | unsigned char* s; // Sign tensor in/out. NULL if unused.
33 | const float* fu; // Upsampling filter.
34 | const float* fd; // Downsampling filter.
35 |
36 | int2 pad0; // Left/top padding.
37 | float gain; // Additional gain factor.
38 | float slope; // Leaky ReLU slope on negative side.
39 | float clamp; // Clamp after nonlinearity.
40 | int flip; // Filter kernel flip for gradient computation.
41 |
42 | int tilesXdim; // Original number of horizontal output tiles.
43 | int tilesXrep; // Number of horizontal tiles per CTA.
44 | int blockZofs; // Block z offset to support large minibatch, channel dimensions.
45 |
46 | int4 xShape; // [width, height, channel, batch]
47 | int4 yShape; // [width, height, channel, batch]
48 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
49 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
50 | int swLimit; // Active width of sign tensor in bytes.
51 |
52 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
53 | longlong4 yStride; //
54 | int64_t bStride; //
55 | longlong3 fuStride; //
56 | longlong3 fdStride; //
57 | };
58 |
59 | struct filtered_lrelu_act_kernel_params
60 | {
61 | void* x; // Input/output, modified in-place.
62 | unsigned char* s; // Sign tensor in/out. NULL if unused.
63 |
64 | float gain; // Additional gain factor.
65 | float slope; // Leaky ReLU slope on negative side.
66 | float clamp; // Clamp after nonlinearity.
67 |
68 | int4 xShape; // [width, height, channel, batch]
69 | longlong4 xStride; // Input/output tensor strides, same order as in shape.
70 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
71 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
72 | };
73 |
74 | //------------------------------------------------------------------------
75 | // CUDA kernel specialization.
76 |
77 | struct filtered_lrelu_kernel_spec
78 | {
79 | void* setup; // Function for filter kernel setup.
80 | void* exec; // Function for main operation.
81 | int2 tileOut; // Width/height of launch tile.
82 | int numWarps; // Number of warps per thread block, determines launch block size.
83 | int xrep; // For processing multiple horizontal tiles per thread block.
84 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
85 | };
86 |
87 | //------------------------------------------------------------------------
88 | // CUDA kernel selection.
89 |
90 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
91 | template void* choose_filtered_lrelu_act_kernel(void);
92 | template cudaError_t copy_filters(cudaStream_t stream);
93 |
94 | //------------------------------------------------------------------------
95 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_ns.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include "filtered_lrelu.cu"
14 |
15 | // Template/kernel specializations for no signs mode (no gradients required).
16 |
17 | // Full op, 32-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Full op, 64-bit indexing.
22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
24 |
25 | // Activation/signs only for generic variant. 64-bit indexing.
26 | template void* choose_filtered_lrelu_act_kernel(void);
27 | template void* choose_filtered_lrelu_act_kernel(void);
28 | template void* choose_filtered_lrelu_act_kernel(void);
29 |
30 | // Copy filters to constant memory.
31 | template cudaError_t copy_filters(cudaStream_t stream);
32 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_rd.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include "filtered_lrelu.cu"
14 |
15 | // Template/kernel specializations for sign read mode.
16 |
17 | // Full op, 32-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Full op, 64-bit indexing.
22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
24 |
25 | // Activation/signs only for generic variant. 64-bit indexing.
26 | template void* choose_filtered_lrelu_act_kernel(void);
27 | template void* choose_filtered_lrelu_act_kernel(void);
28 | template void* choose_filtered_lrelu_act_kernel(void);
29 |
30 | // Copy filters to constant memory.
31 | template cudaError_t copy_filters(cudaStream_t stream);
32 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_wr.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include "filtered_lrelu.cu"
14 |
15 | // Template/kernel specializations for sign write mode.
16 |
17 | // Full op, 32-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Full op, 64-bit indexing.
22 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
23 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
24 |
25 | // Activation/signs only for generic variant. 64-bit indexing.
26 | template void* choose_filtered_lrelu_act_kernel(void);
27 | template void* choose_filtered_lrelu_act_kernel(void);
28 | template void* choose_filtered_lrelu_act_kernel(void);
29 |
30 | // Copy filters to constant memory.
31 | template cudaError_t copy_filters(cudaStream_t stream);
32 |
--------------------------------------------------------------------------------
/torch_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
12 |
13 | import torch
14 |
15 | #----------------------------------------------------------------------------
16 |
17 | def fma(a, b, c): # => a * b + c
18 | return _FusedMultiplyAdd.apply(a, b, c)
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
23 | @staticmethod
24 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
25 | out = torch.addcmul(c, a, b)
26 | ctx.save_for_backward(a, b)
27 | ctx.c_shape = c.shape
28 | return out
29 |
30 | @staticmethod
31 | def backward(ctx, dout): # pylint: disable=arguments-differ
32 | a, b = ctx.saved_tensors
33 | c_shape = ctx.c_shape
34 | da = None
35 | db = None
36 | dc = None
37 |
38 | if ctx.needs_input_grad[0]:
39 | da = _unbroadcast(dout * b, a.shape)
40 |
41 | if ctx.needs_input_grad[1]:
42 | db = _unbroadcast(dout * a, b.shape)
43 |
44 | if ctx.needs_input_grad[2]:
45 | dc = _unbroadcast(dout, c_shape)
46 |
47 | return da, db, dc
48 |
49 | #----------------------------------------------------------------------------
50 |
51 | def _unbroadcast(x, shape):
52 | extra_dims = x.ndim - len(shape)
53 | assert extra_dims >= 0
54 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
55 | if len(dim):
56 | x = x.sum(dim=dim, keepdim=True)
57 | if extra_dims:
58 | x = x.reshape(-1, *x.shape[extra_dims+1:])
59 | assert x.shape == shape
60 | return x
61 |
62 | #----------------------------------------------------------------------------
63 |
--------------------------------------------------------------------------------
/torch_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """Custom replacement for `torch.nn.functional.grid_sample` that
12 | supports arbitrarily high order gradients between the input and output.
13 | Only works on 2D images and assumes
14 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
15 |
16 | import torch
17 |
18 | # pylint: disable=redefined-builtin
19 | # pylint: disable=arguments-differ
20 | # pylint: disable=protected-access
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | enabled = False # Enable the custom op by setting this to true.
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | def grid_sample(input, grid):
29 | if _should_use_custom_op():
30 | return _GridSample2dForward.apply(input, grid)
31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def _should_use_custom_op():
36 | return enabled
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | class _GridSample2dForward(torch.autograd.Function):
41 | @staticmethod
42 | def forward(ctx, input, grid):
43 | assert input.ndim == 4
44 | assert grid.ndim == 4
45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
46 | ctx.save_for_backward(input, grid)
47 | return output
48 |
49 | @staticmethod
50 | def backward(ctx, grad_output):
51 | input, grid = ctx.saved_tensors
52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
53 | return grad_input, grad_grid
54 |
55 | #----------------------------------------------------------------------------
56 |
57 | class _GridSample2dBackward(torch.autograd.Function):
58 | @staticmethod
59 | def forward(ctx, grad_output, input, grid):
60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
61 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
62 | ctx.save_for_backward(grid)
63 | return grad_input, grad_grid
64 |
65 | @staticmethod
66 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
67 | _ = grad2_grad_grid # unused
68 | grid, = ctx.saved_tensors
69 | grad2_grad_output = None
70 | grad2_input = None
71 | grad2_grid = None
72 |
73 | if ctx.needs_input_grad[0]:
74 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
75 |
76 | assert not ctx.needs_input_grad[2]
77 | return grad2_grad_output, grad2_input, grad2_grid
78 |
79 | #----------------------------------------------------------------------------
80 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include
14 | #include
15 | #include
16 | #include "upfirdn2d.h"
17 |
18 | //------------------------------------------------------------------------
19 |
20 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
21 | {
22 | // Validate arguments.
23 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
24 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
25 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
26 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
27 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
28 | TORCH_CHECK(x.numel() > 0, "x has zero size");
29 | TORCH_CHECK(f.numel() > 0, "f has zero size");
30 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
31 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
32 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
33 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
34 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
35 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
36 |
37 | // Create output tensor.
38 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
39 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
40 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
41 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
42 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
43 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
44 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
45 |
46 | // Initialize CUDA kernel parameters.
47 | upfirdn2d_kernel_params p;
48 | p.x = x.data_ptr();
49 | p.f = f.data_ptr();
50 | p.y = y.data_ptr();
51 | p.up = make_int2(upx, upy);
52 | p.down = make_int2(downx, downy);
53 | p.pad0 = make_int2(padx0, pady0);
54 | p.flip = (flip) ? 1 : 0;
55 | p.gain = gain;
56 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
57 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
58 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
59 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
60 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
61 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
62 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
63 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
64 |
65 | // Choose CUDA kernel.
66 | upfirdn2d_kernel_spec spec;
67 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
68 | {
69 | spec = choose_upfirdn2d_kernel(p);
70 | });
71 |
72 | // Set looping options.
73 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
74 | p.loopMinor = spec.loopMinor;
75 | p.loopX = spec.loopX;
76 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
77 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
78 |
79 | // Compute grid size.
80 | dim3 blockSize, gridSize;
81 | if (spec.tileOutW < 0) // large
82 | {
83 | blockSize = dim3(4, 32, 1);
84 | gridSize = dim3(
85 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
86 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
87 | p.launchMajor);
88 | }
89 | else // small
90 | {
91 | blockSize = dim3(256, 1, 1);
92 | gridSize = dim3(
93 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
94 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
95 | p.launchMajor);
96 | }
97 |
98 | // Launch CUDA kernel.
99 | void* args[] = {&p};
100 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
101 | return y;
102 | }
103 |
104 | //------------------------------------------------------------------------
105 |
106 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
107 | {
108 | m.def("upfirdn2d", &upfirdn2d);
109 | }
110 |
111 | //------------------------------------------------------------------------
112 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4 | *
5 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6 | * property and proprietary rights in and to this material, related
7 | * documentation and any modifications thereto. Any use, reproduction,
8 | * disclosure or distribution of this material and related documentation
9 | * without an express license agreement from NVIDIA CORPORATION or
10 | * its affiliates is strictly prohibited.
11 | */
12 |
13 | #include
14 |
15 | //------------------------------------------------------------------------
16 | // CUDA kernel parameters.
17 |
18 | struct upfirdn2d_kernel_params
19 | {
20 | const void* x;
21 | const float* f;
22 | void* y;
23 |
24 | int2 up;
25 | int2 down;
26 | int2 pad0;
27 | int flip;
28 | float gain;
29 |
30 | int4 inSize; // [width, height, channel, batch]
31 | int4 inStride;
32 | int2 filterSize; // [width, height]
33 | int2 filterStride;
34 | int4 outSize; // [width, height, channel, batch]
35 | int4 outStride;
36 | int sizeMinor;
37 | int sizeMajor;
38 |
39 | int loopMinor;
40 | int loopMajor;
41 | int loopX;
42 | int launchMinor;
43 | int launchMajor;
44 | };
45 |
46 | //------------------------------------------------------------------------
47 | // CUDA kernel specialization.
48 |
49 | struct upfirdn2d_kernel_spec
50 | {
51 | void* kernel;
52 | int tileOutW;
53 | int tileOutH;
54 | int loopMinor;
55 | int loopX;
56 | };
57 |
58 | //------------------------------------------------------------------------
59 | // CUDA kernel selection.
60 |
61 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
62 |
63 | //------------------------------------------------------------------------
64 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | # empty
12 |
--------------------------------------------------------------------------------
/training/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/networks_stylegan2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/__pycache__/networks_stylegan2.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/networks_stylegan3.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/__pycache__/networks_stylegan3.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/ranger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/__pycache__/ranger.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/superresolution.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/__pycache__/superresolution.cpython-38.pyc
--------------------------------------------------------------------------------
/training/__pycache__/triplane.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/__pycache__/triplane.cpython-38.pyc
--------------------------------------------------------------------------------
/training/crosssection_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | import torch
12 |
13 | def sample_cross_section(G, ws, resolution=256, w=1.2):
14 | axis=0
15 | A, B = torch.meshgrid(torch.linspace(w/2, -w/2, resolution, device=ws.device), torch.linspace(-w/2, w/2, resolution, device=ws.device), indexing='ij')
16 | A, B = A.reshape(-1, 1), B.reshape(-1, 1)
17 | C = torch.zeros_like(A)
18 | coordinates = [A, B]
19 | coordinates.insert(axis, C)
20 | coordinates = torch.cat(coordinates, dim=-1).expand(ws.shape[0], -1, -1)
21 |
22 | sigma = G.sample_mixed(coordinates, torch.randn_like(coordinates), ws)['sigma']
23 | return sigma.reshape(-1, 1, resolution, resolution)
24 |
25 | # if __name__ == '__main__':
26 | # sample_crossection(None)
--------------------------------------------------------------------------------
/training/projectors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/projectors/__init__.py
--------------------------------------------------------------------------------
/training/projectors/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/projectors/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/training/projectors/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/projectors/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/training/projectors/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/projectors/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/training/projectors/__pycache__/w_plus_projector.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/projectors/__pycache__/w_plus_projector.cpython-38.pyc
--------------------------------------------------------------------------------
/training/projectors/__pycache__/w_plus_projector.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/projectors/__pycache__/w_plus_projector.cpython-39.pyc
--------------------------------------------------------------------------------
/training/projectors/__pycache__/w_projector.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/projectors/__pycache__/w_projector.cpython-36.pyc
--------------------------------------------------------------------------------
/training/projectors/__pycache__/w_projector.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/projectors/__pycache__/w_projector.cpython-38.pyc
--------------------------------------------------------------------------------
/training/projectors/__pycache__/w_projector.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/projectors/__pycache__/w_projector.cpython-39.pyc
--------------------------------------------------------------------------------
/training/projectors/w_plus_projector.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Project given image to the latent space of pretrained network pickle."""
10 |
11 | import copy
12 | import wandb
13 | import numpy as np
14 | import torch
15 | import torch.nn.functional as F
16 | from tqdm import tqdm
17 | from configs import global_config, hyperparameters
18 | import dnnlib
19 | from utils.log_utils import log_image_from_w
20 | import PIL.Image
21 |
22 |
23 |
24 | def project(
25 | G,
26 | target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
27 | parameters,
28 | *,
29 | num_steps=1000,
30 | w_avg_samples=10000,
31 | initial_learning_rate=0.01,
32 | initial_noise_factor=0.05,
33 | lr_rampdown_length=0.25,
34 | lr_rampup_length=0.05,
35 | noise_ramp_length=0.75,
36 | regularize_noise_weight=1e5,
37 | verbose=False,
38 | device: torch.device,
39 | use_wandb=False,
40 | initial_w=None,
41 | image_log_step=global_config.image_rec_result_log_snapshot,
42 | w_name: str
43 | ):
44 | assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
45 |
46 | def logprint(*args):
47 | if verbose:
48 | print(*args)
49 |
50 | G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore
51 |
52 | # Compute w stats.
53 | logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
54 | z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
55 | w_samples = G.mapping(torch.from_numpy(z_samples).to(device), parameters[3].repeat([w_avg_samples, 1]), truncation_psi=parameters[1], truncation_cutoff=parameters[2])
56 | w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
57 | w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
58 | w_avg_tensor = torch.from_numpy(w_avg).to(global_config.device)
59 | w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
60 |
61 | start_w = initial_w if initial_w is not None else w_avg
62 |
63 | # Setup noise inputs.
64 | noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name}
65 |
66 | # Load VGG16 feature detector.
67 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
68 | with dnnlib.util.open_url(url) as f:
69 | vgg16 = torch.jit.load(f).eval().to(device)
70 |
71 | # Features for target image.
72 | target_images = target.unsqueeze(0).to(device).to(torch.float32)
73 | if target_images.shape[2] > 256:
74 | target_images = F.interpolate(target_images, size=(256, 256), mode='area')
75 | target_features = vgg16(target_images, resize_images=False, return_lpips=True)
76 |
77 | start_w = np.repeat(start_w, G.backbone.mapping.num_ws, axis=1)
78 | w_opt = torch.tensor(start_w, dtype=torch.float32, device=device,
79 | requires_grad=True) # pylint: disable=not-callable
80 |
81 | optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999),
82 | lr=hyperparameters.first_inv_lr)
83 |
84 | # Init noise.
85 | for buf in noise_bufs.values():
86 | buf[:] = torch.randn_like(buf)
87 | buf.requires_grad = True
88 |
89 | for step in tqdm(range(num_steps)):
90 |
91 | # Learning rate schedule.
92 | t = step / num_steps
93 | w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
94 | lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
95 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
96 | lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
97 | lr = initial_learning_rate * lr_ramp
98 | for param_group in optimizer.param_groups:
99 | param_group['lr'] = lr
100 |
101 | # Synth images from opt_w.
102 | w_noise = torch.randn_like(w_opt) * w_noise_scale
103 | ws = (w_opt + w_noise)
104 |
105 | synth_images = G.synthesis(ws, parameters[0])['image']
106 |
107 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
108 | synth_images = (synth_images + 1) * (255 / 2)
109 | if synth_images.shape[2] > 256:
110 | synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
111 |
112 | # Features for synth images.
113 | synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
114 | dist = (target_features - synth_features).square().sum()
115 |
116 | img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
117 | PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'test_inversion/' + str(step) + '.png')
118 |
119 | # Noise regularization.
120 | reg_loss = 0.0
121 | for v in noise_bufs.values():
122 | noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
123 | while True:
124 | reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
125 | reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
126 | if noise.shape[2] <= 8:
127 | break
128 | noise = F.avg_pool2d(noise, kernel_size=2)
129 | loss = dist + reg_loss * regularize_noise_weight
130 |
131 | if step % image_log_step == 0:
132 | with torch.no_grad():
133 | if use_wandb:
134 | global_config.training_step += 1
135 | wandb.log({f'first projection _{w_name}': loss.detach().cpu()}, step=global_config.training_step)
136 | log_image_from_w(w_opt, G, w_name)
137 |
138 | # Step
139 | optimizer.zero_grad(set_to_none=True)
140 | loss.backward()
141 | optimizer.step()
142 | logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
143 |
144 | # Normalize noise.
145 | with torch.no_grad():
146 | for buf in noise_bufs.values():
147 | buf -= buf.mean()
148 | buf *= buf.square().mean().rsqrt()
149 |
150 | del G
151 | return w_opt
152 |
--------------------------------------------------------------------------------
/training/projectors/w_projector.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Project given image to the latent space of pretrained network pickle."""
10 |
11 | import copy
12 | import wandb
13 | import numpy as np
14 | import torch
15 | import torch.nn.functional as F
16 | from tqdm import tqdm
17 | from configs import global_config, hyperparameters
18 | from utils import log_utils
19 | import dnnlib
20 |
21 |
22 | def project(
23 | G,
24 | target: torch.Tensor,# [C,H,W] and dynamic range [0,255], W & H must match G output resolution
25 | parameters,
26 | *,
27 | num_steps=1000,
28 | w_avg_samples=10000,
29 | initial_learning_rate=0.01,
30 | initial_noise_factor=0.05,
31 | lr_rampdown_length=0.25,
32 | lr_rampup_length=0.05,
33 | noise_ramp_length=0.75,
34 | regularize_noise_weight=1e5,
35 | verbose=False,
36 | device: torch.device,
37 | use_wandb=False,
38 | initial_w=None,
39 | image_log_step=global_config.image_rec_result_log_snapshot,
40 | w_name: str
41 | ):
42 | assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
43 |
44 | def logprint(*args):
45 | if verbose:
46 | print(*args)
47 |
48 | G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore
49 |
50 | # Compute w stats.
51 | logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
52 | z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
53 | w_samples = G.mapping(torch.from_numpy(z_samples).to(device), parameters[3].repeat([w_avg_samples, 1]), truncation_psi=parameters[1], truncation_cutoff=parameters[2]) # [N, L, C]
54 | w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
55 | w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
56 | w_avg_tensor = torch.from_numpy(w_avg).to(global_config.device)
57 | w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
58 |
59 | start_w = initial_w if initial_w is not None else w_avg
60 |
61 | # Setup noise inputs.
62 | noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name}
63 |
64 | # Load VGG16 feature detector.
65 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
66 | with dnnlib.util.open_url(url) as f:
67 | vgg16 = torch.jit.load(f).eval().to(device)
68 |
69 | # Features for target image.
70 | target_images = target.unsqueeze(0).to(device).to(torch.float32)
71 | if target_images.shape[2] > 256:
72 | target_images = F.interpolate(target_images, size=(256, 256), mode='area')
73 | target_features = vgg16(target_images, resize_images=False, return_lpips=True)
74 |
75 | w_opt = torch.tensor(start_w, dtype=torch.float32, device=device,
76 | requires_grad=True) # pylint: disable=not-callable
77 | optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999),
78 | lr=hyperparameters.first_inv_lr)
79 |
80 | # Init noise.
81 | for buf in noise_bufs.values():
82 | buf[:] = torch.randn_like(buf)
83 | buf.requires_grad = True
84 |
85 | for step in tqdm(range(num_steps)):
86 |
87 | # Learning rate schedule.
88 | t = step / num_steps
89 | w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
90 | lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
91 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
92 | lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
93 | lr = initial_learning_rate * lr_ramp
94 | for param_group in optimizer.param_groups:
95 | param_group['lr'] = lr
96 |
97 | # Synth images from opt_w.
98 | w_noise = torch.randn_like(w_opt) * w_noise_scale
99 | ws = (w_opt + w_noise).repeat([1, 14, 1])
100 | synth_images = G.synthesis(ws, parameters[0])['image']
101 |
102 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
103 | synth_images = (synth_images + 1) * (255 / 2)
104 | if synth_images.shape[2] > 256:
105 | synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
106 |
107 | # Features for synth images.
108 | synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
109 | dist = (target_features - synth_features).square().sum()
110 |
111 | # Noise regularization.
112 | reg_loss = 0.0
113 | for v in noise_bufs.values():
114 | noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
115 | while True:
116 | reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
117 | reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
118 | if noise.shape[2] <= 8:
119 | break
120 | noise = F.avg_pool2d(noise, kernel_size=2)
121 | loss = dist + reg_loss * regularize_noise_weight
122 |
123 | if step % image_log_step == 0:
124 | with torch.no_grad():
125 | if use_wandb:
126 | global_config.training_step += 1
127 | wandb.log({f'first projection _{w_name}': loss.detach().cpu()}, step=global_config.training_step)
128 | log_utils.log_image_from_w(w_opt.repeat([1, G.mapping.num_ws, 1]), G, w_name)
129 |
130 | # Step
131 | optimizer.zero_grad(set_to_none=True)
132 | loss.backward()
133 | optimizer.step()
134 | logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
135 |
136 | # Normalize noise.
137 | with torch.no_grad():
138 | for buf in noise_bufs.values():
139 | buf -= buf.mean()
140 | buf *= buf.square().mean().rsqrt()
141 |
142 | del G
143 | return w_opt.repeat([1, 14, 1])
144 |
--------------------------------------------------------------------------------
/training/ranger.py:
--------------------------------------------------------------------------------
1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
2 |
3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
4 | # and/or
5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers
6 |
7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard.
8 |
9 | # This version = 20.4.11
10 |
11 | # Credits:
12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam
14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
16 |
17 | # summary of changes:
18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold;
22 | # changed eps to 1e-5 as better default than 1e-8.
23 |
24 | import math
25 | import torch
26 | from torch.optim.optimizer import Optimizer
27 |
28 |
29 | class Ranger(Optimizer):
30 |
31 | def __init__(self, params, lr=1e-3, # lr
32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
34 | use_gc=True, gc_conv_only=False
35 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers
36 | ):
37 |
38 | # parameter checks
39 | if not 0.0 <= alpha <= 1.0:
40 | raise ValueError(f'Invalid slow update rate: {alpha}')
41 | if not 1 <= k:
42 | raise ValueError(f'Invalid lookahead steps: {k}')
43 | if not lr > 0:
44 | raise ValueError(f'Invalid Learning Rate: {lr}')
45 | if not eps > 0:
46 | raise ValueError(f'Invalid eps: {eps}')
47 |
48 | # parameter comments:
49 | # beta1 (momentum) of .95 seems to work better than .90...
50 | # N_sma_threshold of 5 seems better in testing than 4.
51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
52 |
53 | # prep defaults and init torch.optim base
54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
55 | eps=eps, weight_decay=weight_decay)
56 | super().__init__(params, defaults)
57 |
58 | # adjustable threshold
59 | self.N_sma_threshhold = N_sma_threshhold
60 |
61 | # look ahead params
62 |
63 | self.alpha = alpha
64 | self.k = k
65 |
66 | # radam buffer for state
67 | self.radam_buffer = [[None, None, None] for ind in range(10)]
68 |
69 | # gc on or off
70 | self.use_gc = use_gc
71 |
72 | # level of gradient centralization
73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1
74 |
75 | def __setstate__(self, state):
76 | super(Ranger, self).__setstate__(state)
77 |
78 | def step(self, closure=None):
79 | loss = None
80 |
81 | # Evaluate averages and grad, update param tensors
82 | for group in self.param_groups:
83 |
84 | for p in group['params']:
85 | if p.grad is None:
86 | continue
87 | grad = p.grad.data.float()
88 |
89 | if grad.is_sparse:
90 | raise RuntimeError('Ranger optimizer does not support sparse gradients')
91 |
92 | p_data_fp32 = p.data.float()
93 |
94 | state = self.state[p] # get state dict for this param
95 |
96 | if len(state) == 0: # if first time to run...init dictionary with our desired entries
97 | # if self.first_run_check==0:
98 | # self.first_run_check=1
99 | # print("Initializing slow buffer...should not see this at load from saved model!")
100 | state['step'] = 0
101 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
102 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
103 |
104 | # look ahead weight storage now in state dict
105 | state['slow_buffer'] = torch.empty_like(p.data)
106 | state['slow_buffer'].copy_(p.data)
107 |
108 | else:
109 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
110 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
111 |
112 | # begin computations
113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
114 | beta1, beta2 = group['betas']
115 |
116 | # GC operation for Conv layers and FC layers
117 | if grad.dim() > self.gc_gradient_threshold:
118 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
119 |
120 | state['step'] += 1
121 |
122 | # compute variance mov avg
123 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
124 | # compute mean moving avg
125 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
126 |
127 | buffered = self.radam_buffer[int(state['step'] % 10)]
128 |
129 | if state['step'] == buffered[0]:
130 | N_sma, step_size = buffered[1], buffered[2]
131 | else:
132 | buffered[0] = state['step']
133 | beta2_t = beta2 ** state['step']
134 | N_sma_max = 2 / (1 - beta2) - 1
135 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
136 | buffered[1] = N_sma
137 | if N_sma > self.N_sma_threshhold:
138 | step_size = math.sqrt(
139 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
140 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
141 | else:
142 | step_size = 1.0 / (1 - beta1 ** state['step'])
143 | buffered[2] = step_size
144 |
145 | if group['weight_decay'] != 0:
146 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
147 |
148 | # apply lr
149 | if N_sma > self.N_sma_threshhold:
150 | denom = exp_avg_sq.sqrt().add_(group['eps'])
151 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
152 | else:
153 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
154 |
155 | p.data.copy_(p_data_fp32)
156 |
157 | # integrated look ahead...
158 | # we do it at the param level instead of group level
159 | if state['step'] % group['k'] == 0:
160 | slow_p = state['slow_buffer'] # get access to slow param tensor
161 | slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
162 | p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
163 |
164 | return loss
--------------------------------------------------------------------------------
/training/volumetric_rendering/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | # empty
--------------------------------------------------------------------------------
/training/volumetric_rendering/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/volumetric_rendering/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/training/volumetric_rendering/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/volumetric_rendering/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/training/volumetric_rendering/__pycache__/math_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/volumetric_rendering/__pycache__/math_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/training/volumetric_rendering/__pycache__/math_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/volumetric_rendering/__pycache__/math_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/training/volumetric_rendering/__pycache__/ray_marcher.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/volumetric_rendering/__pycache__/ray_marcher.cpython-38.pyc
--------------------------------------------------------------------------------
/training/volumetric_rendering/__pycache__/ray_marcher.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/volumetric_rendering/__pycache__/ray_marcher.cpython-39.pyc
--------------------------------------------------------------------------------
/training/volumetric_rendering/__pycache__/ray_sampler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/volumetric_rendering/__pycache__/ray_sampler.cpython-38.pyc
--------------------------------------------------------------------------------
/training/volumetric_rendering/__pycache__/ray_sampler.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/volumetric_rendering/__pycache__/ray_sampler.cpython-39.pyc
--------------------------------------------------------------------------------
/training/volumetric_rendering/__pycache__/renderer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/volumetric_rendering/__pycache__/renderer.cpython-38.pyc
--------------------------------------------------------------------------------
/training/volumetric_rendering/__pycache__/renderer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/training/volumetric_rendering/__pycache__/renderer.cpython-39.pyc
--------------------------------------------------------------------------------
/training/volumetric_rendering/math_utils.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2022 Petr Kellnhofer
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import torch
24 |
25 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
26 | """
27 | Left-multiplies MxM @ NxM. Returns NxM.
28 | """
29 | res = torch.matmul(vectors4, matrix.T)
30 | return res
31 |
32 |
33 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
34 | """
35 | Normalize vector lengths.
36 | """
37 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
38 |
39 | def torch_dot(x: torch.Tensor, y: torch.Tensor):
40 | """
41 | Dot product of two tensors.
42 | """
43 | return (x * y).sum(-1)
44 |
45 |
46 | def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
47 | """
48 | Author: Petr Kellnhofer
49 | Intersects rays with the [-1, 1] NDC volume.
50 | Returns min and max distance of entry.
51 | Returns -1 for no intersection.
52 | https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
53 | """
54 | o_shape = rays_o.shape
55 | rays_o = rays_o.detach().reshape(-1, 3)
56 | rays_d = rays_d.detach().reshape(-1, 3)
57 |
58 |
59 | bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
60 | bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
61 | bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
62 | is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
63 |
64 | # Precompute inverse for stability.
65 | invdir = 1 / rays_d
66 | sign = (invdir < 0).long()
67 |
68 | # Intersect with YZ plane.
69 | tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
70 | tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
71 |
72 | # Intersect with XZ plane.
73 | tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
74 | tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
75 |
76 | # Resolve parallel rays.
77 | is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
78 |
79 | # Use the shortest intersection.
80 | tmin = torch.max(tmin, tymin)
81 | tmax = torch.min(tmax, tymax)
82 |
83 | # Intersect with XY plane.
84 | tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
85 | tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
86 |
87 | # Resolve parallel rays.
88 | is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
89 |
90 | # Use the shortest intersection.
91 | tmin = torch.max(tmin, tzmin)
92 | tmax = torch.min(tmax, tzmax)
93 |
94 | # Mark invalid.
95 | tmin[torch.logical_not(is_valid)] = -1
96 | tmax[torch.logical_not(is_valid)] = -2
97 |
98 | return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
99 |
100 |
101 | def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
102 | """
103 | Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
104 | Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
105 | """
106 | # create a tensor of 'num' steps from 0 to 1
107 | steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
108 |
109 | # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
110 | # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
111 | # "cannot statically infer the expected size of a list in this contex", hence the code below
112 | for i in range(start.ndim):
113 | steps = steps.unsqueeze(-1)
114 |
115 | # the output starts at 'start' and increments until 'stop' in each dimension
116 | out = start[None] + steps * (stop - start)[None]
117 |
118 | return out
119 |
--------------------------------------------------------------------------------
/training/volumetric_rendering/ray_marcher.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """
12 | The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
13 | Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
14 | """
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 |
20 | class MipRayMarcher2(nn.Module):
21 | def __init__(self):
22 | super().__init__()
23 |
24 |
25 | def run_forward(self, colors, densities, depths, rendering_options):
26 | deltas = depths[:, :, 1:] - depths[:, :, :-1]
27 | colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
28 | densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
29 | depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
30 |
31 |
32 | if rendering_options['clamp_mode'] == 'softplus':
33 | densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better
34 | else:
35 | assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!"
36 |
37 | density_delta = densities_mid * deltas
38 |
39 | alpha = 1 - torch.exp(-density_delta)
40 |
41 | alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
42 | weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
43 |
44 | composite_rgb = torch.sum(weights * colors_mid, -2)
45 | weight_total = weights.sum(2)
46 | composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
47 |
48 | # clip the composite to min/max range of depths
49 | composite_depth = torch.nan_to_num(composite_depth, float('inf'))
50 | composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
51 |
52 | if rendering_options.get('white_back', False):
53 | composite_rgb = composite_rgb + 1 - weight_total
54 |
55 | composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
56 |
57 | return composite_rgb, composite_depth, weights
58 |
59 |
60 | def forward(self, colors, densities, depths, rendering_options):
61 | composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
62 |
63 | return composite_rgb, composite_depth, weights
--------------------------------------------------------------------------------
/training/volumetric_rendering/ray_sampler.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """
12 | The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
13 | Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
14 | """
15 |
16 | import torch
17 |
18 | class RaySampler(torch.nn.Module):
19 | def __init__(self):
20 | super().__init__()
21 | self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
22 |
23 |
24 | def forward(self, cam2world_matrix, intrinsics, resolution):
25 | """
26 | Create batches of rays and return origins and directions.
27 |
28 | cam2world_matrix: (N, 4, 4)
29 | intrinsics: (N, 3, 3)
30 | resolution: int
31 |
32 | ray_origins: (N, M, 3)
33 | ray_dirs: (N, M, 2)
34 | """
35 | N, M = cam2world_matrix.shape[0], resolution**2
36 | cam_locs_world = cam2world_matrix[:, :3, 3]
37 | fx = intrinsics[:, 0, 0]
38 | fy = intrinsics[:, 1, 1]
39 | cx = intrinsics[:, 0, 2]
40 | cy = intrinsics[:, 1, 2]
41 | sk = intrinsics[:, 0, 1]
42 |
43 | uv = torch.stack(torch.meshgrid(torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device))) * (1./resolution) + (0.5/resolution)
44 | uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
45 | uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
46 |
47 | x_cam = uv[:, :, 0].view(N, -1)
48 | y_cam = uv[:, :, 1].view(N, -1)
49 | z_cam = torch.ones((N, M), device=cam2world_matrix.device)
50 |
51 | x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
52 | y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
53 |
54 | cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
55 |
56 | world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
57 |
58 | ray_dirs = world_rel_points - cam_locs_world[:, None, :]
59 | ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
60 |
61 | ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
62 |
63 | return ray_origins, ray_dirs
--------------------------------------------------------------------------------
/utils/__pycache__/camera_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/utils/__pycache__/camera_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/data_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/utils/__pycache__/data_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/legacy.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lyx0208/3dSwap/e451c50c089132317b496beea3da30b9299807e1/utils/__pycache__/legacy.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 | """
12 | Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts.
13 | """
14 |
15 | import math
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 | from training.volumetric_rendering import math_utils
21 |
22 | class GaussianCameraPoseSampler:
23 | """
24 | Samples pitch and yaw from a Gaussian distribution and returns a camera pose.
25 | Camera is specified as looking at the origin.
26 | If horizontal and vertical stddev (specified in radians) are zero, gives a
27 | deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean.
28 | The coordinate system is specified with y-up, z-forward, x-left.
29 | Horizontal mean is the azimuthal angle (rotation around y axis) in radians,
30 | vertical mean is the polar angle (angle from the y axis) in radians.
31 | A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2.
32 |
33 | Example:
34 | For a camera pose looking at the origin with the camera at position [0, 0, 1]:
35 | cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1)
36 | """
37 |
38 | @staticmethod
39 | def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
40 | h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
41 | v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
42 | v = torch.clamp(v, 1e-5, math.pi - 1e-5)
43 |
44 | theta = h
45 | v = v / math.pi
46 | phi = torch.arccos(1 - 2*v)
47 |
48 | camera_origins = torch.zeros((batch_size, 3), device=device)
49 |
50 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
51 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
52 | camera_origins[:, 1:2] = radius*torch.cos(phi)
53 |
54 | forward_vectors = math_utils.normalize_vecs(-camera_origins)
55 | return create_cam2world_matrix(forward_vectors, camera_origins)
56 |
57 |
58 | class LookAtPoseSampler:
59 | """
60 | Same as GaussianCameraPoseSampler, except the
61 | camera is specified as looking at 'lookat_position', a 3-vector.
62 |
63 | Example:
64 | For a camera pose looking at the origin with the camera at position [0, 0, 1]:
65 | cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
66 | """
67 |
68 | @staticmethod
69 | def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
70 | h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
71 | v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
72 | v = torch.clamp(v, 1e-5, math.pi - 1e-5)
73 |
74 | theta = h
75 | v = v / math.pi
76 | phi = torch.arccos(1 - 2*v)
77 |
78 | camera_origins = torch.zeros((batch_size, 3), device=device)
79 |
80 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
81 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
82 | camera_origins[:, 1:2] = radius*torch.cos(phi)
83 |
84 | # forward_vectors = math_utils.normalize_vecs(-camera_origins)
85 | forward_vectors = math_utils.normalize_vecs(lookat_position - camera_origins)
86 | return create_cam2world_matrix(forward_vectors, camera_origins)
87 |
88 | class UniformCameraPoseSampler:
89 | """
90 | Same as GaussianCameraPoseSampler, except the
91 | pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev.
92 |
93 | Example:
94 | For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians:
95 |
96 | cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16)
97 | """
98 |
99 | @staticmethod
100 | def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
101 | h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean
102 | v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean
103 | v = torch.clamp(v, 1e-5, math.pi - 1e-5)
104 |
105 | theta = h
106 | v = v / math.pi
107 | phi = torch.arccos(1 - 2*v)
108 |
109 | camera_origins = torch.zeros((batch_size, 3), device=device)
110 |
111 | camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
112 | camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
113 | camera_origins[:, 1:2] = radius*torch.cos(phi)
114 |
115 | forward_vectors = math_utils.normalize_vecs(-camera_origins)
116 | return create_cam2world_matrix(forward_vectors, camera_origins)
117 |
118 | def create_cam2world_matrix(forward_vector, origin):
119 | """
120 | Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
121 | Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll.
122 | """
123 |
124 | forward_vector = math_utils.normalize_vecs(forward_vector)
125 | up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=origin.device).expand_as(forward_vector)
126 |
127 | right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
128 | up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1))
129 |
130 | rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
131 | rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1)
132 |
133 | translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
134 | translation_matrix[:, :3, 3] = origin
135 | cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
136 | assert(cam2world.shape[1:] == (4, 4))
137 | return cam2world
138 |
139 |
140 | def FOV_to_intrinsics(fov_degrees, device='cpu'):
141 | """
142 | Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
143 | Note the intrinsics are returned as normalized by image size, rather than in pixel units.
144 | Assumes principal point is at image center.
145 | """
146 |
147 | focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414))
148 | intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
149 | return intrinsics
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adopted from pix2pixHD:
3 | https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py
4 | """
5 | import os
6 |
7 | IMG_EXTENSIONS = [
8 | '.jpg', '.JPG', '.jpeg', '.JPEG',
9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.npy',
10 | '.pt'
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 |
18 | def make_dataset(dir):
19 | images = []
20 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
21 | for root, _, fnames in sorted(os.walk(dir)):
22 | for fname in fnames:
23 | if is_image_file(fname):
24 | path = os.path.join(root, fname)
25 | images.append(path)
26 | return images
27 |
--------------------------------------------------------------------------------
/utils/shape_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3 | #
4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | # property and proprietary rights in and to this material, related
6 | # documentation and any modifications thereto. Any use, reproduction,
7 | # disclosure or distribution of this material and related documentation
8 | # without an express license agreement from NVIDIA CORPORATION or
9 | # its affiliates is strictly prohibited.
10 |
11 |
12 | """
13 | Utils for extracting 3D shapes using marching cubes. Based on code from DeepSDF (Park et al.)
14 |
15 | Takes as input an .mrc file and extracts a mesh.
16 |
17 | Ex.
18 | python shape_utils.py my_shape.mrc
19 | Ex.
20 | python shape_utils.py myshapes_directory --level=12
21 | """
22 |
23 |
24 | import time
25 | import plyfile
26 | import glob
27 | import logging
28 | import numpy as np
29 | import os
30 | import random
31 | import torch
32 | import torch.utils.data
33 | import trimesh
34 | import skimage.measure
35 | import argparse
36 | import mrcfile
37 | from tqdm import tqdm
38 |
39 |
40 | def convert_sdf_samples_to_ply(
41 | numpy_3d_sdf_tensor,
42 | voxel_grid_origin,
43 | voxel_size,
44 | ply_filename_out,
45 | offset=None,
46 | scale=None,
47 | level=0.0
48 | ):
49 | """
50 | Convert sdf samples to .ply
51 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
52 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
53 | :voxel_size: float, the size of the voxels
54 | :ply_filename_out: string, path of the filename to save to
55 | This function adapted from: https://github.com/RobotLocomotion/spartan
56 | """
57 | start_time = time.time()
58 |
59 | verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0)
60 | # try:
61 | verts, faces, normals, values = skimage.measure.marching_cubes(
62 | numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3
63 | )
64 | # except:
65 | # pass
66 |
67 | # transform from voxel coordinates to camera coordinates
68 | # note x and y are flipped in the output of marching_cubes
69 | mesh_points = np.zeros_like(verts)
70 | mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
71 | mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
72 | mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]
73 |
74 | # apply additional offset and scale
75 | if scale is not None:
76 | mesh_points = mesh_points / scale
77 | if offset is not None:
78 | mesh_points = mesh_points - offset
79 |
80 | # try writing to the ply file
81 |
82 | num_verts = verts.shape[0]
83 | num_faces = faces.shape[0]
84 |
85 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
86 |
87 | for i in range(0, num_verts):
88 | verts_tuple[i] = tuple(mesh_points[i, :])
89 |
90 | faces_building = []
91 | for i in range(0, num_faces):
92 | faces_building.append(((faces[i, :].tolist(),)))
93 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])
94 |
95 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
96 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
97 |
98 | ply_data = plyfile.PlyData([el_verts, el_faces])
99 | ply_data.write(ply_filename_out)
100 | print(f"wrote to {ply_filename_out}")
101 |
102 |
103 | def convert_mrc(input_filename, output_filename, isosurface_level=1):
104 | with mrcfile.open(input_filename) as mrc:
105 | convert_sdf_samples_to_ply(np.transpose(mrc.data, (2, 1, 0)), [0, 0, 0], 1, output_filename, level=isosurface_level)
106 |
107 | if __name__ == '__main__':
108 | start_time = time.time()
109 | parser = argparse.ArgumentParser()
110 | parser.add_argument('input_mrc_path')
111 | parser.add_argument('--level', type=float, default=10, help="The isosurface level for marching cubes")
112 | args = parser.parse_args()
113 |
114 | if os.path.isfile(args.input_mrc_path) and args.input_mrc_path.split('.')[-1] == 'ply':
115 | output_obj_path = args.input_mrc_path.split('.mrc')[0] + '.ply'
116 | convert_mrc(args.input_mrc_path, output_obj_path, isosurface_level=1)
117 |
118 | print(f"{time.time() - start_time:02f} s")
119 | else:
120 | assert os.path.isdir(args.input_mrc_path)
121 |
122 | for mrc_path in tqdm(glob.glob(os.path.join(args.input_mrc_path, '*.mrc'))):
123 | output_obj_path = mrc_path.split('.mrc')[0] + '.ply'
124 | convert_mrc(mrc_path, output_obj_path, isosurface_level=args.level)
--------------------------------------------------------------------------------