├── .gitignore ├── README.md ├── data ├── print_examples │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ └── 5.jpg └── texture_examples │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ └── 5.jpg ├── docs └── teaser.jpg ├── environment.yml ├── inference_print.py ├── inference_texture.py └── pipeline.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | models/ 3 | __pycache__/ 4 | upload_model.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FabricDiffusion 2 | 3 | [![paper](https://img.shields.io/badge/paper-SIGGRAPHAsia-cyan)](https://dl.acm.org/doi/10.1145/3680528.3687637) 4 | [![arXiv](https://img.shields.io/badge/arxiv-2410.01801-red)](https://arxiv.org/abs/2410.01801) 5 | [![webpage](https://img.shields.io/badge/webpage-green)](https://humansensinglab.github.io/fabric-diffusion/) 6 | [![poster](https://img.shields.io/badge/video-orange)](https://youtu.be/xYiyjwldtWc) 7 | 8 | 9 | ## Overview 10 | 11 |

12 | 13 |

14 | 15 | > **FabricDiffusion: High-Fidelity Texture Transfer for 3D Garments Generation from In-The-Wild Images**
16 | > [Cheng Zhang*](https://czhang0528.github.io/), 17 | [Yuanhao Wang*](https://harrywang355.github.io/), 18 | [Francisco Vicente Carrasco](https://www.linkedin.com/in/francisco-vicente-carrasco-32a508144/), 19 | [Chenglei Wu](https://sites.google.com/view/chengleiwu/), 20 | [Jinlong Yang](https://is.mpg.de/~jyang), 21 | [Thabo Beeler](https://thabobeeler.com/), 22 | [Fernando De la Torre](https://www.cs.cmu.edu/~ftorre/) 23 | (* indicates equal contribution)
24 | > **SIGGRAPH Asia 2024** 25 | 26 | 31 | 32 | 33 | 34 | ## Updates 35 | 36 | 37 | **[Jan 2 2025]** Inference code released. 38 | 39 | **[Oct 2 2024]** Paper released to [Arxiv](https://arxiv.org/pdf/2410.01801). 40 | 41 | 42 | 54 | 55 | 56 | 57 | 58 | ## Installation 59 | Running the codebase only requires installing a recent version of [PyTorch](https://pytorch.org/get-started/locally/), [Diffusers](https://pypi.org/project/diffusers/) and [Transformers](https://pypi.org/project/transformers/): 60 | 61 | ```angular2html 62 | git clone https://github.com/humansensinglab/fabric-diffusion.git 63 | cd fabric-diffusion 64 | conda env create --file=environment.yml 65 | conda activate fabric-diff 66 | ``` 67 | 68 | ## Usage 69 | 70 | **1. Texture Normalization** 71 | 72 | Run the following command to normalize texture patches cropped from in-the-wild images: 73 | 74 | 78 | 79 | ```shell 80 | python inference_texture.py \ 81 | --texture_checkpoint='Yuanhao-Harry-Wang/fabric-diffusion-texture' \ 82 | --src_dir='data/texture_examples' \ 83 | --save_dir='outputs/texture' \ 84 | --n_samples=3 85 | ``` 86 | - `--texture_checkpoint`: path to the pre-trained texture model checkpoint. 87 | - `--src_dir`: path to the directory containing input images. 88 | - `--save_dir`: path to the output directory. 89 | - `--n_samples`: number of samples per input. 90 | 91 | **2. Print Normalization** 92 | 93 | Similarly, run the following command to normalize print patches cropped from in-the-wild images: 94 | 95 | 99 | 100 | ```shell 101 | python inference_print.py \ 102 | --print_checkpoint='Yuanhao-Harry-Wang/fabric-diffusion-print' \ 103 | --src_dir='data/print_examples' \ 104 | --save_dir='outputs/print' \ 105 | --n_samples=3 106 | ``` 107 | 108 | The model checkpoints are hosted on Huggingface here ([texture](https://huggingface.co/Yuanhao-Harry-Wang/fabric-diffusion-texture/tree/main), [print](https://huggingface.co/Yuanhao-Harry-Wang/fabric-diffusion-logo)). 109 | 110 | We are actively adding more features to this repo. Please stay tuned! 111 | 112 | 113 | ## Acknowledgements 114 | - Models 115 | - [Stable Diffusion](https://github.com/CompVis/stable-diffusion) 116 | - [InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix) 117 | - [Matfusion](https://github.com/samsartor/matfusion) 118 | 119 | ## Citation 120 | If you find this repo useful, please cite: 121 | ``` 122 | @inproceedings{zhang2024fabricdiffusion, 123 | title = {{FabricDiffusion}: High-Fidelity Texture Transfer for 3D Garments Generation from In-The-Wild Images}, 124 | author = {Zhang, Cheng and Wang, Yuanhao and Vicente Carrasco, Francisco and Wu, Chenglei and 125 | Yang, Jinlong and Beeler, Thabo and De la Torre, Fernando}, 126 | booktitle = {ACM SIGGRAPH Asia}, 127 | year = {2024}, 128 | } 129 | ``` 130 | 131 | ## License 132 | We use the X11 License. This license is identical to the MIT License, 133 | but with an extra sentence that prohibits using the copyright holders' names (Carnegie Mellon University and Google in our case) for 134 | advertising or promotional purposes without written permission. 135 | -------------------------------------------------------------------------------- /data/print_examples/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/data/print_examples/1.jpg -------------------------------------------------------------------------------- /data/print_examples/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/data/print_examples/2.jpg -------------------------------------------------------------------------------- /data/print_examples/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/data/print_examples/3.jpg -------------------------------------------------------------------------------- /data/print_examples/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/data/print_examples/4.jpg -------------------------------------------------------------------------------- /data/print_examples/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/data/print_examples/5.jpg -------------------------------------------------------------------------------- /data/texture_examples/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/data/texture_examples/1.jpg -------------------------------------------------------------------------------- /data/texture_examples/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/data/texture_examples/2.jpg -------------------------------------------------------------------------------- /data/texture_examples/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/data/texture_examples/3.jpg -------------------------------------------------------------------------------- /data/texture_examples/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/data/texture_examples/4.jpg -------------------------------------------------------------------------------- /data/texture_examples/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/data/texture_examples/5.jpg -------------------------------------------------------------------------------- /docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/humansensinglab/fabric-diffusion/4ae7e87171fad0de6b92533d0187d6dc88ef6c09/docs/teaser.jpg -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fabric-diff 2 | channels: 3 | - defaults 4 | dependencies: 5 | - pip: 6 | - diffusers==0.32.1 7 | - torch==2.5.1 8 | - transformers==4.47.1 9 | -------------------------------------------------------------------------------- /inference_print.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pipeline import FabricDiffusionPipeline 3 | import argparse 4 | 5 | 6 | def run_flatten_print(pipeline, warp_dataset_path, output_path=None, n_samples=3): 7 | os.makedirs(os.path.join(output_path), exist_ok=True) 8 | all_image_names = os.listdir(warp_dataset_path) 9 | for image_name in all_image_names: 10 | texture_name = image_name.split('.')[0] 11 | texture_patch = pipeline.load_patch_data(os.path.join(warp_dataset_path, image_name)) 12 | gen_imgs = pipeline.flatten_print(texture_patch, n_samples=n_samples) 13 | for i, gen_img in enumerate(gen_imgs): 14 | gen_img.save(os.path.join(output_path, f'{texture_name}_gen_{i}.png')) 15 | 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "--device", type=str, default="cuda:0", help="Device to run the model" 21 | ) 22 | parser.add_argument( 23 | "--texture_checkpoint", default=None, type=str, help="Path to the texture model checkpoint" 24 | ) 25 | parser.add_argument( 26 | "--print_checkpoint", default=None, type=str, help="Path to the logo model checkpoint" 27 | ) 28 | parser.add_argument( 29 | "--src_dir", default='./data/print_examples', type=str, help="Path to the input image directory" 30 | ) 31 | parser.add_argument( 32 | "--save_dir", type=str, default='./outputs/print', help="Directory to save the output" 33 | ) 34 | parser.add_argument( 35 | "--n_samples", type=int, default=3, help="Number of generated images per input" 36 | ) 37 | return parser.parse_args() 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | device = args.device 43 | texture_checkpoint = args.texture_checkpoint 44 | print_checkpoint = args.print_checkpoint 45 | src_dir = args.src_dir 46 | save_dir = args.save_dir 47 | 48 | pipeline = FabricDiffusionPipeline(device, texture_checkpoint, print_checkpoint=print_checkpoint) 49 | 50 | os.makedirs(save_dir, exist_ok=True) 51 | run_flatten_print(pipeline, src_dir, output_path=save_dir, n_samples=args.n_samples) 52 | -------------------------------------------------------------------------------- /inference_texture.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pipeline import FabricDiffusionPipeline 3 | import argparse 4 | 5 | 6 | def run_flatten_texture(pipeline, warp_dataset_path, output_path=None, n_samples=3): 7 | os.makedirs(os.path.join(output_path), exist_ok=True) 8 | all_image_names = os.listdir(warp_dataset_path) 9 | for image_name in all_image_names: 10 | texture_name = image_name.split('.')[0] 11 | texture_patch = pipeline.load_patch_data(os.path.join(warp_dataset_path, image_name)) 12 | gen_imgs = pipeline.flatten_texture(texture_patch, n_samples=n_samples) 13 | for i, gen_img in enumerate(gen_imgs): 14 | gen_img.save(os.path.join(output_path, f'{texture_name}_gen_{i}.png')) 15 | 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "--device", type=str, default="cuda:0", help="Device to run the model" 21 | ) 22 | parser.add_argument( 23 | "--texture_checkpoint", default=None, type=str, help="Path to the texture model checkpoint" 24 | ) 25 | parser.add_argument( 26 | "--print_checkpoint", default=None, type=str, help="Path to the logo model checkpoint" 27 | ) 28 | parser.add_argument( 29 | "--src_dir", default='./data/texture_examples', type=str, help="Path to the input image directory" 30 | ) 31 | parser.add_argument( 32 | "--save_dir", type=str, default='./outputs/texture', help="Directory to save the output" 33 | ) 34 | parser.add_argument( 35 | "--n_samples", type=int, default=3, help="Number of generated images per input" 36 | ) 37 | return parser.parse_args() 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | device = args.device 43 | texture_checkpoint = args.texture_checkpoint 44 | print_checkpoint = args.print_checkpoint 45 | src_dir = args.src_dir 46 | save_dir = args.save_dir 47 | 48 | pipeline = FabricDiffusionPipeline(device, texture_checkpoint, print_checkpoint=print_checkpoint) 49 | 50 | os.makedirs(save_dir, exist_ok=True) 51 | run_flatten_texture(pipeline, src_dir, output_path=save_dir, n_samples=args.n_samples) 52 | -------------------------------------------------------------------------------- /pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | from diffusers import StableDiffusionInstructPix2PixPipeline 6 | from PIL import Image 7 | 8 | 9 | class FabricDiffusionPipeline(): 10 | def __init__(self, device, texture_checkpoint, print_checkpoint): 11 | self.device = device 12 | self.texture_checkpoint = texture_checkpoint 13 | self.print_base_model = print_checkpoint 14 | 15 | if texture_checkpoint: 16 | self.texture_model = StableDiffusionInstructPix2PixPipeline.from_pretrained( 17 | texture_checkpoint, 18 | torch_dtype=torch.float16, 19 | safety_checker=None 20 | ) 21 | # with open(os.path.join(texture_checkpoint, "unet", "diffusion_pytorch_model.safetensors"), "rb") as f: 22 | # data = f.read() 23 | # loaded = load(data) 24 | # self.texture_pipeline.unet.load_state_dict(loaded) 25 | self.texture_model = self.texture_model.to(device) 26 | else: 27 | self.texture_model = None 28 | 29 | # set circular convolution for the texture model 30 | if self.texture_model: 31 | for a, b in self.texture_model.unet.named_modules(): 32 | if isinstance(b, nn.Conv2d): 33 | setattr(b, 'padding_mode', 'circular') 34 | for a, b in self.texture_model.vae.named_modules(): 35 | if isinstance(b, nn.Conv2d): 36 | setattr(b, 'padding_mode', 'circular') 37 | 38 | if print_checkpoint: 39 | self.print_model = StableDiffusionInstructPix2PixPipeline.from_pretrained( 40 | print_checkpoint, 41 | torch_dtype=torch.float16, 42 | safety_checker=None 43 | ) 44 | self.print_model = self.print_model.to(device) 45 | else: 46 | self.print_model = None 47 | 48 | def load_real_data_with_mask(self, dataset_path, image_name): 49 | image = np.array(Image.open(os.path.join(dataset_path, 'images', image_name)).convert('RGB')) 50 | seg_mask = np.array(Image.open(os.path.join(dataset_path, 'seg_mask', image_name)).convert('L'))[..., None] 51 | texture_mask = np.array(Image.open(os.path.join(dataset_path, 'texture_mask', image_name)).convert('L'))[..., None] 52 | # crop the image based on texture_mask 53 | x1, y1, x2, y2 = np.where(texture_mask > 0)[1].min(), np.where(texture_mask > 0)[0].min(), np.where(texture_mask > 0)[1].max(), np.where(texture_mask > 0)[0].max() 54 | texture_patch = image[y1:y2, x1:x2] 55 | # resize the texture_patch to 256x256 56 | texture_patch = Image.fromarray(texture_patch.astype(np.uint8)).resize((256, 256)) 57 | 58 | return image, seg_mask, texture_patch 59 | 60 | def load_patch_data(self, patch_path): 61 | texture_patch = Image.open(patch_path).convert('RGB').resize((256, 256)) 62 | return texture_patch 63 | 64 | def flatten_texture(self, texture_patch, n_samples=3, use_inversion=True): 65 | num_inference_steps = 20 66 | self.texture_model.scheduler.set_timesteps(num_inference_steps) 67 | timesteps = self.texture_model.scheduler.timesteps 68 | 69 | # convert image to latent using vae 70 | image = self.texture_model.image_processor.preprocess(texture_patch) 71 | if use_inversion: 72 | image_latents = self.texture_model.prepare_image_latents(image, batch_size=1, 73 | num_images_per_prompt=1, 74 | device=self.device, 75 | dtype=torch.float16, 76 | do_classifier_free_guidance=False) 77 | 78 | image_latents = (image_latents - torch.mean(image_latents)) / torch.std(image_latents) 79 | 80 | # forward noising process 81 | noise = torch.randn_like(image_latents) 82 | noisy_image_latents = self.texture_model.scheduler.add_noise(image_latents, noise, timesteps[0:1]) 83 | noisy_image_latents /= self.texture_model.scheduler.init_noise_sigma 84 | noisy_image_latents = torch.tile(noisy_image_latents, (n_samples, 1, 1, 1)) 85 | else: 86 | noisy_image_latents = None 87 | 88 | image = torch.tile(image, (n_samples, 1, 1, 1)) 89 | gen_imgs = self.texture_model( 90 | "", 91 | image=image, 92 | num_inference_steps=20, 93 | image_guidance_scale=1.5, 94 | guidance_scale=7., 95 | latents=noisy_image_latents, 96 | num_images_per_prompt=n_samples, 97 | ).images 98 | 99 | return gen_imgs 100 | 101 | def flatten_print(self, print_patch, n_samples=3): 102 | image = self.print_model.image_processor.preprocess(print_patch) 103 | gen_imgs = [] 104 | for i in range(n_samples): 105 | gen_img = self.print_model( 106 | "", 107 | image=image, 108 | num_inference_steps=20, 109 | image_guidance_scale=1.5, 110 | guidance_scale=7., 111 | ).images[0] 112 | gen_img = np.asarray(gen_img) / 255. 113 | alpha_map = np.clip(gen_img / 0.1 * 1.2 - 0.2, 0., 1).mean(axis=-1, keepdims=True) 114 | gen_img = np.clip((gen_img - 0.1) / 0.9, 0., 1.) 115 | gen_img = np.concatenate([gen_img, alpha_map], axis=-1) 116 | gen_img = (gen_img * 255).astype(np.uint8) 117 | gen_img = Image.fromarray(gen_img) 118 | gen_imgs.append(gen_img) 119 | 120 | return gen_imgs 121 | 122 | --------------------------------------------------------------------------------