├── .gitmodules ├── LICENSE ├── Makefile ├── README.md ├── example └── main.go ├── go.mod ├── overrides ├── decoder_slover.cpp ├── decoder_slover.h ├── diffusion_slover.cpp ├── diffusion_slover.h ├── encoder_slover.cpp ├── encoder_slover.h ├── prompt_slover.cpp └── prompt_slover.h ├── stablediffusion.cpp ├── stablediffusion.go ├── stablediffusion.h └── stablediffusion.hpp /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "stable-diffusion"] 2 | path = stable-diffusion 3 | url = https://github.com/EdVince/Stable-Diffusion-NCNN 4 | [submodule "ncnn"] 5 | path = ncnn 6 | url = https://github.com/Tencent/ncnn 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, WuJinxuan 4 | Copyright (c) 2023, Ettore Di Giacinto (Golang bindings) 5 | 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | 1. Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | 2. Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | 3. Neither the name of the copyright holder nor the names of its 19 | contributors may be used to endorse or promote products derived from 20 | this software without specific prior written permission. 21 | 22 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 23 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 24 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 25 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 26 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 27 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 28 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 29 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 30 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | INCLUDE_PATH := $(abspath ./) 2 | LIBRARY_PATH := $(abspath ./) 3 | 4 | 5 | BUILD_TYPE?= 6 | # keep standard at C11 and C++11 7 | CFLAGS = -I./ncnn -I./ncnn/src -I./ncnn/build/src/ -I. -I./stable-diffusion/x86/vs2019_opencv-mobile_ncnn-dll_demo/vs2019_opencv-mobile_ncnn-dll_demo -O3 -DNDEBUG -std=c11 -fPIC 8 | CXXFLAGS = -I./ncnn -I./ncnn/src -I./ncnn/build/src/ -I. -I./stable-diffusion/x86/vs2019_opencv-mobile_ncnn-dll_demo/vs2019_opencv-mobile_ncnn-dll_demo -O3 -DNDEBUG -std=c++17 -fPIC 9 | LDFLAGS = 10 | 11 | # warnings 12 | CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function 13 | CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function 14 | # 15 | # Print build information 16 | # 17 | 18 | $(info I llama.cpp build info: ) 19 | 20 | ncnn/build/src/libncnn.a: 21 | cd ncnn && mkdir -p build && cd build && cmake -DCMAKE_BUILD_TYPE=Release -DNCNN_VULKAN=OFF -DNCNN_BUILD_EXAMPLES=ON .. && $(MAKE) 22 | cd ncnn && cp -rf src/* ./ 23 | 24 | stablediffusion.o: ncnn/build/src/libncnn.a 25 | cp -rf overrides/* stable-diffusion/x86/vs2019_opencv-mobile_ncnn-dll_demo/vs2019_opencv-mobile_ncnn-dll_demo/ 26 | $(CXX) $(CXXFLAGS) stablediffusion.cpp -o stablediffusion.o -c $(LDFLAGS) 27 | $(CXX) $(CXXFLAGS) stablediffusion.cpp -o stablediffusion-hires.o -c $(LDFLAGS) 28 | 29 | unpack: ncnn/build/src/libncnn.a 30 | mkdir -p unpack && cd unpack && ar x ../ncnn/build/src/libncnn.a 31 | 32 | libstablediffusion.a: stablediffusion.o unpack $(EXTRA_TARGETS) 33 | ar src libstablediffusion.a stablediffusion-hires.o stablediffusion.o $(shell ls unpack/* | xargs echo) 34 | 35 | example/main: libstablediffusion.a 36 | @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build -x -o example/main ./example 37 | 38 | clean: 39 | rm -rf *.o 40 | rm -rf *.a 41 | rm -rf unpack 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is a wrapper around in golang for https://github.com/EdVince/Stable-Diffusion-NCNN/ and https://github.com/fengwang/Stable-Diffusion-NCNN/ 2 | -------------------------------------------------------------------------------- /example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | stableDiffusion "github.com/mudler/go-stable-diffusion" 5 | ) 6 | 7 | func main() { 8 | stableDiffusion.GenerateImage( 9 | 256, 10 | 256, 11 | 0, 12 | 15, 13 | 42, 14 | "floating hair, portrait, ((loli)), ((one girl)), cute face, hidden hands, asymmetrical bangs, beautiful detailed eyes, eye shadow, hair ornament, ribbons, bowties, buttons, pleated skirt, (((masterpiece))), ((best quality)), colorful", 15 | "((part of the head)), ((((mutated hands and fingers)))), deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, Octane renderer, lowres, bad anatomy, bad hands, text", "./test.png", 16 | "", 17 | "assets", 18 | ) 19 | } 20 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mudler/go-stable-diffusion 2 | 3 | go 1.19 4 | -------------------------------------------------------------------------------- /overrides/decoder_slover.cpp: -------------------------------------------------------------------------------- 1 | #include "decoder_slover.h" 2 | #include 3 | 4 | DecodeSlover::DecodeSlover(int h, int w, string assets_dir) 5 | { 6 | net.opt.use_vulkan_compute = false; 7 | net.opt.use_winograd_convolution = false; 8 | net.opt.use_sgemm_convolution = false; 9 | net.opt.use_fp16_packed = false; 10 | net.opt.use_fp16_storage = false; 11 | net.opt.use_fp16_arithmetic = false; 12 | net.opt.use_bf16_storage = true; 13 | net.opt.use_packing_layout = true; 14 | 15 | 16 | // Define the name of the parameter file 17 | std::string param_file; 18 | if (h == 512 && w == 512) 19 | param_file = "AutoencoderKL-512-512-fp16-opt.param"; 20 | else if (h == 256 && w == 256) 21 | param_file = "AutoencoderKL-256-256-fp16-opt.param"; 22 | else 23 | { 24 | generate_param(h, w, assets_dir); 25 | param_file = "tmp-AutoencoderKL-" + to_string(h) + "-" + to_string(w) + "-fp16.param"; 26 | } 27 | 28 | 29 | // Join the paths using std::filesystem::path::operator/() function 30 | std::filesystem::path param_path = std::filesystem::path(assets_dir) / std::filesystem::path(param_file); 31 | std::filesystem::path model_path = std::filesystem::path(assets_dir) / std::filesystem::path("AutoencoderKL-fp16.bin"); 32 | 33 | net.load_param(param_path.string().c_str()); 34 | net.load_model(model_path.string().c_str()); 35 | } 36 | 37 | void DecodeSlover::generate_param(int height, int width, string assets_dir ) 38 | { 39 | string line; 40 | 41 | std::filesystem::path decoder = std::filesystem::path(assets_dir) / std::filesystem::path("AutoencoderKL-base-fp16.param"); 42 | std::filesystem::path decoder_out = std::filesystem::path(assets_dir) / std::filesystem::path("tmp-AutoencoderKL-" + std::to_string(height) + "-" + std::to_string(width) + "-fp16.param"); 43 | 44 | ifstream decoder_file(decoder.string().c_str()); 45 | ofstream decoder_file_new(decoder_out.string().c_str()); 46 | 47 | int cnt = 0; 48 | while (getline(decoder_file, line)) 49 | { 50 | if (line.substr(0, 7) == "Reshape") 51 | { 52 | if (cnt < 3) 53 | line = line.substr(0, line.size() - 12) + "0=" + to_string(width * height / 8 / 8) + " 1=512"; 54 | else 55 | line = line.substr(0, line.size() - 15) + "0=" + to_string(width / 8) + " 1=" + std::to_string(height / 8) + " 2=512"; 56 | cnt++; 57 | } 58 | decoder_file_new << line << endl; 59 | } 60 | decoder_file_new.close(); 61 | decoder_file.close(); 62 | } 63 | 64 | ncnn::Mat DecodeSlover::decode(ncnn::Mat sample) 65 | { 66 | ncnn::Mat x_samples_ddim; 67 | { 68 | sample.substract_mean_normalize(0, factor); 69 | 70 | { 71 | ncnn::Extractor ex = net.create_extractor(); 72 | ex.set_light_mode(true); 73 | ex.input("input.1", sample); 74 | ex.extract("815", x_samples_ddim); 75 | } 76 | 77 | x_samples_ddim.substract_mean_normalize(_mean_, _norm_); 78 | } 79 | 80 | return x_samples_ddim; 81 | } -------------------------------------------------------------------------------- /overrides/decoder_slover.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | using namespace std; 15 | 16 | class DecodeSlover 17 | { 18 | public: 19 | DecodeSlover(int h, int w, string assets_dir); 20 | 21 | ncnn::Mat decode(ncnn::Mat sample); 22 | 23 | private: 24 | void generate_param(int height, int width, string assets_dir); 25 | 26 | const float factor[4] = { 1.0 / 0.18215f, 1.0 / 0.18215f, 1.0 / 0.18215f, 1.0 / 0.18215f }; 27 | 28 | const float _mean_[3] = { -1.0f, -1.0f, -1.0f }; 29 | const float _norm_[3] = { 127.5f, 127.5f, 127.5f }; 30 | 31 | ncnn::Net net; 32 | }; -------------------------------------------------------------------------------- /overrides/diffusion_slover.cpp: -------------------------------------------------------------------------------- 1 | #include "diffusion_slover.h" 2 | #include 3 | 4 | DiffusionSlover::DiffusionSlover(int h, int w, int mode, string assets_dir) 5 | { 6 | net.opt.use_vulkan_compute = false; 7 | net.opt.lightmode = true; 8 | if (mode == 0) 9 | { 10 | net.opt.use_winograd_convolution = false; 11 | net.opt.use_sgemm_convolution = false; 12 | } 13 | else 14 | { 15 | net.opt.use_winograd_convolution = true; 16 | net.opt.use_sgemm_convolution = true; 17 | } 18 | net.opt.use_fp16_packed = true; 19 | net.opt.use_fp16_storage = true; 20 | net.opt.use_fp16_arithmetic = true; 21 | net.opt.use_packing_layout = true; 22 | 23 | // Define the name of the parameter file 24 | std::string param_file; 25 | if (h == 512 && w == 512) 26 | param_file = "UNetModel-512-512-MHA-fp16-opt.param"; 27 | else if (h == 256 && w == 256) 28 | param_file = "UNetModel-256-256-MHA-fp16-opt.param"; 29 | else 30 | { 31 | generate_param(h, w, assets_dir); 32 | param_file = "tmp-UNetModel-" + std::to_string(h) + "-" + std::to_string(w) + "-MHA-fp16.param"; 33 | } 34 | std::filesystem::path model_path = std::filesystem::path(assets_dir) / std::filesystem::path("UNetModel-MHA-fp16.bin"); 35 | 36 | // Join the paths using std::filesystem::path::operator/() function 37 | std::filesystem::path param_path = std::filesystem::path(assets_dir) / std::filesystem::path(param_file); 38 | 39 | net.load_param(param_path.string().c_str()); 40 | net.load_model(model_path.string().c_str()); 41 | std::filesystem::path sigma_path = std::filesystem::path(assets_dir) / std::filesystem::path("log_sigmas.bin"); 42 | 43 | h_size = h / 8; 44 | w_size = w / 8; 45 | 46 | ifstream in(sigma_path.string(), ios::in | ios::binary); 47 | in.read((char*)&log_sigmas, sizeof log_sigmas); 48 | in.close(); 49 | } 50 | 51 | void DiffusionSlover::generate_param(int height, int width, string assets_dir) 52 | { 53 | string line; 54 | 55 | std::filesystem::path decoder = std::filesystem::path(assets_dir) / std::filesystem::path("UNetModel-base-MHA-fp16.param"); 56 | std::filesystem::path decoder_out = std::filesystem::path(assets_dir) / std::filesystem::path("tmp-UNetModel-" + std::to_string(height) + "-" + std::to_string(width) + "-MHA-fp16.param"); 57 | 58 | ifstream diffuser_file(decoder.string().c_str()); 59 | ofstream diffuser_file_new(decoder_out.string().c_str()); 60 | 61 | int cnt = 0; 62 | while (getline(diffuser_file, line)) 63 | { 64 | if (line.substr(0, 7) == "Reshape") 65 | { 66 | switch (cnt) 67 | { 68 | case 0: line = line.substr(0, line.size() - 4) + to_string(width * height / 8 / 8); break; 69 | case 1: line = line.substr(0, line.size() - 7) + to_string(width / 8) + " 2=" + std::to_string(height / 8); break; 70 | case 2: line = line.substr(0, line.size() - 4) + to_string(width * height / 8 / 8); break; 71 | case 3: line = line.substr(0, line.size() - 7) + to_string(width / 8) + " 2=" + std::to_string(height / 8); break; 72 | case 4: line = line.substr(0, line.size() - 4) + to_string(width * height / 2 / 2 / 8 / 8); break; 73 | case 5: line = line.substr(0, line.size() - 7) + to_string(width / 2 / 8) + " 2=" + std::to_string(height / 2 / 8); break; 74 | case 6: line = line.substr(0, line.size() - 4) + to_string(width * height / 2 / 2 / 8 / 8); break; 75 | case 7: line = line.substr(0, line.size() - 7) + to_string(width / 2 / 8) + " 2=" + std::to_string(height / 2 / 8); break; 76 | case 8: line = line.substr(0, line.size() - 3) + to_string(width * height / 4 / 4 / 8 / 8); break; 77 | case 9: line = line.substr(0, line.size() - 7) + to_string(width / 4 / 8) + " 2=" + std::to_string(height / 4 / 8); break; 78 | case 10: line = line.substr(0, line.size() - 3) + to_string(width * height / 4 / 4 / 8 / 8); break; 79 | case 11: line = line.substr(0, line.size() - 7) + to_string(width / 4 / 8) + " 2=" + std::to_string(height / 4 / 8); break; 80 | case 12: line = line.substr(0, line.size() - 2) + to_string(width * height / 8 / 8 / 8 / 8); break; 81 | case 13: line = line.substr(0, line.size() - 5) + to_string(width / 8 / 8) + " 2=" + std::to_string(height / 8 / 8); break; 82 | case 14: line = line.substr(0, line.size() - 3) + to_string(width * height / 4 / 4 / 8 / 8); break; 83 | case 15: line = line.substr(0, line.size() - 7) + to_string(width / 4 / 8) + " 2=" + std::to_string(height / 4 / 8); break; 84 | case 16: line = line.substr(0, line.size() - 3) + to_string(width * height / 4 / 4 / 8 / 8); break; 85 | case 17: line = line.substr(0, line.size() - 7) + to_string(width / 4 / 8) + " 2=" + std::to_string(height / 4 / 8); break; 86 | case 18: line = line.substr(0, line.size() - 3) + to_string(width * height / 4 / 4 / 8 / 8); break; 87 | case 19: line = line.substr(0, line.size() - 7) + to_string(width / 4 / 8) + " 2=" + std::to_string(height / 4 / 8); break; 88 | case 20: line = line.substr(0, line.size() - 4) + to_string(width * height / 2 / 2 / 8 / 8); break; 89 | case 21: line = line.substr(0, line.size() - 7) + to_string(width / 2 / 8) + " 2=" + std::to_string(height / 2 / 8); break; 90 | case 22: line = line.substr(0, line.size() - 4) + to_string(width * height / 2 / 2 / 8 / 8); break; 91 | case 23: line = line.substr(0, line.size() - 7) + to_string(width / 2 / 8) + " 2=" + std::to_string(height / 2 / 8); break; 92 | case 24: line = line.substr(0, line.size() - 4) + to_string(width * height / 2 / 2 / 8 / 8); break; 93 | case 25: line = line.substr(0, line.size() - 7) + to_string(width / 2 / 8) + " 2=" + std::to_string(height / 2 / 8); break; 94 | case 26: line = line.substr(0, line.size() - 4) + to_string(width * height / 8 / 8); break; 95 | case 27: line = line.substr(0, line.size() - 7) + to_string(width / 8) + " 2=" + std::to_string(height / 8); break; 96 | case 28: line = line.substr(0, line.size() - 4) + to_string(width * height / 8 / 8); break; 97 | case 29: line = line.substr(0, line.size() - 7) + to_string(width / 8) + " 2=" + std::to_string(height / 8); break; 98 | case 30: line = line.substr(0, line.size() - 4) + to_string(width * height / 8 / 8); break; 99 | case 31: line = line.substr(0, line.size() - 7) + to_string(width / 8) + " 2=" + std::to_string(height / 8); break; 100 | default: break; 101 | } 102 | 103 | cnt++; 104 | } 105 | diffuser_file_new << line << endl; 106 | } 107 | diffuser_file_new.close(); 108 | diffuser_file.close(); 109 | } 110 | 111 | ncnn::Mat DiffusionSlover::randn_4(int seed) 112 | { 113 | cv::Mat cv_x(cv::Size(w_size, h_size), CV_32FC4); 114 | cv::RNG rng(seed); 115 | rng.fill(cv_x, cv::RNG::NORMAL, 0, 1); 116 | ncnn::Mat x_mat(w_size, h_size, 4, (void*)cv_x.data); 117 | return x_mat.clone(); 118 | } 119 | 120 | ncnn::Mat DiffusionSlover::CFGDenoiser_CompVisDenoiser(ncnn::Mat& input, float sigma, ncnn::Mat cond, ncnn::Mat uncond) 121 | { 122 | // get_scalings 123 | float c_out = -1.0 * sigma; 124 | float c_in = 1.0 / sqrt(sigma * sigma + 1); 125 | 126 | // sigma_to_t 127 | float log_sigma = log(sigma); 128 | vector dists(1000); 129 | for (int i = 0; i < 1000; i++) { 130 | if (log_sigma - log_sigmas[i] >= 0) 131 | dists[i] = 1; 132 | else 133 | dists[i] = 0; 134 | if (i == 0) continue; 135 | dists[i] += dists[i - 1]; 136 | } 137 | int low_idx = min(int(max_element(dists.begin(), dists.end()) - dists.begin()), 1000 - 2); 138 | int high_idx = low_idx + 1; 139 | float low = log_sigmas[low_idx]; 140 | float high = log_sigmas[high_idx]; 141 | float w = (low - log_sigma) / (low - high); 142 | w = max(0.f, min(1.f, w)); 143 | float t = (1 - w) * low_idx + w * high_idx; 144 | 145 | ncnn::Mat t_mat(1); 146 | t_mat[0] = t; 147 | 148 | ncnn::Mat c_in_mat(1); 149 | c_in_mat[0] = c_in; 150 | 151 | ncnn::Mat c_out_mat(1); 152 | c_out_mat[0] = c_out; 153 | 154 | ncnn::Mat v44; 155 | ncnn::Mat v83; 156 | ncnn::Mat v116; 157 | ncnn::Mat v163; 158 | ncnn::Mat v251; 159 | ncnn::Mat v337; 160 | ncnn::Mat v425; 161 | ncnn::Mat v511; 162 | ncnn::Mat v599; 163 | ncnn::Mat v627; 164 | ncnn::Mat v711; 165 | ncnn::Mat v725; 166 | ncnn::Mat v740; 167 | ncnn::Mat v755; 168 | ncnn::Mat v772; 169 | ncnn::Mat v858; 170 | ncnn::Mat v944; 171 | ncnn::Mat v1032; 172 | ncnn::Mat v1118; 173 | ncnn::Mat v1204; 174 | ncnn::Mat v1292; 175 | ncnn::Mat v1378; 176 | ncnn::Mat v1464; 177 | 178 | ncnn::Mat denoised_cond; 179 | { 180 | ncnn::Extractor ex = net.create_extractor(); 181 | ex.set_light_mode(true); 182 | ex.input("in0", input); 183 | ex.input("in1", t_mat); 184 | ex.input("in2", cond); 185 | ex.input("c_in", c_in_mat); 186 | ex.input("c_out", c_out_mat); 187 | ex.extract("44", v44, 1); 188 | ex.extract("83", v83, 1); 189 | ex.extract("116", v116, 1); 190 | ex.extract("163", v163, 1); 191 | ex.extract("251", v251, 1); 192 | ex.extract("337", v337, 1); 193 | ex.extract("425", v425, 1); 194 | ex.extract("511", v511, 1); 195 | ex.extract("599", v599, 1); 196 | ex.extract("627", v627, 1); 197 | ex.extract("711", v711, 1); 198 | ex.extract("725", v725, 1); 199 | ex.extract("740", v740, 1); 200 | ex.extract("755", v755, 1); 201 | ex.extract("772", v772, 1); 202 | ex.extract("858", v858, 1); 203 | ex.extract("944", v944, 1); 204 | ex.extract("1032", v1032, 1); 205 | ex.extract("1118", v1118, 1); 206 | ex.extract("1204", v1204, 1); 207 | ex.extract("1292", v1292, 1); 208 | ex.extract("1378", v1378, 1); 209 | ex.extract("1464", v1464, 1); 210 | ex.extract("outout", denoised_cond); 211 | } 212 | 213 | ncnn::Mat denoised_uncond; 214 | { 215 | ncnn::Extractor ex = net.create_extractor(); 216 | ex.set_light_mode(true); 217 | ex.input("in0", input); 218 | ex.input("in1", t_mat); 219 | ex.input("in2", uncond); 220 | ex.input("c_in", c_in_mat); 221 | ex.input("c_out", c_out_mat); 222 | ex.input("44", v44); 223 | ex.input("83", v83); 224 | ex.input("116", v116); 225 | ex.input("163", v163); 226 | ex.input("251", v251); 227 | ex.input("337", v337); 228 | ex.input("425", v425); 229 | ex.input("511", v511); 230 | ex.input("599", v599); 231 | ex.input("627", v627); 232 | ex.input("711", v711); 233 | ex.input("725", v725); 234 | ex.input("740", v740); 235 | ex.input("755", v755); 236 | ex.input("772", v772); 237 | ex.input("858", v858); 238 | ex.input("944", v944); 239 | ex.input("1032", v1032); 240 | ex.input("1118", v1118); 241 | ex.input("1204", v1204); 242 | ex.input("1292", v1292); 243 | ex.input("1378", v1378); 244 | ex.input("1464", v1464); 245 | ex.extract("outout", denoised_uncond); 246 | } 247 | 248 | for (int c = 0; c < 4; c++) { 249 | float* u_ptr = denoised_uncond.channel(c); 250 | float* c_ptr = denoised_cond.channel(c); 251 | for (int hw = 0; hw < h_size * w_size; hw++) { 252 | (*u_ptr) = (*u_ptr) + guidance_scale * ((*c_ptr) - (*u_ptr)); 253 | u_ptr++; 254 | c_ptr++; 255 | } 256 | } 257 | 258 | return denoised_uncond; 259 | } 260 | 261 | ncnn::Mat DiffusionSlover::sampler_txt2img(int seed, int step, ncnn::Mat& c, ncnn::Mat& uc) 262 | { 263 | // t_to_sigma 264 | vector sigma(step); 265 | float delta = 0.0 - 999.0 / (step - 1); 266 | for (int i = 0; i < step; i++) { 267 | float t = 999.0 + i * delta; 268 | int low_idx = floor(t); 269 | int high_idx = ceil(t); 270 | float w = t - low_idx; 271 | sigma[i] = exp((1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]); 272 | } 273 | sigma.push_back(0.f); 274 | 275 | // init 276 | ncnn::Mat x_mat = randn_4(seed % 1000); 277 | float _norm_[4] = { sigma[0], sigma[0], sigma[0], sigma[0] }; 278 | x_mat.substract_mean_normalize(0, _norm_); 279 | 280 | // euler ancestral 281 | { 282 | for (int i = 0; i < sigma.size() - 1; i++) { 283 | printf("step:%2d/%d\t", i + 1, sigma.size() - 1); 284 | 285 | double t1 = ncnn::get_current_time(); 286 | ncnn::Mat denoised = CFGDenoiser_CompVisDenoiser(x_mat, sigma[i], c, uc); 287 | double t2 = ncnn::get_current_time(); 288 | printf("%.2lfms\n", t2 - t1); 289 | 290 | float sigma_up = min(sigma[i + 1], sqrt(sigma[i + 1] * sigma[i + 1] * (sigma[i] * sigma[i] - sigma[i + 1] * sigma[i + 1]) / (sigma[i] * sigma[i]))); 291 | float sigma_down = sqrt(sigma[i + 1] * sigma[i + 1] - sigma_up * sigma_up); 292 | 293 | srand(time(NULL) + i); 294 | ncnn::Mat randn = randn_4(rand() % 1000); 295 | for (int c = 0; c < 4; c++) { 296 | float* x_ptr = x_mat.channel(c); 297 | float* d_ptr = denoised.channel(c); 298 | float* r_ptr = randn.channel(c); 299 | for (int hw = 0; hw < h_size * w_size; hw++) { 300 | *x_ptr = *x_ptr + ((*x_ptr - *d_ptr) / sigma[i]) * (sigma_down - sigma[i]) + *r_ptr * sigma_up; 301 | x_ptr++; 302 | d_ptr++; 303 | r_ptr++; 304 | } 305 | } 306 | } 307 | } 308 | 309 | /* 310 | // DPM++ 2M Karras 311 | ncnn::Mat old_denoised; 312 | { 313 | for (int i = 0; i < sigma.size() - 1; i++) { 314 | cout << "step:" << i << "\t\t"; 315 | 316 | double t1 = ncnn::get_current_time(); 317 | ncnn::Mat denoised = CFGDenoiser_CompVisDenoiser(x_mat, sigma[i], c, uc); 318 | double t2 = ncnn::get_current_time(); 319 | cout << t2 - t1 << "ms" << endl; 320 | 321 | float sigma_curt = sigma[i]; 322 | float sigma_next = sigma[i + 1]; 323 | float tt = -1.0 * log(sigma_curt); 324 | float tt_next = -1.0 * log(sigma_next); 325 | float hh = tt_next - tt; 326 | if (old_denoised.empty() || sigma_next == 0) 327 | { 328 | for (int c = 0; c < 4; c++) { 329 | float* x_ptr = x_mat.channel(c); 330 | float* d_ptr = denoised.channel(c); 331 | for (int hw = 0; hw < size * size; hw++) { 332 | *x_ptr = (sigma_next / sigma_curt) * *x_ptr - (exp(-hh) - 1) * *d_ptr; 333 | x_ptr++; 334 | d_ptr++; 335 | } 336 | } 337 | } 338 | else 339 | { 340 | float hh_last = -1.0 * log(sigma[i - 1]); 341 | float r = hh_last / hh; 342 | for (int c = 0; c < 4; c++) { 343 | float* x_ptr = x_mat.channel(c); 344 | float* d_ptr = denoised.channel(c); 345 | float* od_ptr = old_denoised.channel(c); 346 | for (int hw = 0; hw < size * size; hw++) { 347 | *x_ptr = (sigma_next / sigma_curt) * *x_ptr - (exp(-hh) - 1) * ((1 + 1 / (2 * r)) * *d_ptr - (1 / (2 * r)) * *od_ptr); 348 | x_ptr++; 349 | d_ptr++; 350 | od_ptr++; 351 | } 352 | } 353 | } 354 | old_denoised.clone_from(denoised); 355 | } 356 | } 357 | */ 358 | 359 | ncnn::Mat fuck_x; 360 | fuck_x.clone_from(x_mat); 361 | return fuck_x; 362 | } 363 | 364 | ncnn::Mat DiffusionSlover::sampler_img2img(int seed, int step, ncnn::Mat& c, ncnn::Mat& uc, vector& init) 365 | { 366 | // t_to_sigma 367 | vector sigma(step); 368 | float delta = 0.0 - 999.0 / (step - 1); 369 | for (int i = 0; i < step; i++) { 370 | float t = 999.0 + i * delta; 371 | int low_idx = floor(t); 372 | int high_idx = ceil(t); 373 | float w = t - low_idx; 374 | sigma[i] = exp((1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]); 375 | } 376 | sigma.push_back(0.f); 377 | 378 | // init 379 | ncnn::Mat x_mat(w_size, h_size, 4); 380 | 381 | // finish the rest of decoder 382 | { 383 | ncnn::Mat noise_mat = randn_4(seed % 1000); 384 | for (int c = 0; c < 4; c++) { 385 | float* x_ptr = x_mat.channel(c); 386 | float* noise_ptr = noise_mat.channel(c); 387 | float* mean_ptr = init[0].channel(c); 388 | float* std_ptr = init[1].channel(c); 389 | for (int hw = 0; hw < h_size * w_size; hw++) { 390 | *x_ptr = *mean_ptr + *std_ptr * *noise_ptr; 391 | x_ptr++; 392 | noise_ptr++; 393 | mean_ptr++; 394 | std_ptr++; 395 | } 396 | } 397 | x_mat.substract_mean_normalize(0, factor); 398 | } 399 | 400 | // reset scheduling 401 | int new_step = step * strength; 402 | { 403 | float _sigma_ = sigma[step - new_step]; 404 | ncnn::Mat noise_mat = randn_4(seed % 1000); 405 | for (int c = 0; c < 4; c++) { 406 | float* x_ptr = x_mat.channel(c); 407 | float* noise_ptr = noise_mat.channel(c); 408 | for (int hw = 0; hw < h_size * w_size; hw++) { 409 | *x_ptr = *x_ptr + *noise_ptr * _sigma_; 410 | x_ptr++; 411 | noise_ptr++; 412 | } 413 | } 414 | } 415 | vector sub_sigma(sigma.begin() + step - new_step, sigma.end()); 416 | 417 | // euler ancestral 418 | { 419 | for (int i = 0; i < sub_sigma.size() - 1; i++) { 420 | printf("step:%2d/%d\t", i+1, sub_sigma.size()-1); 421 | 422 | double t1 = ncnn::get_current_time(); 423 | ncnn::Mat denoised = CFGDenoiser_CompVisDenoiser(x_mat, sub_sigma[i], c, uc); 424 | double t2 = ncnn::get_current_time(); 425 | printf("%.2lfms\n", t2 - t1); 426 | 427 | float sigma_up = min(sub_sigma[i + 1], sqrt(sub_sigma[i + 1] * sub_sigma[i + 1] * (sub_sigma[i] * sub_sigma[i] - sub_sigma[i + 1] * sub_sigma[i + 1]) / (sub_sigma[i] * sub_sigma[i]))); 428 | float sigma_down = sqrt(sub_sigma[i + 1] * sub_sigma[i + 1] - sigma_up * sigma_up); 429 | 430 | srand(time(NULL) + i); 431 | ncnn::Mat randn = randn_4(rand() % 1000); 432 | for (int c = 0; c < 4; c++) { 433 | float* x_ptr = x_mat.channel(c); 434 | float* d_ptr = denoised.channel(c); 435 | float* r_ptr = randn.channel(c); 436 | for (int hw = 0; hw < h_size * w_size; hw++) { 437 | *x_ptr = *x_ptr + ((*x_ptr - *d_ptr) / sub_sigma[i]) * (sigma_down - sub_sigma[i]) + *r_ptr * sigma_up; 438 | x_ptr++; 439 | d_ptr++; 440 | r_ptr++; 441 | } 442 | } 443 | } 444 | } 445 | 446 | ncnn::Mat fuck_x; 447 | fuck_x.clone_from(x_mat); 448 | return fuck_x; 449 | } 450 | -------------------------------------------------------------------------------- /overrides/diffusion_slover.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include "benchmark.h" 15 | using namespace std; 16 | 17 | class DiffusionSlover 18 | { 19 | public: 20 | DiffusionSlover(int h, int w, int mode, string assets_dir); 21 | 22 | ncnn::Mat sampler_txt2img(int seed, int step, ncnn::Mat& c, ncnn::Mat& uc); 23 | ncnn::Mat sampler_img2img(int seed, int step, ncnn::Mat& c, ncnn::Mat& uc, vector& init); 24 | 25 | private: 26 | void generate_param(int height, int width, string assets_dir); 27 | 28 | ncnn::Mat randn_4(int seed); 29 | ncnn::Mat CFGDenoiser_CompVisDenoiser(ncnn::Mat& input, float sigma, ncnn::Mat cond, ncnn::Mat uncond); 30 | 31 | private: 32 | float log_sigmas[1000] = { 0 }; 33 | const float guidance_scale = 7.5; 34 | const float strength = 0.75; 35 | 36 | const float factor[4] = { 0.18215f, 0.18215f, 0.18215f, 0.18215f }; 37 | 38 | ncnn::Net net; 39 | 40 | int h_size, w_size; 41 | }; -------------------------------------------------------------------------------- /overrides/encoder_slover.cpp: -------------------------------------------------------------------------------- 1 | #include "encoder_slover.h" 2 | #include 3 | 4 | EncodeSlover::EncodeSlover(int h, int w, string assets_dir) 5 | { 6 | net.opt.use_vulkan_compute = false; 7 | net.opt.use_winograd_convolution = false; 8 | net.opt.use_sgemm_convolution = false; 9 | net.opt.use_fp16_packed = false; 10 | net.opt.use_fp16_storage = false; 11 | net.opt.use_fp16_arithmetic = false; 12 | net.opt.use_bf16_storage = true; 13 | net.opt.use_packing_layout = true; 14 | 15 | // Define the name of the parameter file 16 | std::string param_file; 17 | if (h == 512 && w == 512) 18 | param_file = "AutoencoderKL-encoder-512-512-fp16.param"; 19 | else 20 | { 21 | generate_param(h, w, assets_dir); 22 | param_file = "tmp-AutoencoderKL-encoder-" + to_string(h) + "-" + to_string(w) + "-fp16.param"; 23 | } 24 | 25 | // Join the paths using std::filesystem::path::operator/() function 26 | std::filesystem::path param_path = std::filesystem::path(assets_dir) / std::filesystem::path(param_file); 27 | net.load_param(param_path.string().c_str()); 28 | 29 | std::filesystem::path encoder_path = std::filesystem::path(assets_dir) / std::filesystem::path("AutoencoderKL-encoder-512-512-fp16.bin"); 30 | 31 | net.load_model(encoder_path.string().c_str()); 32 | 33 | h_size = h; 34 | w_size = w; 35 | } 36 | 37 | void EncodeSlover::generate_param(int height, int width, string assets_dir) 38 | { 39 | string line; 40 | 41 | std::filesystem::path decoder = std::filesystem::path(assets_dir) / std::filesystem::path("AutoencoderKL-encoder-512-512-fp16.param"); 42 | std::filesystem::path decoder_out = std::filesystem::path(assets_dir) / std::filesystem::path("tmp-AutoencoderKL-encoder-" + std::to_string(height) + "-" + std::to_string(width) + "-fp16.param"); 43 | 44 | ifstream encoder_file(decoder.string().c_str()); 45 | ofstream encoder_file_new(decoder_out.string().c_str()); 46 | 47 | int cnt = 0; 48 | while (getline(encoder_file, line)) 49 | { 50 | if (line.substr(0, 7) == "Reshape") 51 | { 52 | switch (cnt) 53 | { 54 | case 0: line = line.substr(0, line.size() - 12) + "0=" + to_string(width * height / 8 / 8) + " 1=512"; break; 55 | case 1: line = line.substr(0, line.size() - 15) + "0=" + to_string(width / 8) + " 1=" + std::to_string(height / 8) + " 2=512"; break; 56 | default: break; 57 | } 58 | 59 | cnt++; 60 | } 61 | encoder_file_new << line << endl; 62 | } 63 | encoder_file_new.close(); 64 | encoder_file.close(); 65 | } 66 | 67 | std::vector EncodeSlover::encode(cv::Mat& bgr_image) 68 | { 69 | std::vector mean_std(2); 70 | { 71 | int ih = bgr_image.rows, iw = bgr_image.cols; 72 | ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr_image.data, ncnn::Mat::PIXEL_BGR2RGB, iw, ih, w_size, h_size); 73 | in.substract_mean_normalize(_mean_, _norm_); 74 | 75 | { 76 | ncnn::Extractor ex = net.create_extractor(); 77 | ex.set_light_mode(true); 78 | ex.input("in0", in); 79 | ex.extract("out0", mean_std[0]); 80 | ex.extract("out1", mean_std[1]); 81 | } 82 | } 83 | 84 | return mean_std; 85 | } -------------------------------------------------------------------------------- /overrides/encoder_slover.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | using namespace std; 15 | 16 | class EncodeSlover 17 | { 18 | public: 19 | EncodeSlover(int h, int w, string assets_dir); 20 | 21 | std::vector encode(cv::Mat& image); 22 | 23 | private: 24 | void generate_param(int height, int width, string assets_dir); 25 | 26 | const float _mean_[3] = { 127.5f, 127.5f, 127.5f }; 27 | const float _norm_[3] = { 1.0 / 127.5f, 1.0 / 127.5f, 1.0 / 127.5f }; 28 | 29 | ncnn::Net net; 30 | 31 | int h_size, w_size; 32 | }; -------------------------------------------------------------------------------- /overrides/prompt_slover.cpp: -------------------------------------------------------------------------------- 1 | #include "prompt_slover.h" 2 | #include 3 | 4 | PromptSlover::PromptSlover(string assets_dir) 5 | { 6 | // ����CLIPģ�� 7 | net.opt.use_vulkan_compute = false; 8 | net.opt.use_winograd_convolution = false; 9 | net.opt.use_sgemm_convolution = false; 10 | net.opt.use_fp16_packed = true; 11 | net.opt.use_fp16_storage = true; 12 | net.opt.use_fp16_arithmetic = true; 13 | net.opt.use_packing_layout = true; 14 | std::filesystem::path params = std::filesystem::path(assets_dir) / std::filesystem::path("FrozenCLIPEmbedder-fp16.param"); 15 | std::filesystem::path model = std::filesystem::path(assets_dir) / std::filesystem::path("FrozenCLIPEmbedder-fp16.bin"); 16 | 17 | net.load_param(params.string().c_str()); 18 | net.load_model(model.string().c_str()); 19 | 20 | // ��ȡtokenizer�ֵ� 21 | std::ifstream infile; 22 | std::filesystem::path v = std::filesystem::path(assets_dir) / std::filesystem::path("vocab.txt"); 23 | std::string pathname = v.string(); 24 | infile.open(pathname.data()); 25 | std::string s; 26 | int idx = 0; 27 | while (getline(infile, s)) { 28 | tokenizer_token2idx.insert(pair(s, idx)); 29 | tokenizer_idx2token.insert(pair(idx, s)); 30 | idx++; 31 | } 32 | infile.close(); 33 | } 34 | 35 | ncnn::Mat PromptSlover::get_conditioning(string& prompt) 36 | { 37 | // ��Ҫ�ȼ������ƥ�䡰()���͡�[]����Բ�����Ǽ���Ҫ�ȣ��������Ǽ���Ҫ�� 38 | vector> parsed = parse_prompt_attention(prompt); 39 | 40 | // tokenתids 41 | vector> tokenized; 42 | { 43 | for (auto p : parsed) { 44 | vector tokens = split(p.first); 45 | vector ids; 46 | for (string token : tokens) { 47 | ids.push_back(tokenizer_token2idx[token]); 48 | } 49 | tokenized.push_back(ids); 50 | } 51 | } 52 | 53 | // һЩ���� 54 | vector remade_tokens; 55 | vector multipliers; 56 | { 57 | int last_comma = -1; 58 | for (int it_tokenized = 0; it_tokenized < tokenized.size(); it_tokenized++) { 59 | vector tokens = tokenized[it_tokenized]; 60 | float weight = parsed[it_tokenized].second; 61 | 62 | int i = 0; 63 | while (i < tokens.size()) { 64 | int token = tokens[i]; 65 | if (token == 267) { 66 | last_comma = remade_tokens.size(); 67 | } 68 | else if ((max(int(remade_tokens.size()), 1) % 75 == 0) && (last_comma != -1) && (remade_tokens.size() - last_comma <= 20)) { 69 | last_comma += 1; 70 | vector reloc_tokens(remade_tokens.begin() + last_comma, remade_tokens.end()); 71 | vector reloc_mults(multipliers.begin() + last_comma, multipliers.end()); 72 | vector _remade_tokens_(remade_tokens.begin(), remade_tokens.begin() + last_comma); 73 | remade_tokens = _remade_tokens_; 74 | int length = remade_tokens.size(); 75 | int rem = ceil(length / 75.0) * 75 - length; 76 | vector tmp_token(rem, 49407); 77 | remade_tokens.insert(remade_tokens.end(), tmp_token.begin(), tmp_token.end()); 78 | remade_tokens.insert(remade_tokens.end(), reloc_tokens.begin(), reloc_tokens.end()); 79 | vector _multipliers_(multipliers.begin(), multipliers.end() + last_comma); 80 | vector tmp_multipliers(rem, 1.0f); 81 | _multipliers_.insert(_multipliers_.end(), tmp_multipliers.begin(), tmp_multipliers.end()); 82 | _multipliers_.insert(_multipliers_.end(), reloc_mults.begin(), reloc_mults.end()); 83 | multipliers = _multipliers_; 84 | } 85 | remade_tokens.push_back(token); 86 | multipliers.push_back(weight); 87 | i += 1; 88 | } 89 | } 90 | int prompt_target_length = ceil(max(int(remade_tokens.size()), 1) / 75.0) * 75; 91 | int tokens_to_add = prompt_target_length - remade_tokens.size(); 92 | vector tmp_token(tokens_to_add, 49407); 93 | remade_tokens.insert(remade_tokens.end(), tmp_token.begin(), tmp_token.end()); 94 | vector tmp_multipliers(tokens_to_add, 1.0f); 95 | multipliers.insert(multipliers.end(), tmp_multipliers.begin(), tmp_multipliers.end()); 96 | } 97 | 98 | // �з� 99 | ncnn::Mat conds(768, 0); 100 | { 101 | while (remade_tokens.size() > 0) { 102 | vector rem_tokens(remade_tokens.begin() + 75, remade_tokens.end()); 103 | vector rem_multipliers(multipliers.begin() + 75, multipliers.end()); 104 | 105 | vector current_tokens; 106 | vector current_multipliers; 107 | if (remade_tokens.size() > 0) { 108 | current_tokens.insert(current_tokens.end(), remade_tokens.begin(), remade_tokens.begin() + 75); 109 | current_multipliers.insert(current_multipliers.end(), multipliers.begin(), multipliers.begin() + 75); 110 | } 111 | else { 112 | vector tmp_token(75, 49407); 113 | current_tokens.insert(current_tokens.end(), tmp_token.begin(), tmp_token.end()); 114 | vector tmp_multipliers(75, 1.0f); 115 | current_multipliers.insert(current_multipliers.end(), tmp_multipliers.begin(), tmp_multipliers.end()); 116 | } 117 | 118 | { 119 | ncnn::Mat token_mat = ncnn::Mat(77); 120 | token_mat.fill(int(49406)); 121 | ncnn::Mat multiplier_mat = ncnn::Mat(77); 122 | multiplier_mat.fill(1.0f); 123 | 124 | int* token_ptr = token_mat; 125 | float* multiplier_ptr = multiplier_mat; 126 | for (int i = 0; i < 75; i++) { 127 | token_ptr[i + 1] = int(current_tokens[i]); 128 | multiplier_ptr[i + 1] = current_multipliers[i]; 129 | } 130 | 131 | ncnn::Extractor ex = net.create_extractor(); 132 | ex.set_light_mode(true); 133 | ex.input("token", token_mat); 134 | ex.input("multiplier", multiplier_mat); 135 | ex.input("cond", conds); 136 | ncnn::Mat new_conds; 137 | ex.extract("conds", new_conds); 138 | conds = new_conds; 139 | 140 | } 141 | 142 | remade_tokens = rem_tokens; 143 | multipliers = rem_multipliers; 144 | } 145 | } 146 | 147 | return conds; 148 | } 149 | 150 | vector> PromptSlover::parse_prompt_attention(string& texts) 151 | { 152 | vector> res; 153 | stack round_brackets; 154 | stack square_brackets; 155 | const float round_bracket_multiplier = 1.1; 156 | const float square_bracket_multiplier = 1 / 1.1; 157 | 158 | vector ms; 159 | for (char c : texts) { 160 | string s = string(1, c); 161 | if (s == "(" || s == "[" || s == ")" || s == "]") { 162 | ms.push_back(s); 163 | } 164 | else { 165 | if (ms.size() < 1) 166 | ms.push_back(""); 167 | string last = ms[ms.size() - 1]; 168 | if (last == "(" || last == "[" || last == ")" || last == "]") { 169 | ms.push_back(""); 170 | } 171 | ms[ms.size() - 1] += s; 172 | } 173 | } 174 | 175 | for (string text : ms) { 176 | if (text == "(") { 177 | round_brackets.push(res.size()); 178 | } 179 | else if (text == "[") { 180 | square_brackets.push(res.size()); 181 | } 182 | else if (text == ")" && round_brackets.size() > 0) { 183 | for (int p = round_brackets.top(); p < res.size(); p++) { 184 | res[p].second *= round_bracket_multiplier; 185 | } 186 | round_brackets.pop(); 187 | } 188 | else if (text == "]" and square_brackets.size() > 0) { 189 | for (int p = square_brackets.top(); p < res.size(); p++) { 190 | res[p].second *= square_bracket_multiplier; 191 | } 192 | square_brackets.pop(); 193 | } 194 | else { 195 | res.push_back(make_pair(text, 1.0)); 196 | } 197 | } 198 | 199 | while (!round_brackets.empty()) { 200 | for (int p = round_brackets.top(); p < res.size(); p++) { 201 | res[p].second *= round_bracket_multiplier; 202 | } 203 | round_brackets.pop(); 204 | } 205 | 206 | while (!square_brackets.empty()) { 207 | for (int p = square_brackets.top(); p < res.size(); p++) { 208 | res[p].second *= square_bracket_multiplier; 209 | } 210 | square_brackets.pop(); 211 | } 212 | 213 | int i = 0; 214 | while (i + 1 < res.size()) { 215 | if (res[i].second == res[i + 1].second) { 216 | res[i].first += res[i + 1].first; 217 | auto it = res.begin(); 218 | res.erase(it + i + 1); 219 | } 220 | else { 221 | i += 1; 222 | } 223 | } 224 | 225 | return res; 226 | } 227 | 228 | string PromptSlover::whitespace_clean(string& text) 229 | { 230 | return regex_replace(text, regex("\\s+"), " "); 231 | } 232 | 233 | std::vector PromptSlover::split(std::string str) 234 | { 235 | std::string::size_type pos; 236 | std::vector result; 237 | str += " "; 238 | int size = str.size(); 239 | for (int i = 0; i < size; i++) 240 | { 241 | pos = min(str.find(" ", i), str.find(",", i)); 242 | if (pos < size) 243 | { 244 | std::string s = str.substr(i, pos - i); 245 | string pat = string(1, str[pos]); 246 | if (s.length() > 0) 247 | result.push_back(s + ""); 248 | if (pat != " ") 249 | result.push_back(pat + ""); 250 | i = pos; 251 | } 252 | } 253 | return result; 254 | } -------------------------------------------------------------------------------- /overrides/prompt_slover.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | using namespace std; 12 | 13 | class PromptSlover 14 | { 15 | public: 16 | PromptSlover(string assets_dir); 17 | 18 | ncnn::Mat get_conditioning(string& prompt); 19 | 20 | private: 21 | std::vector split(std::string str); 22 | string whitespace_clean(string& text); 23 | vector> parse_prompt_attention(string& texts); 24 | 25 | map tokenizer_token2idx; 26 | map tokenizer_idx2token; 27 | 28 | ncnn::Net net; 29 | }; -------------------------------------------------------------------------------- /stablediffusion.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "prompt_slover.h" 11 | #include "decoder_slover.h" 12 | #include "encoder_slover.h" 13 | #include "diffusion_slover.h" 14 | #include "diffusion_slover.cpp" 15 | #include "decoder_slover.cpp" 16 | #include "encoder_slover.cpp" 17 | #include "prompt_slover.cpp" 18 | #include 19 | #include 20 | #include 21 | #include "getmem.h" 22 | #include "stablediffusion.h" 23 | #include "stablediffusion.hpp" 24 | 25 | using namespace std; 26 | 27 | int generate_image(int height, int width, int mode, int step, int seed, const char *positive_prompt, const char *negative_prompt, const char *dst, const char *init_image, const char *assets_dir) 28 | { 29 | 30 | if (seed == 0) { 31 | seed = (unsigned)time(NULL); 32 | } 33 | 34 | cout << "----------------[start stable diffusion]--------------------" << endl; 35 | 36 | // stable diffusion 37 | cout << "----------------[init]--------------------" << endl; 38 | PromptSlover prompt_slover(assets_dir); 39 | DiffusionSlover diffusion_slover(height, width, mode, assets_dir); 40 | DecodeSlover decode_slover(height, width, assets_dir); 41 | EncodeSlover encode_slover(height, width, assets_dir); 42 | printf(" %.2lfG / %.2lfG\n", getCurrentRSS() / 1024.0 / 1024.0 / 1024.0, getPeakRSS() / 1024.0 / 1024.0 / 1024.0); 43 | std::string positive_prompt_str = positive_prompt; 44 | std::string negative_prompt_str = negative_prompt; 45 | cout << "----------------[prompt]------------------"; 46 | ncnn::Mat cond = prompt_slover.get_conditioning(positive_prompt_str); 47 | ncnn::Mat uncond = prompt_slover.get_conditioning(negative_prompt_str); 48 | printf(" %.2lfG / %.2lfG\n", getCurrentRSS() / 1024.0 / 1024.0 / 1024.0, getPeakRSS() / 1024.0 / 1024.0 / 1024.0); 49 | 50 | vector init_latents; 51 | cv::Mat img = cv::imread(init_image); 52 | if (!img.empty()) { 53 | cout << "----------------[ encoder ]----------------"; 54 | init_latents = encode_slover.encode(img); 55 | printf(" %.2lfG / %.2lfG\n", getCurrentRSS() / 1024.0 / 1024.0 / 1024.0, getPeakRSS() / 1024.0 / 1024.0 / 1024.0); 56 | } 57 | 58 | cout << "----------------[diffusion]---------------" << endl; 59 | ncnn::Mat sample; 60 | if (!img.empty()) { 61 | sample = diffusion_slover.sampler_img2img(seed, step, cond, uncond, init_latents); 62 | } 63 | else { 64 | sample = diffusion_slover.sampler_txt2img(seed, step, cond, uncond); 65 | } 66 | cout << "----------------[diffusion]---------------"; 67 | printf(" %.2lfG / %.2lfG\n", getCurrentRSS() / 1024.0 / 1024.0 / 1024.0, getPeakRSS() / 1024.0 / 1024.0 / 1024.0); 68 | 69 | cout << "----------------[decode]------------------"; 70 | ncnn::Mat x_samples_ddim = decode_slover.decode(sample); 71 | printf(" %.2lfG / %.2lfG\n", getCurrentRSS() / 1024.0 / 1024.0 / 1024.0, getPeakRSS() / 1024.0 / 1024.0 / 1024.0); 72 | 73 | cout << "----------------[save]--------------------" << endl; 74 | cv::Mat image(height, width, CV_8UC3); 75 | x_samples_ddim.to_pixels(image.data, ncnn::Mat::PIXEL_RGB2BGR); 76 | //cv::imwrite(dst, image); 77 | cv::imwrite(dst, image); 78 | 79 | cout << "----------------[close]-------------------" << endl; 80 | 81 | return 0; 82 | } 83 | 84 | int generate_image_upscaled( int height, int width, int step, int seed, const char *positive_prompt, const char *negative_prompt, const char *dst, const char *assets_dir) 85 | { 86 | std::cout << "----------------[start generation upscaled image]------------------" << std::endl; 87 | std::cout << "positive_prompt: " << positive_prompt << std::endl; 88 | std::cout << "output_png_path: " << dst << std::endl; 89 | std::cout << "negative_prompt: " << negative_prompt << std::endl; 90 | std::cout << "step: " << step << std::endl; 91 | std::cout << "seed: " << seed << std::endl; 92 | std::cout << "----------------[prompt]------------------" << std::endl; 93 | auto [cond, uncond] = prompt_solver( positive_prompt, negative_prompt, assets_dir ); 94 | std::cout << "----------------[diffusion]---------------" << std::endl; 95 | ncnn::Mat sample = diffusion_solver( seed, step, cond, uncond , assets_dir); 96 | std::cout << "----------------[decode]------------------" << std::endl; 97 | ncnn::Mat x_samples_ddim = decoder_solver( sample, assets_dir ); 98 | std::cout << "----------------[4x]--------------------" << std::endl; 99 | x_samples_ddim = esr4x( x_samples_ddim, assets_dir ); 100 | std::cout << "----------------[save]--------------------" << std::endl; 101 | { 102 | std::vector buffer; 103 | //buffer.resize( 512 * 512 * 3 ); 104 | buffer.resize( height * width * 3 ); 105 | x_samples_ddim.to_pixels( buffer.data(), ncnn::Mat::PIXEL_RGB ); 106 | save_png( buffer.data(), height, width, 0, dst ); 107 | } 108 | std::cout << "----------------[close]-------------------" << std::endl; 109 | return 0; 110 | } -------------------------------------------------------------------------------- /stablediffusion.go: -------------------------------------------------------------------------------- 1 | package stablediffusion 2 | 3 | // #cgo CXXFLAGS: -I${SRCDIR}/ -I${SRCDIR}/ncnn/src -I${SRCDIR}/ncnn -I${SRCDIR}/ncnn/build/src/ -I${SRCDIR}/stable-diffusion/x86/vs2019_opencv-mobile_ncnn-dll_demo/vs2019_opencv-mobile_ncnn-dll_demo/ -std=c++17 4 | // #cgo LDFLAGS: -L${SRCDIR}/ -lstablediffusion -lgomp -lopencv_core -lopencv_imgcodecs -lm -lstdc++ 5 | // #include "stablediffusion.h" 6 | // #include 7 | import "C" 8 | import ( 9 | "fmt" 10 | ) 11 | 12 | func GenerateImage(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst, init_image, asset_dir string) error { 13 | pp := C.CString(positive_prompt) 14 | np := C.CString(negative_prompt) 15 | ii := C.CString(init_image) 16 | ad := C.CString(asset_dir) 17 | 18 | destination := C.CString(dst) 19 | 20 | ret := C.generate_image(C.int(height), C.int(width), C.int(mode), C.int(step), C.int(seed), pp, np, destination, ii, ad) 21 | if ret != 0 { 22 | return fmt.Errorf("failed") 23 | } 24 | return nil 25 | } 26 | 27 | func GenerateImageUpscaled(height, width, step, seed int, positive_prompt, negative_prompt, dst, asset_dir string) error { 28 | pp := C.CString(positive_prompt) 29 | np := C.CString(negative_prompt) 30 | ad := C.CString(asset_dir) 31 | 32 | destination := C.CString(dst) 33 | 34 | ret := C.generate_image_upscaled(C.int(height), C.int(width), C.int(step), C.int(seed), pp, np, destination, ad) 35 | if ret != 0 { 36 | return fmt.Errorf("failed") 37 | } 38 | return nil 39 | } 40 | -------------------------------------------------------------------------------- /stablediffusion.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | #include 3 | #include 4 | extern "C" { 5 | #endif 6 | 7 | #include 8 | 9 | int generate_image(int height, int width, int mode, int step, int seed, const char *positive_prompt, const char *negative_prompt, const char *dst, const char *init_image, const char *assets_dir); 10 | int generate_image_upscaled( int height, int width, int step, int seed, const char *positive_prompt, const char *negative_prompt, const char *dst, const char *assets_dir); 11 | 12 | #ifdef __cplusplus 13 | } 14 | #endif 15 | -------------------------------------------------------------------------------- /stablediffusion.hpp: -------------------------------------------------------------------------------- 1 | // This is an adaptation of https://github.com/fengwang/Stable-Diffusion-NCNN/blob/main/stable_diffusion.hpp 2 | // Credits goes to fengwang and EdVince 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | // adapted from https://github.com/miloyip/svpng/blob/master/svpng.inc 24 | inline static void save_png( std::uint8_t* img, unsigned w, unsigned h, int alpha, char const* const file_name ) noexcept 25 | { 26 | constexpr unsigned t[] = { 0, 0x1db71064, 0x3b6e20c8, 0x26d930ac, 0x76dc4190, 0x6b6b51f4, 0x4db26158, 0x5005713c, 0xedb88320, 0xf00f9344, 0xd6d6a3e8, 0xcb61b38c, 0x9b64c2b0, 0x86d3d2d4, 0xa00ae278, 0xbdbdf21c }; 27 | unsigned a = 1, b = 0, c, p = w * ( alpha ? 4 : 3 ) + 1, x, y, i; 28 | FILE* fp = fopen( file_name, "wb" ); 29 | 30 | for ( i = 0; i < 8; i++ ) 31 | fputc( ( "\x89PNG\r\n\32\n" )[i], fp );; 32 | 33 | { 34 | { 35 | fputc( ( 13 ) >> 24, fp ); 36 | fputc( ( ( 13 ) >> 16 ) & 255, fp ); 37 | fputc( ( ( 13 ) >> 8 ) & 255, fp ); 38 | fputc( ( 13 ) & 255, fp ); 39 | } 40 | c = ~0U; 41 | 42 | for ( i = 0; i < 4; i++ ) 43 | { 44 | fputc( ( "IHDR" )[i], fp ); 45 | c ^= ( ( "IHDR" )[i] ); 46 | c = ( c >> 4 ) ^ t[c & 15]; 47 | c = ( c >> 4 ) ^ t[c & 15]; 48 | } 49 | } 50 | { 51 | { 52 | fputc( ( w ) >> 24, fp ); 53 | c ^= ( ( w ) >> 24 ); 54 | c = ( c >> 4 ) ^ t[c & 15]; 55 | c = ( c >> 4 ) ^ t[c & 15]; 56 | } 57 | { 58 | fputc( ( ( w ) >> 16 ) & 255, fp ); 59 | c ^= ( ( ( w ) >> 16 ) & 255 ); 60 | c = ( c >> 4 ) ^ t[c & 15]; 61 | c = ( c >> 4 ) ^ t[c & 15]; 62 | } 63 | { 64 | fputc( ( ( w ) >> 8 ) & 255, fp ); 65 | c ^= ( ( ( w ) >> 8 ) & 255 ); 66 | c = ( c >> 4 ) ^ t[c & 15]; 67 | c = ( c >> 4 ) ^ t[c & 15]; 68 | } 69 | { 70 | fputc( ( w ) & 255, fp ); 71 | c ^= ( ( w ) & 255 ); 72 | c = ( c >> 4 ) ^ t[c & 15]; 73 | c = ( c >> 4 ) ^ t[c & 15]; 74 | } 75 | } 76 | { 77 | { 78 | fputc( ( h ) >> 24, fp ); 79 | c ^= ( ( h ) >> 24 ); 80 | c = ( c >> 4 ) ^ t[c & 15]; 81 | c = ( c >> 4 ) ^ t[c & 15]; 82 | } 83 | { 84 | fputc( ( ( h ) >> 16 ) & 255, fp ); 85 | c ^= ( ( ( h ) >> 16 ) & 255 ); 86 | c = ( c >> 4 ) ^ t[c & 15]; 87 | c = ( c >> 4 ) ^ t[c & 15]; 88 | } 89 | { 90 | fputc( ( ( h ) >> 8 ) & 255, fp ); 91 | c ^= ( ( ( h ) >> 8 ) & 255 ); 92 | c = ( c >> 4 ) ^ t[c & 15]; 93 | c = ( c >> 4 ) ^ t[c & 15]; 94 | } 95 | { 96 | fputc( ( h ) & 255, fp ); 97 | c ^= ( ( h ) & 255 ); 98 | c = ( c >> 4 ) ^ t[c & 15]; 99 | c = ( c >> 4 ) ^ t[c & 15]; 100 | } 101 | } 102 | { 103 | fputc( 8, fp ); 104 | c ^= ( 8 ); 105 | c = ( c >> 4 ) ^ t[c & 15]; 106 | c = ( c >> 4 ) ^ t[c & 15]; 107 | } 108 | { 109 | fputc( alpha ? 6 : 2, fp ); 110 | c ^= ( alpha ? 6 : 2 ); 111 | c = ( c >> 4 ) ^ t[c & 15]; 112 | c = ( c >> 4 ) ^ t[c & 15]; 113 | } 114 | 115 | for ( i = 0; i < 3; i++ ) 116 | { 117 | fputc( ( "\0\0\0" )[i], fp ); 118 | c ^= ( ( "\0\0\0" )[i] ); 119 | c = ( c >> 4 ) ^ t[c & 15]; 120 | c = ( c >> 4 ) ^ t[c & 15]; 121 | } 122 | 123 | { 124 | fputc( ( ~c ) >> 24, fp ); 125 | fputc( ( ( ~c ) >> 16 ) & 255, fp ); 126 | fputc( ( ( ~c ) >> 8 ) & 255, fp ); 127 | fputc( ( ~c ) & 255, fp ); 128 | } 129 | 130 | { 131 | { 132 | fputc( ( 2 + h * ( 5 + p ) + 4 ) >> 24, fp ); 133 | fputc( ( ( 2 + h * ( 5 + p ) + 4 ) >> 16 ) & 255, fp ); 134 | fputc( ( ( 2 + h * ( 5 + p ) + 4 ) >> 8 ) & 255, fp ); 135 | fputc( ( 2 + h * ( 5 + p ) + 4 ) & 255, fp ); 136 | } 137 | c = ~0U; 138 | 139 | for ( i = 0; i < 4; i++ ) 140 | { 141 | fputc( ( "IDAT" )[i], fp ); 142 | c ^= ( ( "IDAT" )[i] ); 143 | c = ( c >> 4 ) ^ t[c & 15]; 144 | c = ( c >> 4 ) ^ t[c & 15]; 145 | } 146 | } 147 | 148 | for ( i = 0; i < 2; i++ ) 149 | { 150 | fputc( ( "\x78\1" )[i], fp ); 151 | c ^= ( ( "\x78\1" )[i] ); 152 | c = ( c >> 4 ) ^ t[c & 15]; 153 | c = ( c >> 4 ) ^ t[c & 15]; 154 | } 155 | 156 | for ( y = 0; y < h; y++ ) 157 | { 158 | { 159 | fputc( y == h - 1, fp ); 160 | c ^= ( y == h - 1 ); 161 | c = ( c >> 4 ) ^ t[c & 15]; 162 | c = ( c >> 4 ) ^ t[c & 15]; 163 | } 164 | { 165 | { 166 | fputc( ( p ) & 255, fp ); 167 | c ^= ( ( p ) & 255 ); 168 | c = ( c >> 4 ) ^ t[c & 15]; 169 | c = ( c >> 4 ) ^ t[c & 15]; 170 | } 171 | { 172 | fputc( ( ( p ) >> 8 ) & 255, fp ); 173 | c ^= ( ( ( p ) >> 8 ) & 255 ); 174 | c = ( c >> 4 ) ^ t[c & 15]; 175 | c = ( c >> 4 ) ^ t[c & 15]; 176 | } 177 | } 178 | { 179 | { 180 | fputc( ( ~p ) & 255, fp ); 181 | c ^= ( ( ~p ) & 255 ); 182 | c = ( c >> 4 ) ^ t[c & 15]; 183 | c = ( c >> 4 ) ^ t[c & 15]; 184 | } 185 | { 186 | fputc( ( ( ~p ) >> 8 ) & 255, fp ); 187 | c ^= ( ( ( ~p ) >> 8 ) & 255 ); 188 | c = ( c >> 4 ) ^ t[c & 15]; 189 | c = ( c >> 4 ) ^ t[c & 15]; 190 | } 191 | } 192 | { 193 | { 194 | fputc( 0, fp ); 195 | c ^= ( 0 ); 196 | c = ( c >> 4 ) ^ t[c & 15]; 197 | c = ( c >> 4 ) ^ t[c & 15]; 198 | } 199 | a = ( a + ( 0 ) ) % 65521; 200 | b = ( b + a ) % 65521; 201 | } 202 | 203 | for ( x = 0; x < p - 1; x++, img++ ) 204 | { 205 | { 206 | fputc( *img, fp ); 207 | c ^= ( *img ); 208 | c = ( c >> 4 ) ^ t[c & 15]; 209 | c = ( c >> 4 ) ^ t[c & 15]; 210 | } 211 | a = ( a + ( *img ) ) % 65521; 212 | b = ( b + a ) % 65521; 213 | } 214 | } 215 | 216 | { 217 | { 218 | fputc( ( ( b << 16 ) | a ) >> 24, fp ); 219 | c ^= ( ( ( b << 16 ) | a ) >> 24 ); 220 | c = ( c >> 4 ) ^ t[c & 15]; 221 | c = ( c >> 4 ) ^ t[c & 15]; 222 | } 223 | { 224 | fputc( ( ( ( b << 16 ) | a ) >> 16 ) & 255, fp ); 225 | c ^= ( ( ( ( b << 16 ) | a ) >> 16 ) & 255 ); 226 | c = ( c >> 4 ) ^ t[c & 15]; 227 | c = ( c >> 4 ) ^ t[c & 15]; 228 | } 229 | { 230 | fputc( ( ( ( b << 16 ) | a ) >> 8 ) & 255, fp ); 231 | c ^= ( ( ( ( b << 16 ) | a ) >> 8 ) & 255 ); 232 | c = ( c >> 4 ) ^ t[c & 15]; 233 | c = ( c >> 4 ) ^ t[c & 15]; 234 | } 235 | { 236 | fputc( ( ( b << 16 ) | a ) & 255, fp ); 237 | c ^= ( ( ( b << 16 ) | a ) & 255 ); 238 | c = ( c >> 4 ) ^ t[c & 15]; 239 | c = ( c >> 4 ) ^ t[c & 15]; 240 | } 241 | } 242 | 243 | { 244 | fputc( ( ~c ) >> 24, fp ); 245 | fputc( ( ( ~c ) >> 16 ) & 255, fp ); 246 | fputc( ( ( ~c ) >> 8 ) & 255, fp ); 247 | fputc( ( ~c ) & 255, fp ); 248 | } 249 | 250 | { 251 | { 252 | fputc( ( 0 ) >> 24, fp ); 253 | fputc( ( ( 0 ) >> 16 ) & 255, fp ); 254 | fputc( ( ( 0 ) >> 8 ) & 255, fp ); 255 | fputc( ( 0 ) & 255, fp ); 256 | } 257 | c = ~0U; 258 | 259 | for ( i = 0; i < 4; i++ ) 260 | { 261 | fputc( ( "IEND" )[i], fp ); 262 | c ^= ( ( "IEND" )[i] ); 263 | c = ( c >> 4 ) ^ t[c & 15]; 264 | c = ( c >> 4 ) ^ t[c & 15]; 265 | } 266 | } 267 | 268 | { 269 | fputc( ( ~c ) >> 24, fp ); 270 | fputc( ( ( ~c ) >> 16 ) & 255, fp ); 271 | fputc( ( ( ~c ) >> 8 ) & 255, fp ); 272 | fputc( ( ~c ) & 255, fp ); 273 | } 274 | 275 | fclose( fp ); 276 | } 277 | 278 | 279 | inline static ncnn::Mat decoder_solver( ncnn::Mat& sample, string assets_dir ) 280 | { 281 | ncnn::Net net; 282 | { 283 | net.opt.use_vulkan_compute = false; 284 | net.opt.use_winograd_convolution = false; 285 | net.opt.use_sgemm_convolution = false; 286 | net.opt.use_fp16_packed = true; 287 | net.opt.use_fp16_storage = true; 288 | net.opt.use_fp16_arithmetic = true; 289 | net.opt.use_packing_layout = true; 290 | std::filesystem::path param_path = std::filesystem::path(assets_dir) / std::filesystem::path("AutoencoderKL-fp16.param"); 291 | std::filesystem::path model_path = std::filesystem::path(assets_dir) / std::filesystem::path("AutoencoderKL-fp16.bin"); 292 | net.load_param(param_path.string().c_str()); 293 | net.load_model(model_path.string().c_str()); 294 | } 295 | ncnn::Mat x_samples_ddim; 296 | { 297 | constexpr float factor[4] = { 5.48998f, 5.48998f, 5.48998f, 5.48998f }; 298 | sample.substract_mean_normalize( 0, factor ); 299 | ncnn::Extractor ex = net.create_extractor(); 300 | ex.set_light_mode( true ); 301 | ex.input( "input.1", sample ); 302 | ex.extract( "815", x_samples_ddim ); 303 | constexpr float _mean_[3] = { -1.0f, -1.0f, -1.0f }; 304 | constexpr float _norm_[3] = { 127.5f, 127.5f, 127.5f }; 305 | x_samples_ddim.substract_mean_normalize( _mean_, _norm_ ); 306 | } 307 | return x_samples_ddim; 308 | } 309 | 310 | static inline ncnn::Mat randn_4_64_64( int seed ) 311 | { 312 | std::vector arr; 313 | { 314 | std::mt19937 gen{ static_cast( seed ) }; 315 | std::normal_distribution d{0.0f, 1.0f}; 316 | arr.resize( 64 * 64 * 4 ); 317 | std::for_each( arr.begin(), arr.end(), [&]( float & x ) 318 | { 319 | x = d( gen ); 320 | } ); 321 | } 322 | ncnn::Mat x_mat( 64, 64, 4, reinterpret_cast( arr.data() ) ); 323 | return x_mat.clone(); 324 | } 325 | 326 | static inline ncnn::Mat CFGDenoiser_CompVisDenoiser( ncnn::Net& net, float const* log_sigmas, ncnn::Mat& input, float sigma, ncnn::Mat cond, ncnn::Mat uncond ) 327 | { 328 | // get_scalings 329 | float c_out = -1.0 * sigma; 330 | float c_in = 1.0 / std::sqrt( sigma * sigma + 1 ); 331 | // sigma_to_t 332 | float log_sigma = std::log( sigma ); 333 | std::vector dists( 1000 ); 334 | 335 | for ( int i = 0; i < 1000; i++ ) 336 | { 337 | if ( log_sigma - log_sigmas[i] >= 0 ) 338 | dists[i] = 1; 339 | 340 | else 341 | dists[i] = 0; 342 | 343 | if ( i == 0 ) continue; 344 | 345 | dists[i] += dists[i - 1]; 346 | } 347 | 348 | int low_idx = std::min( int( std::max_element( dists.begin(), dists.end() ) - dists.begin() ), 1000 - 2 ); 349 | int high_idx = low_idx + 1; 350 | float low = log_sigmas[low_idx]; 351 | float high = log_sigmas[high_idx]; 352 | float w = ( low - log_sigma ) / ( low - high ); 353 | w = std::max( 0.f, std::min( 1.f, w ) ); 354 | float t = ( 1 - w ) * low_idx + w * high_idx; 355 | ncnn::Mat t_mat( 1 ); 356 | t_mat[0] = t; 357 | ncnn::Mat c_in_mat( 1 ); 358 | c_in_mat[0] = c_in; 359 | ncnn::Mat c_out_mat( 1 ); 360 | c_out_mat[0] = c_out; 361 | ncnn::Mat denoised_cond; 362 | { 363 | ncnn::Extractor ex = net.create_extractor(); 364 | ex.set_light_mode( true ); 365 | ex.input( "in0", input ); 366 | ex.input( "in1", t_mat ); 367 | ex.input( "in2", cond ); 368 | ex.input( "c_in", c_in_mat ); 369 | ex.input( "c_out", c_out_mat ); 370 | ex.extract( "outout", denoised_cond ); 371 | } 372 | ncnn::Mat denoised_uncond; 373 | { 374 | ncnn::Extractor ex = net.create_extractor(); 375 | ex.set_light_mode( true ); 376 | ex.input( "in0", input ); 377 | ex.input( "in1", t_mat ); 378 | ex.input( "in2", uncond ); 379 | ex.input( "c_in", c_in_mat ); 380 | ex.input( "c_out", c_out_mat ); 381 | ex.extract( "outout", denoised_uncond ); 382 | } 383 | 384 | for ( int c = 0; c < 4; c++ ) 385 | { 386 | float* u_ptr = denoised_uncond.channel( c ); 387 | float* c_ptr = denoised_cond.channel( c ); 388 | 389 | for ( int hw = 0; hw < 64 * 64; hw++ ) 390 | { 391 | ( *u_ptr ) = ( *u_ptr ) + 7 * ( ( *c_ptr ) - ( *u_ptr ) ); 392 | u_ptr++; 393 | c_ptr++; 394 | } 395 | } 396 | 397 | return denoised_uncond; 398 | } 399 | 400 | inline static ncnn::Mat diffusion_solver( int seed, int step, ncnn::Mat& c, ncnn::Mat& uc , string assets_dir) 401 | { 402 | ncnn::Net net; 403 | { 404 | net.opt.use_vulkan_compute = false; 405 | net.opt.use_winograd_convolution = false; 406 | net.opt.use_sgemm_convolution = false; 407 | net.opt.use_fp16_packed = true; 408 | net.opt.use_fp16_storage = true; 409 | net.opt.use_fp16_arithmetic = true; 410 | net.opt.use_packing_layout = true; 411 | std::filesystem::path param_path = std::filesystem::path(assets_dir) / std::filesystem::path("UNetModel-fp16.param"); 412 | std::filesystem::path model_path = std::filesystem::path(assets_dir) / std::filesystem::path("UNetModel-fp16.bin"); 413 | net.load_param(param_path.string().c_str()); 414 | net.load_model(model_path.string().c_str()); 415 | } 416 | float const log_sigmas[1000] = { -3.534698963f, -3.186542273f, -2.982215166f, -2.836785793f, -2.723614454f, -2.63086009f, -2.552189827f, -2.483832836f, -2.423344612f, -2.369071007f, -2.319822073f, -2.274721861f, -2.233105659f, -2.1944592f, -2.15836978f, -2.124504805f, -2.092598915f, -2.062425613f, -2.033797979f, -2.006558657f, -1.980568767f, -1.955715537f, -1.931894541f, -1.90902102f, -1.887015939f, -1.865811229f, -1.845347762f, -1.825569034f, -1.806429505f, -1.787884474f, -1.769894958f, -1.752426744f, -1.735446692f, -1.718925714f, -1.702836871f, -1.687156916f, -1.671862602f, -1.656933904f, -1.642351151f, -1.628097653f, -1.614156127f, -1.60051167f, -1.587151766f, -1.574060798f, -1.561229229f, -1.548643827f, -1.536295056f, -1.524172544f, -1.512266994f, -1.500569701f, -1.489071608f, -1.477766395f, -1.466645837f, -1.455702543f, -1.444930911f, -1.43432498f, -1.423877597f, -1.413584232f, -1.403439164f, -1.393437862f, -1.383575559f, -1.373847008f, -1.364248514f, -1.35477531f, -1.345424652f, -1.336191654f, -1.327073216f, -1.318066359f, -1.309167266f, -1.300373077f, -1.291680217f, -1.283086777f, -1.2745893f, -1.266185522f, -1.257872462f, -1.249648571f, -1.24151063f, -1.233456612f, -1.225485086f, -1.217592716f, -1.209778547f, -1.202040195f, -1.194375992f, -1.186783791f, -1.179262877f, -1.171809912f, -1.164424658f, -1.157105207f, -1.14985013f, -1.142657518f, -1.135526419f, -1.128455162f, -1.121442795f, -1.114487886f, -1.107589245f, -1.100745678f, -1.093955874f, -1.087218761f, -1.080533504f, -1.073898554f, -1.067313433f, -1.060776949f, -1.05428803f, -1.04784584f, -1.041449785f, -1.035098076f, -1.028790832f, -1.022526741f, -1.01630497f, -1.010124922f, -1.003985763f, -0.9978865385f, -0.9918267727f, -0.9858058691f, -0.9798227549f, -0.9738773108f, -0.9679679871f, -0.9620951414f, -0.9562574625f, -0.9504545927f, -0.9446860552f, -0.9389512539f, -0.9332492948f, -0.9275799394f, -0.9219425917f, -0.9163367748f, -0.910761714f, -0.9052170515f, -0.8997026086f, -0.8942174911f, -0.8887614012f, -0.8833341002f, -0.8779345155f, -0.8725628257f, -0.8672183752f, -0.8619008064f, -0.8566094041f, -0.851344347f, -0.8461046219f, -0.8408905864f, -0.8357009888f, -0.8305362463f, -0.8253954053f, -0.8202784657f, -0.8151849508f, -0.8101147413f, -0.8050670028f, -0.800041914f, -0.7950390577f, -0.7900577784f, -0.7850983739f, -0.7801600099f, -0.7752425075f, -0.7703458071f, -0.7654693723f, -0.7606129646f, -0.7557764053f, -0.7509595752f, -0.7461619377f, -0.7413833141f, -0.7366235256f, -0.7318821549f, -0.7271592021f, -0.7224541306f, -0.7177669406f, -0.7130974531f, -0.7084451318f, -0.7038100958f, -0.6991918087f, -0.6945903301f, -0.6900054812f, -0.6854367256f, -0.6808840632f, -0.6763471961f, -0.6718261838f, -0.6673204899f, -0.6628302932f, -0.658354938f, -0.6538946629f, -0.64944911f, -0.6450180411f, -0.6406013966f, -0.6361990571f, -0.6318103671f, -0.6274358034f, -0.6230751276f, -0.618727684f, -0.6143938303f, -0.61007303f, -0.6057654023f, -0.6014707088f, -0.597188592f, -0.5929191113f, -0.5886622667f, -0.5844176412f, -0.5801851153f, -0.5759648085f, -0.5717563629f, -0.5675594807f, -0.5633742213f, -0.5592005849f, -0.5550382137f, -0.5508873463f, -0.546747148f, -0.5426182747f, -0.5385001898f, -0.5343927145f, -0.5302959085f, -0.5262096524f, -0.5221338272f, -0.5180680752f, -0.5140126348f, -0.5099670291f, -0.5059314966f, -0.5019059181f, -0.4978898764f, -0.4938833714f, -0.4898864031f, -0.4858988523f, -0.4819207191f, -0.477951467f, -0.4739915431f, -0.4700405598f, -0.4660984278f, -0.4621651769f, -0.4582404494f, -0.4543244243f, -0.4504169226f, -0.4465178251f, -0.4426270425f, -0.438744545f, -0.4348701835f, -0.4310038984f, -0.4271455407f, -0.4232949913f, -0.4194523394f, -0.415617466f, -0.4117901921f, -0.4079704583f, -0.4041582644f, -0.4003533721f, -0.3965558708f, -0.3927654326f, -0.3889823556f, -0.3852062523f, -0.3814373016f, -0.3776751757f, -0.3739199638f, -0.3701713085f, -0.3664295971f, -0.3626944721f, -0.3589659631f, -0.3552440405f, -0.3515283465f, -0.3478190601f, -0.3441161811f, -0.3404195011f, -0.3367289305f, -0.3330446184f, -0.3293660879f, -0.3256936371f, -0.3220270574f, -0.3183663487f, -0.3147114515f, -0.3110622168f, -0.3074187338f, -0.3037807941f, -0.3001481593f, -0.296521306f, -0.2928997874f, -0.2892835438f, -0.2856727242f, -0.2820670605f, -0.2784664929f, -0.2748712003f, -0.271281004f, -0.2676956952f, -0.2641154826f, -0.2605400383f, -0.2569694519f, -0.2534037232f, -0.2498426586f, -0.2462864369f, -0.24273476f, -0.2391877174f, -0.2356451303f, -0.2321071476f, -0.2285735607f, -0.2250443399f, -0.2215195149f, -0.2179989815f, -0.214482531f, -0.210970372f, -0.2074623704f, -0.2039585114f, -0.2004585862f, -0.1969628185f, -0.1934709698f, -0.1899829209f, -0.1864988059f, -0.1830185503f, -0.1795420945f, -0.1760693043f, -0.172600165f, -0.1691347659f, -0.1656728834f, -0.1622146815f, -0.1587598771f, -0.1553086042f, -0.1518607438f, -0.1484163553f, -0.1449751854f, -0.1415374279f, -0.138102904f, -0.1346716136f, -0.1312434822f, -0.1278186142f, -0.1243967116f, -0.1209778786f, -0.1175621748f, -0.1141493395f, -0.1107395291f, -0.1073326841f, -0.1039286032f, -0.100527443f, -0.09712906927f, -0.09373350441f, -0.09034062177f, -0.08695036173f, -0.08356288075f, -0.08017785847f, -0.0767955035f, -0.07341576368f, -0.07003845274f, -0.06666365266f, -0.06329131871f, -0.05992120504f, -0.05655376986f, -0.05318845809f, -0.04982547462f, -0.04646483436f, -0.04310630262f, -0.03975001723f, -0.03639599681f, -0.03304407373f, -0.02969419584f, -0.02634644695f, -0.02300059609f, -0.01965690777f, -0.0163150914f, -0.01297534816f, -0.009637393057f, -0.006301366724f, -0.002967105946f, 0.0003651905863f, 0.00369580253f, 0.007024710067f, 0.01035177521f, 0.0136772152f, 0.01700089127f, 0.02032313682f, 0.02364357933f, 0.02696254663f, 0.03028002009f, 0.03359586f, 0.03691027686f, 0.04022336379f, 0.04353487119f, 0.04684500396f, 0.05015373603f, 0.05346116424f, 0.05676726624f, 0.060072124f, 0.06337571889f, 0.06667824835f, 0.06997924298f, 0.073279351f, 0.07657821476f, 0.07987590879f, 0.08317264169f, 0.08646827191f, 0.08976276964f, 0.0930563435f, 0.09634894878f, 0.09964046627f, 0.1029312909f, 0.1062210724f, 0.109510012f, 0.1127980649f, 0.1160853282f, 0.1193717569f, 0.1226574481f, 0.1259424686f, 0.1292266697f, 0.1325102597f, 0.1357929856f, 0.1390753537f, 0.1423569024f, 0.1456380188f, 0.1489184797f, 0.1521983445f, 0.155477792f, 0.1587566882f, 0.1620351225f, 0.1653131396f, 0.1685907096f, 0.1718678325f, 0.1751447469f, 0.1784212291f, 0.1816974431f, 0.1849732697f, 0.1882487684f, 0.1915241033f, 0.1947992444f, 0.198074162f, 0.2013489157f, 0.2046233714f, 0.2078978866f, 0.211172238f, 0.214446485f, 0.2177205831f, 0.2209947109f, 0.2242688239f, 0.2275429815f, 0.230817154f, 0.2340912819f, 0.2373655587f, 0.2406399995f, 0.2439144254f, 0.247189045f, 0.2504638135f, 0.2537388504f, 0.2570140362f, 0.2602894902f, 0.2635650337f, 0.266841054f, 0.2701171935f, 0.27339378f, 0.2766706944f, 0.2799479663f, 0.2832255363f, 0.2865035534f, 0.2897821367f, 0.2930608988f, 0.2963403165f, 0.2996201515f, 0.3029005229f, 0.306181401f, 0.3094629347f, 0.3127449751f, 0.3160274923f, 0.3193107247f, 0.322594583f, 0.3258791566f, 0.329164356f, 0.3324502707f, 0.335736841f, 0.3390242159f, 0.3423123956f, 0.3456012905f, 0.3488909006f, 0.352181375f, 0.3554728627f, 0.3587650359f, 0.3620581031f, 0.365352124f, 0.3686470091f, 0.371942848f, 0.3752396405f, 0.3785375357f, 0.3818363845f, 0.3851362169f, 0.3884370327f, 0.3917389512f, 0.3950420022f, 0.3983460069f, 0.4016513228f, 0.4049576223f, 0.408265233f, 0.4115738571f, 0.4148837626f, 0.4181949198f, 0.4215073586f, 0.4248209298f, 0.4281358421f, 0.4314520359f, 0.434769541f, 0.4380882978f, 0.441408515f, 0.444730103f, 0.448053062f, 0.4513774216f, 0.4547032118f, 0.4580304027f, 0.461359024f, 0.4646892846f, 0.4680209458f, 0.4713541865f, 0.4746888876f, 0.4780252576f, 0.4813631475f, 0.4847026467f, 0.4880437851f, 0.4913864136f, 0.4947308302f, 0.4980769157f, 0.5014246702f, 0.5047741532f, 0.5081253052f, 0.5114781857f, 0.5148329139f, 0.5181894302f, 0.5215476751f, 0.5249077678f, 0.5282697678f, 0.5316335559f, 0.5349991322f, 0.5383667946f, 0.5417361856f, 0.5451076627f, 0.5484809279f, 0.5518562794f, 0.5552335382f, 0.5586128235f, 0.5619941354f, 0.5653774738f, 0.5687628388f, 0.5721503496f, 0.5755399466f, 0.5789316297f, 0.5823253393f, 0.585721314f, 0.5891193748f, 0.5925196409f, 0.5959220529f, 0.5993267298f, 0.6027336717f, 0.6061428785f, 0.6095542312f, 0.6129679084f, 0.6163839698f, 0.6198022962f, 0.6232229471f, 0.6266459823f, 0.6300714016f, 0.6334991455f, 0.6369293332f, 0.6403619647f, 0.6437969804f, 0.6472345591f, 0.6506744623f, 0.6541169882f, 0.6575619578f, 0.6610094905f, 0.6644595861f, 0.6679121852f, 0.6713674068f, 0.6748251915f, 0.6782855988f, 0.6817486286f, 0.6852144003f, 0.688682735f, 0.6921537519f, 0.6956274509f, 0.6991039515f, 0.7025832534f, 0.7060650587f, 0.7095498443f, 0.713037312f, 0.7165275812f, 0.7200207114f, 0.7235167027f, 0.7270154953f, 0.730517149f, 0.7340217829f, 0.7375292182f, 0.7410396338f, 0.7445529699f, 0.7480692267f, 0.7515884042f, 0.7551106215f, 0.7586359382f, 0.7621641755f, 0.7656953931f, 0.7692299485f, 0.7727673054f, 0.7763077617f, 0.779851377f, 0.7833981514f, 0.786947906f, 0.7905010581f, 0.7940571904f, 0.7976165414f, 0.8011791706f, 0.8047449589f, 0.8083140254f, 0.8118864298f, 0.8154619932f, 0.8190407753f, 0.8226229548f, 0.8262084126f, 0.8297972679f, 0.833389461f, 0.836984992f, 0.8405839205f, 0.8441862464f, 0.8477919698f, 0.8514010906f, 0.8550137877f, 0.8586298227f, 0.8622494936f, 0.8658725023f, 0.8694992065f, 0.8731292486f, 0.8767629862f, 0.8804001808f, 0.8840410113f, 0.8876854181f, 0.8913334608f, 0.894985199f, 0.8986404538f, 0.9022994041f, 0.9059621096f, 0.9096283317f, 0.9132984877f, 0.9169722795f, 0.9206498265f, 0.924331069f, 0.9280161858f, 0.9317050576f, 0.9353976846f, 0.9390941858f, 0.9427945018f, 0.9464985728f, 0.9502066374f, 0.9539185762f, 0.957634449f, 0.9613542557f, 0.9650779963f, 0.9688056707f, 0.9725371599f, 0.9762728214f, 0.9800124764f, 0.983756125f, 0.9875037074f, 0.9912554026f, 0.9950110912f, 0.9987710118f, 1.002534866f, 1.006302953f, 1.010075092f, 1.013851404f, 1.017631888f, 1.021416545f, 1.025205374f, 1.028998375f, 1.032795668f, 1.036597133f, 1.040402889f, 1.044212818f, 1.048027158f, 1.051845789f, 1.055668592f, 1.059495926f, 1.063327432f, 1.067163467f, 1.071003675f, 1.074848413f, 1.078697562f, 1.082551122f, 1.086408973f, 1.090271473f, 1.094138384f, 1.098009825f, 1.101885676f, 1.105766058f, 1.109651089f, 1.113540411f, 1.117434502f, 1.121333122f, 1.125236511f, 1.129144192f, 1.13305676f, 1.136973858f, 1.140895605f, 1.144822001f, 1.148753166f, 1.15268898f, 1.156629443f, 1.160574675f, 1.164524794f, 1.168479443f, 1.172439098f, 1.176403403f, 1.180372596f, 1.184346557f, 1.188325405f, 1.192309022f, 1.196297526f, 1.200290918f, 1.204289317f, 1.208292484f, 1.212300658f, 1.216313839f, 1.220331907f, 1.224354982f, 1.228383064f, 1.232415915f, 1.236454129f, 1.240497231f, 1.244545341f, 1.248598576f, 1.252656817f, 1.256720304f, 1.260788798f, 1.264862418f, 1.268941164f, 1.273025036f, 1.277114153f, 1.281208396f, 1.285307884f, 1.289412618f, 1.293522477f, 1.297637701f, 1.301758051f, 1.305883765f, 1.310014725f, 1.314151049f, 1.318292618f, 1.322439551f, 1.326591969f, 1.330749512f, 1.334912539f, 1.33908093f, 1.343254805f, 1.347433925f, 1.351618767f, 1.355808854f, 1.360004425f, 1.36420548f, 1.368412018f, 1.372624159f, 1.376841784f, 1.381064892f, 1.385293603f, 1.389527798f, 1.393767595f, 1.398013115f, 1.402264118f, 1.406520724f, 1.410783052f, 1.415050983f, 1.419324636f, 1.423603892f, 1.427888989f, 1.43217957f, 1.436476111f, 1.440778255f, 1.445086241f, 1.449399829f, 1.453719258f, 1.458044529f, 1.462375641f, 1.466712594f, 1.471055388f, 1.475403905f, 1.479758382f, 1.484118819f, 1.488484859f, 1.492857099f, 1.497235179f, 1.50161922f, 1.506009102f, 1.510405064f, 1.514806986f, 1.519214869f, 1.523628831f, 1.528048754f, 1.532474637f, 1.536906719f, 1.541344643f, 1.545788884f, 1.550239086f, 1.554695368f, 1.559157968f, 1.563626409f, 1.568101287f, 1.572582126f, 1.577069283f, 1.581562519f, 1.586061954f, 1.590567589f, 1.595079541f, 1.599597573f, 1.604121923f, 1.608652592f, 1.613189697f, 1.617732882f, 1.622282386f, 1.626838207f, 1.631400466f, 1.635969043f, 1.640543938f, 1.645125389f, 1.649713039f, 1.654307127f, 1.658907652f, 1.663514495f, 1.668127894f, 1.67274785f, 1.677374125f, 1.682006836f, 1.686646223f, 1.691291928f, 1.695944309f, 1.700603247f, 1.705268621f, 1.709940553f, 1.71461916f, 1.719304323f, 1.723996043f, 1.728694439f, 1.733399391f, 1.738111019f, 1.742829323f, 1.747554302f, 1.752285957f, 1.757024288f, 1.761769295f, 1.766520977f, 1.771279573f, 1.776044846f, 1.780816793f, 1.785595655f, 1.790381074f, 1.795173526f, 1.799972653f, 1.804778576f, 1.809591532f, 1.814411163f, 1.819237709f, 1.82407105f, 1.828911304f, 1.833758473f, 1.838612676f, 1.843473673f, 1.848341703f, 1.853216529f, 1.858098507f, 1.86298728f, 1.867883086f, 1.872785926f, 1.877695799f, 1.882612705f, 1.887536645f, 1.892467618f, 1.897405624f, 1.902350664f, 1.907302856f, 1.912262201f, 1.91722858f, 1.92220211f, 1.927182794f, 1.93217051f, 1.937165499f, 1.94216764f, 1.947176933f, 1.952193499f, 1.957217097f, 1.962248087f, 1.967286348f, 1.972331762f, 1.977384448f, 1.982444406f, 1.987511516f, 1.992586017f, 1.997667909f, 2.002757072f, 2.007853508f, 2.012957335f, 2.018068552f, 2.023186922f, 2.028312922f, 2.033446312f, 2.038586855f, 2.043735027f, 2.048890591f, 2.054053545f, 2.05922389f, 2.064401865f, 2.069587231f, 2.074779987f, 2.079980135f, 2.08518815f, 2.090403318f, 2.095626354f, 2.100856543f, 2.106094599f, 2.111340046f, 2.116593361f, 2.121853828f, 2.127122164f, 2.132398129f, 2.137681484f, 2.142972708f, 2.148271322f, 2.153577805f, 2.158891916f, 2.164213657f, 2.169543266f, 2.174880266f, 2.180225134f, 2.185577631f, 2.190937996f, 2.19630599f, 2.201681852f, 2.207065582f, 2.212456942f, 2.21785593f, 2.223263025f, 2.22867775f, 2.234100103f, 2.239530563f, 2.244968891f, 2.250414848f, 2.255868912f, 2.261330843f, 2.266800642f, 2.27227807f, 2.277763605f, 2.283257008f, 2.288758516f, 2.294267654f, 2.299785137f, 2.305310249f, 2.310843468f, 2.316384792f, 2.321933746f, 2.327491045f, 2.333056211f, 2.338629484f, 2.344210625f, 2.34980011f, 2.355397224f, 2.361002684f, 2.366616249f, 2.372237921f, 2.37786746f, 2.383505344f, 2.389151335f, 2.394805431f, 2.400467634f, 2.406137943f, 2.411816359f, 2.41750288f, 2.423197985f, 2.428900957f, 2.434612274f, 2.440331697f, 2.446059465f, 2.45179534f, 2.457539558f, 2.463291883f, 2.469052553f, 2.474821568f, 2.480598688f, 2.486384153f, 2.492177963f, 2.497980118f, 2.503790617f, 2.509609461f, 2.515436649f, 2.521272182f, 2.527115822f, 2.532968283f, 2.53882885f, 2.544697762f, 2.550575256f, 2.556461096f, 2.56235528f, 2.568258047f, 2.574169159f, 2.580088615f, 2.586016655f, 2.591953278f, 2.597898245f, 2.603851557f, 2.60981369f, 2.615784168f, 2.621763229f, 2.627750635f, 2.633746862f, 2.639751434f, 2.645764589f, 2.651786327f, 2.657816648f, 2.663855553f, 2.66990304f, 2.67595911f, 2.682024002f }; 417 | 418 | ncnn::Mat x_mat = randn_4_64_64( seed % 1000 ); 419 | // t_to_sigma 420 | std::vector sigma( step ); 421 | float delta = - 999.0f / ( step - 1 ); 422 | 423 | for ( int i = 0; i < step; i++ ) 424 | { 425 | float t = 999.0 + i * delta; 426 | int low_idx = std::floor( t ); 427 | int high_idx = std::ceil( t ); 428 | float w = t - low_idx; 429 | sigma[i] = std::exp( ( 1 - w ) * log_sigmas[low_idx] + w * log_sigmas[high_idx] ); 430 | } 431 | 432 | sigma.push_back( 0.f ); 433 | float _norm_[4] = { sigma[0], sigma[0], sigma[0], sigma[0] }; 434 | x_mat.substract_mean_normalize( 0, _norm_ ); 435 | // sample_euler_ancestral 436 | { 437 | for ( int i = 0; i < static_cast(sigma.size()) - 1; i++ ) 438 | { 439 | std::cout << "step:" << i << "\t\t"; 440 | double t1 = ncnn::get_current_time(); 441 | ncnn::Mat denoised = CFGDenoiser_CompVisDenoiser( net, log_sigmas, x_mat, sigma[i], c, uc ); 442 | double t2 = ncnn::get_current_time(); 443 | std::cout << t2 - t1 << "ms" << std::endl; 444 | float sigma_up = std::min( sigma[i + 1], std::sqrt( sigma[i + 1] * sigma[i + 1] * ( sigma[i] * sigma[i] - sigma[i + 1] * sigma[i + 1] ) / ( sigma[i] * sigma[i] ) ) ); 445 | float sigma_down = std::sqrt( sigma[i + 1] * sigma[i + 1] - sigma_up * sigma_up ); 446 | std::srand( std::time( NULL ) ); 447 | ncnn::Mat randn = randn_4_64_64( rand() % 1000 ); 448 | 449 | for ( int c = 0; c < 4; c++ ) 450 | { 451 | float* x_ptr = x_mat.channel( c ); 452 | float* d_ptr = denoised.channel( c ); 453 | float* r_ptr = randn.channel( c ); 454 | 455 | for ( int hw = 0; hw < 64 * 64; hw++ ) 456 | { 457 | *x_ptr = *x_ptr + ( ( *x_ptr - *d_ptr ) / sigma[i] ) * ( sigma_down - sigma[i] ) + *r_ptr * sigma_up; 458 | x_ptr++; 459 | d_ptr++; 460 | r_ptr++; 461 | } 462 | } 463 | } 464 | } 465 | ncnn::Mat fuck_x; 466 | fuck_x.clone_from( x_mat ); 467 | return fuck_x; 468 | } 469 | 470 | 471 | inline static ncnn::Mat esr4x( ncnn::Mat& input, string assets_dir ) 472 | { 473 | ncnn::Net net; 474 | { 475 | net.opt.use_vulkan_compute = false; 476 | net.opt.use_winograd_convolution = false; 477 | net.opt.use_sgemm_convolution = false; 478 | net.opt.use_fp16_packed = false; 479 | net.opt.use_fp16_storage = false; 480 | net.opt.use_fp16_arithmetic = false; 481 | net.opt.use_packing_layout = true; 482 | std::filesystem::path param_path = std::filesystem::path(assets_dir) / std::filesystem::path("RealESRGAN_x4plus_anime_6B.fp32-sim-sim-opt.param"); 483 | std::filesystem::path model_path = std::filesystem::path(assets_dir) / std::filesystem::path("RealESRGAN_x4plus_anime_6B.fp32-sim-sim-opt.bin"); 484 | net.load_param(param_path.string().c_str()); 485 | net.load_model(model_path.string().c_str()); 486 | } 487 | ncnn::Extractor ex = net.create_extractor(); 488 | ex.set_light_mode( true ); 489 | { 490 | constexpr float mean[] = {0.0f, 0.0f, 0.0f}; 491 | constexpr float norm[] = {1.0f/255.0f, 1.0f/255.0f, 1.0f/255.0f}; 492 | input.substract_mean_normalize( mean, norm ); 493 | } 494 | ex.input( "data", input ); 495 | ncnn::Mat ans; 496 | ex.extract( "output", ans ); 497 | { 498 | constexpr float mean[] = {0.0f, 0.0f, 0.0f}; 499 | constexpr float norm[] = {255.0f, 255.0f, 255.0f}; 500 | ans.substract_mean_normalize( mean, norm ); 501 | } 502 | return ans; 503 | } 504 | 505 | inline static std::vector> parse_prompt_attention( std::string& texts ) 506 | { 507 | std::vector> res; 508 | std::stack round_brackets; 509 | std::stack square_brackets; 510 | const float round_bracket_multiplier = 1.1; 511 | const float square_bracket_multiplier = 1 / 1.1; 512 | std::vector ms; 513 | 514 | for ( char c : texts ) 515 | { 516 | std::string s = std::string( 1, c ); 517 | 518 | if ( s == "(" || s == "[" || s == ")" || s == "]" ) 519 | { 520 | ms.push_back( s ); 521 | } 522 | 523 | else 524 | { 525 | if ( ms.size() < 1 ) 526 | ms.push_back( "" ); 527 | 528 | std::string last = ms[ms.size() - 1]; 529 | 530 | if ( last == "(" || last == "[" || last == ")" || last == "]" ) 531 | { 532 | ms.push_back( "" ); 533 | } 534 | 535 | ms[ms.size() - 1] += s; 536 | } 537 | } 538 | 539 | for ( std::string text : ms ) 540 | { 541 | if ( text == "(" ) 542 | { 543 | round_brackets.push( res.size() ); 544 | } 545 | 546 | else if ( text == "[" ) 547 | { 548 | square_brackets.push( res.size() ); 549 | } 550 | 551 | else if ( text == ")" && round_brackets.size() > 0 ) 552 | { 553 | for ( unsigned long p = round_brackets.top(); p < res.size(); p++ ) 554 | { 555 | res[p].second *= round_bracket_multiplier; 556 | } 557 | 558 | round_brackets.pop(); 559 | } 560 | 561 | else if ( text == "]" and square_brackets.size() > 0 ) 562 | { 563 | for ( unsigned long p = square_brackets.top(); p < res.size(); p++ ) 564 | { 565 | res[p].second *= square_bracket_multiplier; 566 | } 567 | 568 | square_brackets.pop(); 569 | } 570 | 571 | else 572 | { 573 | res.push_back( make_pair( text, 1.0 ) ); 574 | } 575 | } 576 | 577 | while ( !round_brackets.empty() ) 578 | { 579 | for ( unsigned long p = round_brackets.top(); p < res.size(); p++ ) 580 | { 581 | res[p].second *= round_bracket_multiplier; 582 | } 583 | 584 | round_brackets.pop(); 585 | } 586 | 587 | while ( !square_brackets.empty() ) 588 | { 589 | for ( unsigned long p = square_brackets.top(); p < res.size(); p++ ) 590 | { 591 | res[p].second *= square_bracket_multiplier; 592 | } 593 | 594 | square_brackets.pop(); 595 | } 596 | 597 | unsigned long i = 0; 598 | 599 | while ( i + 1 < res.size() ) 600 | { 601 | if ( res[i].second == res[i + 1].second ) 602 | { 603 | res[i].first += res[i + 1].first; 604 | auto it = res.begin(); 605 | res.erase( it + i + 1 ); 606 | } 607 | 608 | else 609 | { 610 | i += 1; 611 | } 612 | } 613 | 614 | return res; 615 | } 616 | 617 | 618 | inline static std::vector split( std::string str ) 619 | { 620 | std::string::size_type pos; 621 | std::vector result; 622 | str += " "; 623 | int size = str.size(); 624 | 625 | for ( int i = 0; i < size; i++ ) 626 | { 627 | pos = std::min( str.find( " ", i ), str.find( ",", i ) ); 628 | 629 | if ( pos < str.size() ) 630 | { 631 | std::string s = str.substr( i, pos - i ); 632 | std::string pat = std::string( 1, str[pos] ); 633 | 634 | if ( s.length() > 0 ) 635 | result.push_back( s + "" ); 636 | 637 | if ( pat != " " ) 638 | result.push_back( pat + "" ); 639 | 640 | i = pos; 641 | } 642 | } 643 | 644 | return result; 645 | } 646 | 647 | inline static ncnn::Mat prompt_solve( std::unordered_map& tokenizer_token2idx, ncnn::Net& net, std::string prompt ) 648 | { 649 | 650 | // 重要度计算可以匹配“()”和“[]”,圆括号是加重要度,方括号是减重要度 651 | std::vector> parsed = parse_prompt_attention( prompt ); 652 | // token转ids 653 | std::vector> tokenized; 654 | { 655 | for ( auto p : parsed ) 656 | { 657 | std::vector tokens = split( p.first ); 658 | std::vector ids; 659 | 660 | for ( std::string token : tokens ) 661 | { 662 | ids.push_back( tokenizer_token2idx[token] ); 663 | } 664 | 665 | tokenized.push_back( ids ); 666 | } 667 | } 668 | 669 | // 一些处理 670 | std::vector remade_tokens; 671 | std::vector multipliers; 672 | { 673 | int last_comma = -1; 674 | 675 | for ( unsigned long it_tokenized = 0; it_tokenized < tokenized.size(); it_tokenized++ ) 676 | { 677 | std::vector tokens = tokenized[it_tokenized]; 678 | float weight = parsed[it_tokenized].second; 679 | unsigned long i = 0; 680 | 681 | while ( i < tokens.size() ) 682 | { 683 | int token = tokens[i]; 684 | 685 | if ( token == 267 ) 686 | { 687 | last_comma = remade_tokens.size(); 688 | } 689 | 690 | else if ( ( std::max( int( remade_tokens.size() ), 1 ) % 75 == 0 ) && ( last_comma != -1 ) && ( remade_tokens.size() - last_comma <= 20 ) ) 691 | { 692 | last_comma += 1; 693 | std::vector reloc_tokens( remade_tokens.begin() + last_comma, remade_tokens.end() ); 694 | std::vector reloc_mults( multipliers.begin() + last_comma, multipliers.end() ); 695 | std::vector _remade_tokens_( remade_tokens.begin(), remade_tokens.begin() + last_comma ); 696 | remade_tokens = _remade_tokens_; 697 | int length = remade_tokens.size(); 698 | int rem = std::ceil( length / 75.0 ) * 75 - length; 699 | std::vector tmp_token( rem, 49407 ); 700 | remade_tokens.insert( remade_tokens.end(), tmp_token.begin(), tmp_token.end() ); 701 | remade_tokens.insert( remade_tokens.end(), reloc_tokens.begin(), reloc_tokens.end() ); 702 | std::vector _multipliers_( multipliers.begin(), multipliers.end() + last_comma ); 703 | std::vector tmp_multipliers( rem, 1.0f ); 704 | _multipliers_.insert( _multipliers_.end(), tmp_multipliers.begin(), tmp_multipliers.end() ); 705 | _multipliers_.insert( _multipliers_.end(), reloc_mults.begin(), reloc_mults.end() ); 706 | multipliers = _multipliers_; 707 | } 708 | 709 | remade_tokens.push_back( token ); 710 | multipliers.push_back( weight ); 711 | i += 1; 712 | } 713 | } 714 | 715 | int prompt_target_length = std::ceil( std::max( int( remade_tokens.size() ), 1 ) / 75.0 ) * 75; 716 | int tokens_to_add = prompt_target_length - remade_tokens.size(); 717 | std::vector tmp_token( tokens_to_add, 49407 ); 718 | remade_tokens.insert( remade_tokens.end(), tmp_token.begin(), tmp_token.end() ); 719 | std::vector tmp_multipliers( tokens_to_add, 1.0f ); 720 | multipliers.insert( multipliers.end(), tmp_multipliers.begin(), tmp_multipliers.end() ); 721 | } 722 | // 切分 723 | ncnn::Mat conds( 768, 0 ); 724 | { 725 | while ( remade_tokens.size() > 0 ) 726 | { 727 | std::vector rem_tokens( remade_tokens.begin() + 75, remade_tokens.end() ); 728 | std::vector rem_multipliers( multipliers.begin() + 75, multipliers.end() ); 729 | std::vector current_tokens; 730 | std::vector current_multipliers; 731 | 732 | if ( remade_tokens.size() > 0 ) 733 | { 734 | current_tokens.insert( current_tokens.end(), remade_tokens.begin(), remade_tokens.begin() + 75 ); 735 | current_multipliers.insert( current_multipliers.end(), multipliers.begin(), multipliers.begin() + 75 ); 736 | } 737 | 738 | else 739 | { 740 | std::vector tmp_token( 75, 49407 ); 741 | current_tokens.insert( current_tokens.end(), tmp_token.begin(), tmp_token.end() ); 742 | std::vector tmp_multipliers( 75, 1.0f ); 743 | current_multipliers.insert( current_multipliers.end(), tmp_multipliers.begin(), tmp_multipliers.end() ); 744 | } 745 | 746 | { 747 | ncnn::Mat token_mat = ncnn::Mat( 77 ); 748 | token_mat.fill( int( 49406 ) ); 749 | ncnn::Mat multiplier_mat = ncnn::Mat( 77 ); 750 | multiplier_mat.fill( 1.0f ); 751 | int* token_ptr = token_mat; 752 | float* multiplier_ptr = multiplier_mat; 753 | 754 | for ( int i = 0; i < 75; i++ ) 755 | { 756 | token_ptr[i + 1] = int( current_tokens[i] ); 757 | multiplier_ptr[i + 1] = current_multipliers[i]; 758 | } 759 | 760 | ncnn::Extractor ex = net.create_extractor(); 761 | ex.set_light_mode( true ); 762 | ex.input( "token", token_mat ); 763 | ex.input( "multiplier", multiplier_mat ); 764 | ex.input( "cond", conds ); 765 | ncnn::Mat new_conds; 766 | ex.extract( "conds", new_conds ); 767 | conds = new_conds; 768 | } 769 | 770 | remade_tokens = rem_tokens; 771 | multipliers = rem_multipliers; 772 | } 773 | } 774 | return conds; 775 | } 776 | 777 | inline static std::pair prompt_solver( std::string const& prompt_positive, std::string const& prompt_negative, string assets_dir ) 778 | { 779 | std::unordered_map tokenizer_token2idx; 780 | ncnn::Net net; 781 | { 782 | // 加载CLIP模型 783 | net.opt.use_vulkan_compute = false; 784 | net.opt.use_winograd_convolution = false; 785 | net.opt.use_sgemm_convolution = false; 786 | net.opt.use_fp16_packed = true; 787 | net.opt.use_fp16_storage = true; 788 | net.opt.use_fp16_arithmetic = true; 789 | net.opt.use_packing_layout = true; 790 | 791 | std::filesystem::path param_path = std::filesystem::path(assets_dir) / std::filesystem::path("FrozenCLIPEmbedder-fp16.param"); 792 | std::filesystem::path model_path = std::filesystem::path(assets_dir) / std::filesystem::path("FrozenCLIPEmbedder-fp16.bin"); 793 | net.load_param(param_path.string().c_str()); 794 | net.load_model(model_path.string().c_str()); 795 | // 读取tokenizer字典 796 | std::ifstream infile; 797 | std::filesystem::path vocab_path = std::filesystem::path(assets_dir) / std::filesystem::path("vocab.txt"); 798 | std::string pathname = vocab_path.string(); 799 | infile.open( pathname.data() ); 800 | std::string s; 801 | int idx = 0; 802 | 803 | while ( getline( infile, s ) ) 804 | { 805 | tokenizer_token2idx.insert( std::pair( s, idx ) ); 806 | idx++; 807 | } 808 | infile.close(); 809 | } 810 | 811 | return std::make_pair( prompt_solve( tokenizer_token2idx, net, prompt_positive ), prompt_solve( tokenizer_token2idx, net, prompt_negative ) ); 812 | } 813 | --------------------------------------------------------------------------------