├── .gitignore ├── output.0.png ├── requirements.txt ├── cog.yaml ├── .dockerignore ├── feature-extractor └── preprocessor_config.json ├── README.md └── predict.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .cog 3 | Hyper-FLUX.1-dev 4 | safety-cache 5 | -------------------------------------------------------------------------------- /output.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucataco/cog-hyper-flux-8step/HEAD/output.0.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.0 2 | torchvision 3 | diffusers==0.30.1 4 | transformers==4.43.3 5 | accelerate==0.33.0 6 | sentencepiece==0.2.0 7 | protobuf==5.27.3 8 | numpy==1.26.3 9 | pillow==11.0.0 10 | peft -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://cog.run/yaml 3 | 4 | build: 5 | gpu: true 6 | cuda: "12.1" 7 | python_version: "3.11" 8 | python_requirements: requirements.txt 9 | 10 | predict: "predict.py:Predictor" -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | -------------------------------------------------------------------------------- /feature-extractor/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 224, 3 | "do_center_crop": true, 4 | "do_convert_rgb": true, 5 | "do_normalize": true, 6 | "do_resize": true, 7 | "feature_extractor_type": "CLIPFeatureExtractor", 8 | "image_mean": [ 9 | 0.48145466, 10 | 0.4578275, 11 | 0.40821073 12 | ], 13 | "image_std": [ 14 | 0.26862954, 15 | 0.26130258, 16 | 0.27577711 17 | ], 18 | "resample": 3, 19 | "size": 224 20 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ByteDance/Hyper-SD FLUX.1-dev 8-step Cog Model 2 | 3 | This is an implementation of [ByteDance/Hyper-SD FLUX.1-dev 8-step](https://huggingface.co/ByteDance/Hyper-SD) as a [Cog](https://github.com/replicate/cog) model. 4 | 5 | - [x] Cog Fast Push Compatible 6 | 7 | ## Development 8 | 9 | Follow the [model pushing guide](https://replicate.com/docs/guides/push-a-model) to push your own model to [Replicate](https://replicate.com). 10 | 11 | ## How to use 12 | 13 | Make sure you have [cog](https://github.com/replicate/cog) installed. 14 | 15 | To run a prediction: 16 | 17 | cog predict -i prompt="a dog smiling and looking directly at the camera, wearing a white t-shirt with the word 'HYPER' printed on it." 18 | 19 | ![Output](output.0.png) 20 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | from cog import BasePredictor, Input, Path 5 | import os 6 | import time 7 | import torch 8 | import subprocess 9 | import numpy as np 10 | from typing import List 11 | from diffusers import FluxPipeline 12 | from transformers import CLIPImageProcessor 13 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 14 | StableDiffusionSafetyChecker 15 | ) 16 | 17 | MODEL_CACHE = "Hyper-FLUX.1-dev" 18 | MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Hyper-FLUX.1-dev-8steps/model.tar" 19 | SAFETY_CACHE = "safety-cache" 20 | FEATURE_EXTRACTOR = "/src/feature-extractor" 21 | SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar" 22 | 23 | ASPECT_RATIOS = { 24 | "1:1": (1024, 1024), 25 | "16:9": (1344, 768), 26 | "21:9": (1536, 640), 27 | "3:2": (1216, 832), 28 | "2:3": (832, 1216), 29 | "4:5": (896, 1088), 30 | "5:4": (1088, 896), 31 | "3:4": (896, 1152), 32 | "4:3": (1152, 896), 33 | "9:16": (768, 1344), 34 | "9:21": (640, 1536), 35 | } 36 | 37 | def download_weights(url, dest): 38 | start = time.time() 39 | print("downloading url: ", url) 40 | print("downloading to: ", dest) 41 | subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) 42 | print("downloading took: ", time.time() - start) 43 | 44 | def make_multiple_of_16(x): 45 | return (x + 15) // 16 * 16 46 | 47 | class Predictor(BasePredictor): 48 | test_inputs = { 49 | "prompt": "a dog smiling and looking directly at the camera, wearing a white t-shirt with the word \"HYPER\" printed on it.", 50 | "width": 848, 51 | "height": 848, 52 | "seed": 0, 53 | } 54 | 55 | def setup(self) -> None: 56 | """Load the model into memory to make running multiple predictions efficient""" 57 | start = time.time() 58 | os.environ["TRANSFORMERS_OFFLINE"] = "1" 59 | 60 | print("Loading safety checker...") 61 | if not os.path.exists(SAFETY_CACHE): 62 | download_weights(SAFETY_URL, SAFETY_CACHE) 63 | self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( 64 | SAFETY_CACHE, torch_dtype=torch.float16 65 | ).to("cuda") 66 | self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) 67 | 68 | print("Loading Flux txt2img Pipeline") 69 | if not os.path.exists(MODEL_CACHE): 70 | download_weights(MODEL_URL, MODEL_CACHE) 71 | self.txt2img_pipe = FluxPipeline.from_pretrained( 72 | MODEL_CACHE, 73 | torch_dtype=torch.bfloat16 74 | ).to("cuda") 75 | print("setup took: ", time.time() - start) 76 | 77 | @torch.amp.autocast('cuda') 78 | def run_safety_checker(self, image): 79 | safety_checker_input = self.feature_extractor(image, return_tensors="pt").to("cuda") 80 | np_image = [np.array(val) for val in image] 81 | image, has_nsfw_concept = self.safety_checker( 82 | images=np_image, 83 | clip_input=safety_checker_input.pixel_values.to(torch.float16), 84 | ) 85 | return image, has_nsfw_concept 86 | 87 | def aspect_ratio_to_width_height(self, aspect_ratio: str) -> tuple[int, int]: 88 | return ASPECT_RATIOS[aspect_ratio] 89 | 90 | @torch.inference_mode() 91 | def predict( 92 | self, 93 | prompt: str = Input(description="Prompt for generated image"), 94 | aspect_ratio: str = Input( 95 | description="Aspect ratio for the generated image. The size will always be 1 megapixel, i.e. 1024x1024 if aspect ratio is 1:1. To use arbitrary width and height, set aspect ratio to 'custom'.", 96 | choices=list(ASPECT_RATIOS.keys()) + ["custom"], 97 | default="1:1", 98 | ), 99 | width: int = Input( 100 | description="Width of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)", 101 | ge=256, 102 | le=1440, 103 | default=848, 104 | ), 105 | height: int = Input( 106 | description="Height of the generated image. Optional, only used when aspect_ratio=custom. Must be a multiple of 16 (if it's not, it will be rounded to nearest multiple of 16)", 107 | ge=256, 108 | le=1440, 109 | default=848, 110 | ), 111 | num_outputs: int = Input( 112 | description="Number of images to output.", 113 | ge=1, 114 | le=4, 115 | default=1, 116 | ), 117 | num_inference_steps: int = Input( 118 | description="Number of inference steps", 119 | ge=1,le=30,default=8, 120 | ), 121 | guidance_scale: float = Input( 122 | description="Guidance scale for the diffusion process", 123 | ge=0,le=10,default=3.5, 124 | ), 125 | seed: int = Input(description="Random seed. Set for reproducible generation", default=0), 126 | output_format: str = Input( 127 | description="Format of the output images", 128 | choices=["webp", "jpg", "png"], 129 | default="webp", 130 | ), 131 | output_quality: int = Input( 132 | description="Quality when saving the output images, from 0 to 100. 100 is best quality, 0 is lowest quality. Not relevant for .png outputs", 133 | default=80, 134 | ge=0, 135 | le=100, 136 | ), 137 | disable_safety_checker: bool = Input( 138 | description="Disable safety checker for generated images. This feature is only available through the API. See [https://replicate.com/docs/how-does-replicate-work#safety](https://replicate.com/docs/how-does-replicate-work#safety)", 139 | default=False, 140 | ), 141 | ) -> List[Path]: 142 | """Run a single prediction on the model""" 143 | if seed is None or seed == 0: 144 | seed = int.from_bytes(os.urandom(2), "big") 145 | print(f"Using seed: {seed}") 146 | 147 | if aspect_ratio == "custom": 148 | if width is None or height is None: 149 | raise ValueError( 150 | "width and height must be defined if aspect ratio is 'custom'" 151 | ) 152 | width = make_multiple_of_16(width) 153 | height = make_multiple_of_16(height) 154 | else: 155 | width, height = self.aspect_ratio_to_width_height(aspect_ratio) 156 | max_sequence_length = 512 157 | 158 | flux_kwargs = {} 159 | print(f"Prompt: {prompt}") 160 | print("txt2img mode") 161 | flux_kwargs["width"] = width 162 | flux_kwargs["height"] = height 163 | pipe = self.txt2img_pipe 164 | 165 | generator = torch.Generator("cuda").manual_seed(seed) 166 | 167 | common_args = { 168 | "prompt": [prompt] * num_outputs, 169 | "guidance_scale": guidance_scale, 170 | "generator": generator, 171 | "num_inference_steps": num_inference_steps, 172 | "max_sequence_length": max_sequence_length, 173 | "output_type": "pil" 174 | } 175 | 176 | output = pipe(**common_args, **flux_kwargs) 177 | 178 | if not disable_safety_checker: 179 | _, has_nsfw_content = self.run_safety_checker(output.images) 180 | 181 | output_paths = [] 182 | for i, image in enumerate(output.images): 183 | if not disable_safety_checker and has_nsfw_content[i]: 184 | print(f"NSFW content detected in image {i}") 185 | continue 186 | output_path = f"/tmp/out-{i}.{output_format}" 187 | if output_format != 'png': 188 | image.save(output_path, quality=output_quality, optimize=True) 189 | else: 190 | image.save(output_path) 191 | output_paths.append(Path(output_path)) 192 | 193 | if len(output_paths) == 0: 194 | raise Exception("NSFW content detected. Try running it again, or try a different prompt.") 195 | 196 | return output_paths 197 | --------------------------------------------------------------------------------