├── CODEOWNERS ├── .gitignore ├── discord_bot ├── interface.go └── discord_bot.go ├── composite_renderer ├── interface.go └── renderer.go ├── entities ├── default_settings.go └── image_generation.go ├── stable_diffusion_api ├── interface.go └── stable_diffusion.go ├── clock └── clock.go ├── repositories ├── default_settings │ ├── interface.go │ └── sqlite.go ├── errors.go └── image_generations │ ├── interface.go │ └── sqlite.go ├── .github └── workflows │ ├── golangci-lint.yaml │ └── release.yml ├── imagine_queue ├── interface.go └── queue.go ├── LICENSE ├── go.mod ├── main.go ├── README.md ├── databases └── sqlite │ └── sqlite.go └── go.sum /CODEOWNERS: -------------------------------------------------------------------------------- 1 | *.* @AndBobsYourUncle -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | stable_diffusion_bot 2 | .idea/ 3 | sd_discord_bot.sqlite 4 | -------------------------------------------------------------------------------- /discord_bot/interface.go: -------------------------------------------------------------------------------- 1 | package discord_bot 2 | 3 | type Bot interface { 4 | Start() 5 | } 6 | -------------------------------------------------------------------------------- /composite_renderer/interface.go: -------------------------------------------------------------------------------- 1 | package composite_renderer 2 | 3 | import "bytes" 4 | 5 | type Renderer interface { 6 | TileImages(imageBufs []*bytes.Buffer) (*bytes.Buffer, error) 7 | } 8 | -------------------------------------------------------------------------------- /entities/default_settings.go: -------------------------------------------------------------------------------- 1 | package entities 2 | 3 | type DefaultSettings struct { 4 | MemberID string `json:"member_id"` 5 | Width int `json:"width"` 6 | Height int `json:"height"` 7 | BatchCount int `json:"batch_count"` 8 | BatchSize int `json:"batch_size"` 9 | } 10 | -------------------------------------------------------------------------------- /stable_diffusion_api/interface.go: -------------------------------------------------------------------------------- 1 | package stable_diffusion_api 2 | 3 | type StableDiffusionAPI interface { 4 | TextToImage(req *TextToImageRequest) (*TextToImageResponse, error) 5 | UpscaleImage(upscaleReq *UpscaleRequest) (*UpscaleResponse, error) 6 | GetCurrentProgress() (*ProgressResponse, error) 7 | } 8 | -------------------------------------------------------------------------------- /clock/clock.go: -------------------------------------------------------------------------------- 1 | package clock 2 | 3 | import "time" 4 | 5 | //go:generate mockgen -destination=mock/mock.go -package=mock_clock -source=clock.go 6 | 7 | type Clock interface { 8 | Now() time.Time 9 | } 10 | 11 | type realClock struct{} 12 | 13 | func (realClock) Now() time.Time { 14 | return time.Now() 15 | } 16 | 17 | func NewClock() Clock { 18 | return &realClock{} 19 | } 20 | -------------------------------------------------------------------------------- /repositories/default_settings/interface.go: -------------------------------------------------------------------------------- 1 | package default_settings 2 | 3 | import ( 4 | "context" 5 | "stable_diffusion_bot/entities" 6 | ) 7 | 8 | type Repository interface { 9 | Upsert(ctx context.Context, setting *entities.DefaultSettings) (*entities.DefaultSettings, error) 10 | GetByMemberID(ctx context.Context, memberID string) (*entities.DefaultSettings, error) 11 | } 12 | -------------------------------------------------------------------------------- /repositories/errors.go: -------------------------------------------------------------------------------- 1 | package repositories 2 | 3 | import "fmt" 4 | 5 | type NotFoundError struct { 6 | entityName string 7 | } 8 | 9 | func NewNotFoundError(entityName string) *NotFoundError { 10 | return &NotFoundError{entityName: entityName} 11 | } 12 | 13 | func (m *NotFoundError) Error() string { 14 | return fmt.Sprintf("%s not found", m.entityName) 15 | } 16 | 17 | func (e *NotFoundError) Is(err error) bool { 18 | _, ok := err.(*NotFoundError) 19 | return ok 20 | } 21 | -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yaml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | on: 3 | pull_request: 4 | permissions: 5 | contents: read 6 | pull-requests: read 7 | jobs: 8 | golangci: 9 | name: lint 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/setup-go@v3 13 | with: 14 | go-version: 1.19 15 | - uses: actions/checkout@v3 16 | - name: golangci-lint 17 | uses: golangci/golangci-lint-action@v3 18 | with: 19 | version: latest -------------------------------------------------------------------------------- /repositories/image_generations/interface.go: -------------------------------------------------------------------------------- 1 | package image_generations 2 | 3 | import ( 4 | "context" 5 | "stable_diffusion_bot/entities" 6 | ) 7 | 8 | type Repository interface { 9 | Create(ctx context.Context, generation *entities.ImageGeneration) (*entities.ImageGeneration, error) 10 | GetByMessage(ctx context.Context, messageID string) (*entities.ImageGeneration, error) 11 | GetByMessageAndSort(ctx context.Context, messageID string, sortOrder int) (*entities.ImageGeneration, error) 12 | } 13 | -------------------------------------------------------------------------------- /imagine_queue/interface.go: -------------------------------------------------------------------------------- 1 | package imagine_queue 2 | 3 | import ( 4 | "stable_diffusion_bot/entities" 5 | 6 | "github.com/bwmarrin/discordgo" 7 | ) 8 | 9 | type Queue interface { 10 | AddImagine(item *QueueItem) (int, error) 11 | StartPolling(botSession *discordgo.Session) 12 | GetBotDefaultSettings() (*entities.DefaultSettings, error) 13 | UpdateDefaultDimensions(width, height int) (*entities.DefaultSettings, error) 14 | UpdateDefaultBatch(batchCount, batchSize int) (*entities.DefaultSettings, error) 15 | } 16 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # .github/workflows/release.yaml 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | releases-matrix: 9 | name: Release Go Binary 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | # build and publish in parallel: linux/amd64, linux/arm64, windows/amd64, darwin/amd64, darwin/arm64 14 | goos: [linux, windows, darwin] 15 | goarch: [amd64, arm64] 16 | exclude: 17 | - goarch: arm64 18 | goos: windows 19 | steps: 20 | - uses: actions/checkout@v3 21 | - uses: wangyoucao577/go-release-action@v1.34 22 | with: 23 | github_token: ${{ secrets.GITHUB_TOKEN }} 24 | goos: ${{ matrix.goos }} 25 | goarch: ${{ matrix.goarch }} 26 | goversion: "https://dl.google.com/go/go1.19.1.linux-amd64.tar.gz" 27 | binary_name: "stable_diffusion_bot" 28 | extra_files: LICENSE README.md 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Nicholas Page 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module stable_diffusion_bot 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/bwmarrin/discordgo v0.26.1 7 | modernc.org/sqlite v1.20.1 8 | ) 9 | 10 | require ( 11 | github.com/dustin/go-humanize v1.0.0 // indirect 12 | github.com/google/uuid v1.3.0 // indirect 13 | github.com/gorilla/websocket v1.4.2 // indirect 14 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect 15 | github.com/mattn/go-isatty v0.0.16 // indirect 16 | github.com/mattn/go-sqlite3 v1.14.16 // indirect 17 | github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect 18 | golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b // indirect 19 | golang.org/x/mod v0.3.0 // indirect 20 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab // indirect 21 | golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 // indirect 22 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect 23 | lukechampine.com/uint128 v1.2.0 // indirect 24 | modernc.org/cc/v3 v3.40.0 // indirect 25 | modernc.org/ccgo/v3 v3.16.13 // indirect 26 | modernc.org/libc v1.22.2 // indirect 27 | modernc.org/mathutil v1.5.0 // indirect 28 | modernc.org/memory v1.4.0 // indirect 29 | modernc.org/opt v0.1.3 // indirect 30 | modernc.org/strutil v1.1.3 // indirect 31 | modernc.org/token v1.0.1 // indirect 32 | ) 33 | -------------------------------------------------------------------------------- /entities/image_generation.go: -------------------------------------------------------------------------------- 1 | package entities 2 | 3 | import "time" 4 | 5 | type ImageGeneration struct { 6 | ID int64 `json:"id"` 7 | InteractionID string `json:"interaction_id"` 8 | MessageID string `json:"message_id"` 9 | MemberID string `json:"member_id"` 10 | SortOrder int `json:"sort_order"` 11 | Prompt string `json:"prompt"` 12 | NegativePrompt string `json:"negative_prompt"` 13 | Width int `json:"width"` 14 | Height int `json:"height"` 15 | RestoreFaces bool `json:"restore_faces"` 16 | EnableHR bool `json:"enable_hr"` 17 | HiresWidth int `json:"hires_width"` 18 | HiresHeight int `json:"hires_height"` 19 | DenoisingStrength float64 `json:"denoising_strength"` 20 | BatchCount int `json:"batch_count"` 21 | BatchSize int `json:"batch_size"` 22 | Seed int `json:"seed"` 23 | Subseed int `json:"subseed"` 24 | SubseedStrength float64 `json:"subseed_strength"` 25 | SamplerName string `json:"sampler_name"` 26 | CfgScale float64 `json:"cfg_scale"` 27 | Steps int `json:"steps"` 28 | Processed bool `json:"processed"` 29 | CreatedAt time.Time `json:"created_at"` 30 | } 31 | -------------------------------------------------------------------------------- /composite_renderer/renderer.go: -------------------------------------------------------------------------------- 1 | package composite_renderer 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "image" 7 | "image/draw" 8 | "image/png" 9 | ) 10 | 11 | type rendererImpl struct{} 12 | 13 | type Config struct{} 14 | 15 | func New(cfg Config) (Renderer, error) { 16 | return &rendererImpl{}, nil 17 | } 18 | 19 | func (r *rendererImpl) TileImages(imageBufs []*bytes.Buffer) (*bytes.Buffer, error) { 20 | if len(imageBufs) != 4 { 21 | return nil, errors.New("invalid number of images") 22 | } 23 | 24 | images := make([]image.Image, 4) 25 | 26 | for i, buf := range imageBufs { 27 | img, _, err := image.Decode(buf) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | images[i] = img 33 | } 34 | 35 | firstBounds := images[0].Bounds() 36 | 37 | for _, img := range images { 38 | if img.Bounds() != firstBounds { 39 | return nil, errors.New("images are not the same size") 40 | } 41 | } 42 | 43 | retImage := image.NewRGBA(image.Rect(0, 0, firstBounds.Max.X*2, firstBounds.Max.Y*2)) 44 | 45 | draw.Draw(retImage, images[0].Bounds().Add(image.Pt(0, 0)), images[0], image.Point{}, draw.Over) 46 | draw.Draw(retImage, images[1].Bounds().Add(image.Pt(firstBounds.Max.X, 0)), images[1], image.Point{}, draw.Over) 47 | draw.Draw(retImage, images[2].Bounds().Add(image.Pt(0, firstBounds.Max.Y)), images[2], image.Point{}, draw.Over) 48 | draw.Draw(retImage, images[3].Bounds().Add(image.Pt(firstBounds.Max.X, firstBounds.Max.Y)), images[3], image.Point{}, draw.Over) 49 | 50 | imageBuf := new(bytes.Buffer) 51 | 52 | err := png.Encode(imageBuf, retImage) 53 | if err != nil { 54 | return nil, err 55 | } 56 | 57 | return imageBuf, nil 58 | } 59 | -------------------------------------------------------------------------------- /repositories/default_settings/sqlite.go: -------------------------------------------------------------------------------- 1 | package default_settings 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "stable_diffusion_bot/clock" 9 | "stable_diffusion_bot/entities" 10 | "stable_diffusion_bot/repositories" 11 | ) 12 | 13 | const upsertSetting string = ` 14 | INSERT OR REPLACE INTO default_settings (member_id, width, height, batch_count, batch_size) VALUES (?, ?, ?, ?, ?); 15 | ` 16 | 17 | const getSettingByMemberID string = ` 18 | SELECT member_id, width, height, batch_count, batch_size FROM default_settings WHERE member_id = ?; 19 | ` 20 | 21 | type sqliteRepo struct { 22 | dbConn *sql.DB 23 | clock clock.Clock 24 | } 25 | 26 | type Config struct { 27 | DB *sql.DB 28 | } 29 | 30 | func NewRepository(cfg *Config) (Repository, error) { 31 | if cfg.DB == nil { 32 | return nil, errors.New("missing DB parameter") 33 | } 34 | 35 | newRepo := &sqliteRepo{ 36 | dbConn: cfg.DB, 37 | clock: clock.NewClock(), 38 | } 39 | 40 | return newRepo, nil 41 | } 42 | 43 | func (repo *sqliteRepo) Upsert(ctx context.Context, setting *entities.DefaultSettings) (*entities.DefaultSettings, error) { 44 | _, err := repo.dbConn.ExecContext(ctx, upsertSetting, 45 | setting.MemberID, setting.Width, setting.Height, setting.BatchCount, setting.BatchSize) 46 | if err != nil { 47 | return nil, err 48 | } 49 | 50 | return setting, nil 51 | } 52 | 53 | func (repo *sqliteRepo) GetByMemberID(ctx context.Context, memberID string) (*entities.DefaultSettings, error) { 54 | var setting entities.DefaultSettings 55 | 56 | err := repo.dbConn.QueryRowContext(ctx, getSettingByMemberID, memberID).Scan( 57 | &setting.MemberID, &setting.Width, &setting.Height, &setting.BatchCount, &setting.BatchSize) 58 | if err != nil { 59 | if errors.Is(err, sql.ErrNoRows) { 60 | return nil, repositories.NewNotFoundError(fmt.Sprintf("default setting for member ID %s", memberID)) 61 | } 62 | 63 | return nil, err 64 | } 65 | 66 | return &setting, nil 67 | } 68 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "log" 7 | "stable_diffusion_bot/databases/sqlite" 8 | "stable_diffusion_bot/discord_bot" 9 | "stable_diffusion_bot/imagine_queue" 10 | "stable_diffusion_bot/repositories/default_settings" 11 | "stable_diffusion_bot/repositories/image_generations" 12 | "stable_diffusion_bot/stable_diffusion_api" 13 | ) 14 | 15 | // Bot parameters 16 | var ( 17 | guildID = flag.String("guild", "", "Guild ID. If not passed - bot registers commands globally") 18 | botToken = flag.String("token", "", "Bot access token") 19 | apiHost = flag.String("host", "", "Host for the Automatic1111 API") 20 | imagineCommand = flag.String("imagine", "imagine", "Imagine command name. Default is \"imagine\"") 21 | removeCommandsFlag = flag.Bool("remove", false, "Delete all commands when bot exits") 22 | devModeFlag = flag.Bool("dev", false, "Start in development mode, using \"dev_\" prefixed commands instead") 23 | ) 24 | 25 | func main() { 26 | flag.Parse() 27 | 28 | if guildID == nil || *guildID == "" { 29 | log.Fatalf("Guild ID flag is required") 30 | } 31 | 32 | if botToken == nil || *botToken == "" { 33 | log.Fatalf("Bot token flag is required") 34 | } 35 | 36 | if apiHost == nil || *apiHost == "" { 37 | log.Fatalf("API host flag is required") 38 | } 39 | 40 | if imagineCommand == nil || *imagineCommand == "" { 41 | log.Fatalf("Imagine command flag is required") 42 | } 43 | 44 | devMode := false 45 | 46 | if devModeFlag != nil && *devModeFlag { 47 | devMode = *devModeFlag 48 | 49 | log.Printf("Starting in development mode.. all commands prefixed with \"dev_\"") 50 | } 51 | 52 | removeCommands := false 53 | 54 | if removeCommandsFlag != nil && *removeCommandsFlag { 55 | removeCommands = *removeCommandsFlag 56 | } 57 | 58 | stableDiffusionAPI, err := stable_diffusion_api.New(stable_diffusion_api.Config{ 59 | Host: *apiHost, 60 | }) 61 | if err != nil { 62 | log.Fatalf("Failed to create Stable Diffusion API: %v", err) 63 | } 64 | 65 | ctx := context.Background() 66 | 67 | sqliteDB, err := sqlite.New(ctx) 68 | if err != nil { 69 | log.Fatalf("Failed to create sqlite database: %v", err) 70 | } 71 | 72 | generationRepo, err := image_generations.NewRepository(&image_generations.Config{DB: sqliteDB}) 73 | if err != nil { 74 | log.Fatalf("Failed to create image generation repository: %v", err) 75 | } 76 | 77 | defaultSettingsRepo, err := default_settings.NewRepository(&default_settings.Config{DB: sqliteDB}) 78 | if err != nil { 79 | log.Fatalf("Failed to create default settings repository: %v", err) 80 | } 81 | 82 | imagineQueue, err := imagine_queue.New(imagine_queue.Config{ 83 | StableDiffusionAPI: stableDiffusionAPI, 84 | ImageGenerationRepo: generationRepo, 85 | DefaultSettingsRepo: defaultSettingsRepo, 86 | }) 87 | if err != nil { 88 | log.Fatalf("Failed to create imagine queue: %v", err) 89 | } 90 | 91 | bot, err := discord_bot.New(discord_bot.Config{ 92 | DevelopmentMode: devMode, 93 | BotToken: *botToken, 94 | GuildID: *guildID, 95 | ImagineQueue: imagineQueue, 96 | ImagineCommand: *imagineCommand, 97 | RemoveCommands: removeCommands, 98 | }) 99 | if err != nil { 100 | log.Fatalf("Error creating Discord bot: %v", err) 101 | } 102 | 103 | bot.Start() 104 | 105 | log.Println("Gracefully shutting down.") 106 | } 107 | -------------------------------------------------------------------------------- /repositories/image_generations/sqlite.go: -------------------------------------------------------------------------------- 1 | package image_generations 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "stable_diffusion_bot/clock" 8 | "stable_diffusion_bot/entities" 9 | ) 10 | 11 | const insertGenerationQuery string = ` 12 | INSERT INTO image_generations (interaction_id, message_id, member_id, sort_order, prompt, negative_prompt, width, height, restore_faces, enable_hr, hires_width, hires_height, denoising_strength, batch_count, batch_size, seed, subseed, subseed_strength, sampler_name, cfg_scale, steps, processed, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); 13 | ` 14 | 15 | const getGenerationByMessageID string = ` 16 | SELECT id, interaction_id, message_id, member_id, sort_order, prompt, negative_prompt, width, height, restore_faces, enable_hr, hires_width, hires_height, denoising_strength, batch_count, batch_size, seed, subseed, subseed_strength, sampler_name, cfg_scale, steps, processed, created_at FROM image_generations WHERE message_id = ?; 17 | ` 18 | 19 | const getGenerationByMessageIDAndSortOrder string = ` 20 | SELECT id, interaction_id, message_id, member_id, sort_order, prompt, negative_prompt, width, height, restore_faces, enable_hr, hires_width, hires_height, denoising_strength, batch_count, batch_size, seed, subseed, subseed_strength, sampler_name, cfg_scale, steps, processed, created_at FROM image_generations WHERE message_id = ? AND sort_order = ?; 21 | ` 22 | 23 | type sqliteRepo struct { 24 | dbConn *sql.DB 25 | clock clock.Clock 26 | } 27 | 28 | type Config struct { 29 | DB *sql.DB 30 | } 31 | 32 | func NewRepository(cfg *Config) (Repository, error) { 33 | if cfg.DB == nil { 34 | return nil, errors.New("missing DB parameter") 35 | } 36 | 37 | newRepo := &sqliteRepo{ 38 | dbConn: cfg.DB, 39 | clock: clock.NewClock(), 40 | } 41 | 42 | return newRepo, nil 43 | } 44 | 45 | func (repo *sqliteRepo) Create(ctx context.Context, generation *entities.ImageGeneration) (*entities.ImageGeneration, error) { 46 | generation.CreatedAt = repo.clock.Now() 47 | 48 | res, err := repo.dbConn.ExecContext(ctx, insertGenerationQuery, 49 | generation.InteractionID, generation.MessageID, generation.MemberID, generation.SortOrder, generation.Prompt, 50 | generation.NegativePrompt, generation.Width, generation.Height, generation.RestoreFaces, 51 | generation.EnableHR, generation.HiresWidth, generation.HiresHeight, generation.DenoisingStrength, 52 | generation.BatchCount, generation.BatchSize, generation.Seed, generation.Subseed, 53 | generation.SubseedStrength, generation.SamplerName, generation.CfgScale, generation.Steps, generation.Processed, generation.CreatedAt) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | lastID, err := res.LastInsertId() 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | generation.ID = lastID 64 | 65 | return generation, nil 66 | } 67 | 68 | func (repo *sqliteRepo) GetByMessage(ctx context.Context, messageID string) (*entities.ImageGeneration, error) { 69 | var generation entities.ImageGeneration 70 | 71 | err := repo.dbConn.QueryRowContext(ctx, getGenerationByMessageID, messageID).Scan( 72 | &generation.ID, &generation.InteractionID, &generation.MessageID, &generation.MemberID, &generation.SortOrder, &generation.Prompt, 73 | &generation.NegativePrompt, &generation.Width, &generation.Height, &generation.RestoreFaces, 74 | &generation.EnableHR, &generation.HiresWidth, &generation.HiresHeight, &generation.DenoisingStrength, 75 | &generation.BatchCount, &generation.BatchSize, &generation.Seed, &generation.Subseed, 76 | &generation.SubseedStrength, &generation.SamplerName, &generation.CfgScale, &generation.Steps, &generation.Processed, &generation.CreatedAt) 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | return &generation, nil 82 | } 83 | 84 | func (repo *sqliteRepo) GetByMessageAndSort(ctx context.Context, messageID string, sortOrder int) (*entities.ImageGeneration, error) { 85 | var generation entities.ImageGeneration 86 | 87 | err := repo.dbConn.QueryRowContext(ctx, getGenerationByMessageIDAndSortOrder, messageID, sortOrder).Scan( 88 | &generation.ID, &generation.InteractionID, &generation.MessageID, &generation.MemberID, &generation.SortOrder, &generation.Prompt, 89 | &generation.NegativePrompt, &generation.Width, &generation.Height, &generation.RestoreFaces, 90 | &generation.EnableHR, &generation.HiresWidth, &generation.HiresHeight, &generation.DenoisingStrength, 91 | &generation.BatchCount, &generation.BatchSize, &generation.Seed, &generation.Subseed, 92 | &generation.SubseedStrength, &generation.SamplerName, &generation.CfgScale, &generation.Steps, &generation.Processed, &generation.CreatedAt) 93 | if err != nil { 94 | return nil, err 95 | } 96 | 97 | return &generation, nil 98 | } 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion Discord Bot 2 | 3 | This is a Discord bot that interfaces with the Automatic1111 API, from this project: https://github.com/AUTOMATIC1111/stable-diffusion-webui 4 | 5 | Video showing off the current features: 6 | https://www.youtube.com/watch?v=of5MBh3ueMk 7 | 8 | ## Installation 9 | 10 | 1. Download the appropriate version for your system from the releases page: https://github.com/AndBobsYourUncle/stable-diffusion-discord-bot/releases 11 | 1. Windows users will need to use the windows-amd64 version 12 | 2. Intel Macs will need to use the darwin-amd64 version 13 | 3. M1 Macs will need to use the darwin-arm64 version 14 | 4. Devices like a Raspberry Pi will need to use the linux-arm64 version 15 | 5. Most other Linux devices will need to use the linux-amd64 version 16 | 2. Extract the archive folder to a location of your choice 17 | 18 | ## Building (optional, only if you want to build from source) 19 | 20 | 1. Clone this repository 21 | 2. Install Go 22 | * This varies with your operating system, but the easiest way is to use the official installer: https://golang.org/dl/ 23 | 3. Build the bot with `go build` 24 | 25 | ## Usage 26 | 27 | 1. Create a Discord bot and get the token 28 | 2. Add the Discord bot to your Discord server. It needs permissions to post messages, use slash commands, mentioning anyone, and uploading files. 29 | 3. Ensure that the Automatic 1111 webui is running with `--api` (and also `--listen` if it is running on a different computer than the bot). 30 | 4. Run the bot with `./stable_diffusion_bot -token -guild -host ` 31 | * It's important that the `-host` parameter matches the IP address where the A1111 is running. If the bot is on the same computer, `127.0.0.1` will work. 32 | * There needs to be no trailing slash after the port number (which is `7860` in this example). So, instead of `http://127.0.0.1:7860/`, it should be `http://127.0.0.1:7860`. 33 | 5. The first run will generate a new SQLite DB file in the current working directory. 34 | 35 | The `-imagine ` flag can be used to have the bot use a different command when running, so that it doesn't collide with a Midjourney bot running on the same Discord server. 36 | 37 | ## Commands 38 | 39 | ### `/imagine_settings` 40 | 41 | Responds with a message that has buttons to allow updating of the default settings for the `/imagine` command. 42 | 43 | By default, the size is 512x512. However, if you are running the Stable Diffusion 2.0 768 model, you might want to change this to 768x768. 44 | 45 | Choosing an option will cause the bot to update the setting, and edit the message in place, allowing further edits. 46 | 47 | Screenshot 2023-01-06 at 10 41 36 AM 48 | 49 | ### `/imagine` 50 | 51 | Creates an image from a text prompt. (e.g. `/imagine cute kitten riding a skateboard`) 52 | 53 | Available options: 54 | - Aspect Ratio 55 | - `--ar :` (e.g. `/imagine cute kitten riding a skateboard --ar 16:9`) 56 | - Uses the default width or height, and calculates the final value for the other based on the aspect ratio. It then rounds that value up to the nearest multiple of `8`, to match the expectations of the underlying neural model and SD API. 57 | - Under the hood, it will use the "Hires fix" option in the API, which will generate an image with the bot's default width/height, and then resize it to the desired aspect ratio. 58 | 59 | ## How it Works 60 | 61 | The bot implements a FIFO queue (first in, first out). When a user issues the `/imagine` command (or uses an interaction button), they are added to the end of the queue. 62 | 63 | The bot then checks the queue every second. If the queue is not empty, and there is nothing currently being processed, it will send the top interaction to the Automatic1111 WebUI API, and then remove it from the queue. 64 | 65 | After the Automatic1111 has finished processing the interaction, the bot will then update the reply message with the finished result. 66 | 67 | Buttons are added to the Discord response message for interactions like re-roll, variations, and up-scaling. 68 | 69 | All image generations are saved into a local SQLite database, so that the parameters of the image can be retrieved later for variations or up-scaling. 70 | 71 | Screenshot 2022-12-22 at 4 25 03 PM 72 | 73 | Screenshot 2022-12-22 at 4 25 18 PM 74 | 75 | Options like aspect ratio are extracted and sanitized from the text prompt, and then the resulting options are stored in the database record for the image generation (for further variations or upscaling): 76 | 77 | Screenshot 2022-12-28 at 4 30 43 PM 78 | 79 | ## Contributing 80 | 81 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. 82 | 83 | There are lots more features that could be added to this bot, such as: 84 | 85 | - [x] Moving defaults to the database 86 | - [ ] Per-user defaults/settings, as well as enforcing limits on a user's usage of the bot 87 | - [x] Ability to easily re-roll an image 88 | - [x] Generating multiple images at once 89 | - [x] Ability to upscale the resulting images 90 | - [x] Ability to generate variations on a grid image 91 | - [ ] Ability to tweak more settings when issuing the `/imagine` command (like aspect ratio) 92 | - [ ] Image to image processing 93 | 94 | I'll probably be adding a few of these over time, but any contributions are also welcome. 95 | 96 | ## Why Go? 97 | 98 | I like Go a lot better than Python, and for me it's a lot easier to maintain dependencies with Go modules versus running a bunch of different Anaconda environments. 99 | 100 | It's also able to be cross-compiled to a wide range of platforms, which is nice. 101 | -------------------------------------------------------------------------------- /databases/sqlite/sqlite.go: -------------------------------------------------------------------------------- 1 | package sqlite 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "log" 7 | "os" 8 | "strconv" 9 | "strings" 10 | 11 | _ "modernc.org/sqlite" 12 | ) 13 | 14 | const dbFile string = "sd_discord_bot.sqlite" 15 | 16 | const getCurrentMigration string = `PRAGMA user_version;` 17 | const setCurrentMigration string = `PRAGMA user_version = ?;` 18 | 19 | const createGenerationTableIfNotExistsQuery string = ` 20 | CREATE TABLE IF NOT EXISTS image_generations ( 21 | id INTEGER NOT NULL PRIMARY KEY, 22 | interaction_id TEXT NOT NULL, 23 | message_id TEXT NOT NULL, 24 | member_id TEXT NOT NULL, 25 | sort_order INTEGER NOT NULL, 26 | prompt TEXT NOT NULL, 27 | negative_prompt TEXT NOT NULL, 28 | width INTEGER NOT NULL, 29 | height INTEGER NOT NULL, 30 | restore_faces INTEGER NOT NULL, 31 | enable_hr INTEGER NOT NULL, 32 | denoising_strength REAL NOT NULL, 33 | batch_size INTEGER NOT NULL, 34 | seed INTEGER NOT NULL, 35 | subseed INTEGER NOT NULL, 36 | subseed_strength REAL NOT NULL, 37 | sampler_name TEXT NOT NULL, 38 | cfg_scale REAL NOT NULL, 39 | steps INTEGER NOT NULL, 40 | processed INTEGER NOT NULL, 41 | created_at DATETIME NOT NULL 42 | );` 43 | 44 | const createInteractionIndexIfNotExistsQuery string = ` 45 | CREATE INDEX IF NOT EXISTS generation_interaction_index 46 | ON image_generations(interaction_id); 47 | ` 48 | 49 | const createMessageIndexIfNotExistsQuery string = ` 50 | CREATE INDEX IF NOT EXISTS generation_interaction_index 51 | ON image_generations(message_id); 52 | ` 53 | 54 | const addHiresFirstPassDimensionColumnsQuery string = ` 55 | ALTER TABLE image_generations ADD COLUMN firstphase_width INTEGER NOT NULL DEFAULT 0; 56 | ALTER TABLE image_generations ADD COLUMN firstphase_height INTEGER NOT NULL DEFAULT 0; 57 | ` 58 | 59 | const dropHiresFirstPassDimensionColumnsQuery string = ` 60 | ALTER TABLE image_generations DROP COLUMN firstphase_width; 61 | ALTER TABLE image_generations DROP COLUMN firstphase_height; 62 | ` 63 | 64 | const addHiresResizeColumnsQuery string = ` 65 | ALTER TABLE image_generations ADD COLUMN hires_width INTEGER NOT NULL DEFAULT 0; 66 | ALTER TABLE image_generations ADD COLUMN hires_height INTEGER NOT NULL DEFAULT 0; 67 | ` 68 | 69 | const createDefaultSettingsTableIfNotExistsQuery string = ` 70 | CREATE TABLE IF NOT EXISTS default_settings ( 71 | member_id TEXT NOT NULL PRIMARY KEY, 72 | width INTEGER NOT NULL, 73 | height INTEGER NOT NULL 74 | );` 75 | 76 | const addSettingsBatchColumnsQuery string = ` 77 | ALTER TABLE default_settings ADD COLUMN batch_count INTEGER NOT NULL DEFAULT 0; 78 | ALTER TABLE default_settings ADD COLUMN batch_size INTEGER NOT NULL DEFAULT 0; 79 | ` 80 | 81 | const addGenerationBatchSizeColumnQuery string = ` 82 | ALTER TABLE image_generations ADD COLUMN batch_count INTEGER NOT NULL DEFAULT 0; 83 | ` 84 | 85 | type migration struct { 86 | migrationName string 87 | migrationQuery string 88 | } 89 | 90 | var migrations = []migration{ 91 | {migrationName: "create generation table", migrationQuery: createGenerationTableIfNotExistsQuery}, 92 | {migrationName: "add generation interaction index", migrationQuery: createInteractionIndexIfNotExistsQuery}, 93 | {migrationName: "add generation message index", migrationQuery: createMessageIndexIfNotExistsQuery}, 94 | {migrationName: "add hires firstpass columns", migrationQuery: addHiresFirstPassDimensionColumnsQuery}, 95 | {migrationName: "drop hires firstpass columns", migrationQuery: dropHiresFirstPassDimensionColumnsQuery}, 96 | {migrationName: "add hires resize columns", migrationQuery: addHiresResizeColumnsQuery}, 97 | {migrationName: "create default settings table", migrationQuery: createDefaultSettingsTableIfNotExistsQuery}, 98 | {migrationName: "add settings batch columns", migrationQuery: addSettingsBatchColumnsQuery}, 99 | {migrationName: "add generation batch count column", migrationQuery: addGenerationBatchSizeColumnQuery}, 100 | } 101 | 102 | func New(ctx context.Context) (*sql.DB, error) { 103 | filename, err := DBFilename() 104 | if err != nil { 105 | return nil, err 106 | } 107 | 108 | err = touchDBFile(filename) 109 | if err != nil { 110 | return nil, err 111 | } 112 | 113 | db, err := sql.Open("sqlite", filename) 114 | if err != nil { 115 | return nil, err 116 | } 117 | 118 | err = migrate(ctx, db) 119 | if err != nil { 120 | return nil, err 121 | } 122 | 123 | return db, nil 124 | } 125 | 126 | func migrate(ctx context.Context, db *sql.DB) error { 127 | var currentMigration int 128 | 129 | row := db.QueryRowContext(ctx, getCurrentMigration) 130 | 131 | err := row.Scan(¤tMigration) 132 | if err != nil { 133 | return err 134 | } 135 | 136 | requiredMigration := len(migrations) 137 | 138 | log.Printf("Current DB version: %v, required DB version: %v\n", currentMigration, requiredMigration) 139 | 140 | if currentMigration < requiredMigration { 141 | for migrationNum := currentMigration + 1; migrationNum <= requiredMigration; migrationNum++ { 142 | err = execMigration(ctx, db, migrationNum) 143 | if err != nil { 144 | log.Printf("Error running migration %v '%v'\n", migrationNum, migrations[migrationNum-1].migrationName) 145 | 146 | return err 147 | } 148 | } 149 | } 150 | 151 | return nil 152 | } 153 | 154 | func execMigration(ctx context.Context, db *sql.DB, migrationNum int) error { 155 | log.Printf("Running migration %v '%v'\n", migrationNum, migrations[migrationNum-1].migrationName) 156 | 157 | tx, err := db.BeginTx(ctx, nil) 158 | if err != nil { 159 | return err 160 | } 161 | 162 | //nolint 163 | defer tx.Rollback() 164 | 165 | _, err = tx.ExecContext(ctx, migrations[migrationNum-1].migrationQuery) 166 | if err != nil { 167 | return err 168 | } 169 | 170 | setQuery := strings.Replace(setCurrentMigration, "?", strconv.Itoa(migrationNum), 1) 171 | 172 | _, err = tx.ExecContext(ctx, setQuery) 173 | if err != nil { 174 | return err 175 | } 176 | 177 | err = tx.Commit() 178 | if err != nil { 179 | return err 180 | } 181 | 182 | return nil 183 | } 184 | 185 | func DBFilename() (string, error) { 186 | dir, err := os.Getwd() 187 | if err != nil { 188 | return "", err 189 | } 190 | 191 | return dir + "/" + dbFile, nil 192 | } 193 | 194 | func touchDBFile(filename string) error { 195 | _, err := os.Stat(filename) 196 | if os.IsNotExist(err) { 197 | file, createErr := os.Create(filename) 198 | if createErr != nil { 199 | return createErr 200 | } 201 | 202 | closeErr := file.Close() 203 | if closeErr != nil { 204 | return closeErr 205 | } 206 | } 207 | 208 | return nil 209 | } 210 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bwmarrin/discordgo v0.26.1 h1:AIrM+g3cl+iYBr4yBxCBp9tD9jR3K7upEjl0d89FRkE= 2 | github.com/bwmarrin/discordgo v0.26.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= 3 | github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= 4 | github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= 5 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 6 | github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= 7 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= 8 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 9 | github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= 10 | github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 11 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= 12 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= 13 | github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= 14 | github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= 15 | github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= 16 | github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= 17 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 18 | github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= 19 | github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= 20 | github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= 21 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 22 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 23 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 24 | golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b h1:7mWr3k41Qtv8XlltBkDkl8LoP3mpSgBW8BUoxtEdbXg= 25 | golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= 26 | golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= 27 | golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 28 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 29 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 30 | golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= 31 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 32 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 33 | golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 34 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 35 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 36 | golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 37 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 38 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab h1:2QkjZIsXupsJbJIdSjjUOgWK3aEtzyuh2mPt3l/CkeU= 39 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 40 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 41 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 42 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 43 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 44 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 45 | golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 h1:M8tBwCtWD/cZV9DZpFYRUgaymAYAr+aIUTWzDaM3uPs= 46 | golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= 47 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 48 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 49 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= 50 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 51 | lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= 52 | lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= 53 | modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= 54 | modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= 55 | modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw= 56 | modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY= 57 | modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= 58 | modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= 59 | modernc.org/libc v1.22.2 h1:4U7v51GyhlWqQmwCHj28Rdq2Yzwk55ovjFrdPjs8Hb0= 60 | modernc.org/libc v1.22.2/go.mod h1:uvQavJ1pZ0hIoC/jfqNoMLURIMhKzINIWypNM17puug= 61 | modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= 62 | modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= 63 | modernc.org/memory v1.4.0 h1:crykUfNSnMAXaOJnnxcSzbUGMqkLWjklJKkBK2nwZwk= 64 | modernc.org/memory v1.4.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= 65 | modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= 66 | modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= 67 | modernc.org/sqlite v1.20.1 h1:z6qRLw72B0VfRrJjs3l6hWkzYDx1bo0WGVrBGP4ohhM= 68 | modernc.org/sqlite v1.20.1/go.mod h1:fODt+bFmc/j8LcoCbMSkAuKuGmhxjG45KGc25N2705M= 69 | modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY= 70 | modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw= 71 | modernc.org/tcl v1.15.0 h1:oY+JeD11qVVSgVvodMJsu7Edf8tr5E/7tuhF5cNYz34= 72 | modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg= 73 | modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= 74 | modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE= 75 | -------------------------------------------------------------------------------- /stable_diffusion_api/stable_diffusion.go: -------------------------------------------------------------------------------- 1 | package stable_diffusion_api 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "log" 9 | "net/http" 10 | ) 11 | 12 | type apiImpl struct { 13 | host string 14 | } 15 | 16 | type Config struct { 17 | Host string 18 | } 19 | 20 | func New(cfg Config) (StableDiffusionAPI, error) { 21 | if cfg.Host == "" { 22 | return nil, errors.New("missing host") 23 | } 24 | 25 | return &apiImpl{ 26 | host: cfg.Host, 27 | }, nil 28 | } 29 | 30 | type jsonTextToImageResponse struct { 31 | Images []string `json:"images"` 32 | Info string `json:"info"` 33 | } 34 | 35 | type jsonInfoResponse struct { 36 | Seed int `json:"seed"` 37 | AllSeeds []int `json:"all_seeds"` 38 | AllSubseeds []int `json:"all_subseeds"` 39 | } 40 | 41 | type TextToImageResponse struct { 42 | Images []string `json:"images"` 43 | Seeds []int `json:"seeds"` 44 | Subseeds []int `json:"subseeds"` 45 | } 46 | 47 | type TextToImageRequest struct { 48 | Prompt string `json:"prompt"` 49 | NegativePrompt string `json:"negative_prompt"` 50 | Width int `json:"width"` 51 | Height int `json:"height"` 52 | RestoreFaces bool `json:"restore_faces"` 53 | EnableHR bool `json:"enable_hr"` 54 | HRResizeX int `json:"hr_resize_x"` 55 | HRResizeY int `json:"hr_resize_y"` 56 | DenoisingStrength float64 `json:"denoising_strength"` 57 | BatchSize int `json:"batch_size"` 58 | Seed int `json:"seed"` 59 | Subseed int `json:"subseed"` 60 | SubseedStrength float64 `json:"subseed_strength"` 61 | SamplerName string `json:"sampler_name"` 62 | CfgScale float64 `json:"cfg_scale"` 63 | Steps int `json:"steps"` 64 | NIter int `json:"n_iter"` 65 | } 66 | 67 | func (api *apiImpl) TextToImage(req *TextToImageRequest) (*TextToImageResponse, error) { 68 | if req == nil { 69 | return nil, errors.New("missing request") 70 | } 71 | 72 | postURL := api.host + "/sdapi/v1/txt2img" 73 | 74 | jsonData, err := json.Marshal(req) 75 | if err != nil { 76 | return nil, err 77 | } 78 | 79 | request, err := http.NewRequest("POST", postURL, bytes.NewBuffer(jsonData)) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | request.Header.Set("Content-Type", "application/json; charset=UTF-8") 85 | 86 | client := &http.Client{} 87 | 88 | response, err := client.Do(request) 89 | if err != nil { 90 | log.Printf("API URL: %s", postURL) 91 | log.Printf("Error with API Request: %s", string(jsonData)) 92 | 93 | return nil, err 94 | } 95 | 96 | defer response.Body.Close() 97 | 98 | body, _ := io.ReadAll(response.Body) 99 | 100 | respStruct := &jsonTextToImageResponse{} 101 | 102 | err = json.Unmarshal(body, respStruct) 103 | if err != nil { 104 | log.Printf("API URL: %s", postURL) 105 | log.Printf("Unexpected API response: %s", string(body)) 106 | 107 | return nil, err 108 | } 109 | 110 | infoStruct := &jsonInfoResponse{} 111 | 112 | err = json.Unmarshal([]byte(respStruct.Info), infoStruct) 113 | if err != nil { 114 | log.Printf("API URL: %s", postURL) 115 | log.Printf("Unexpected API response: %s", string(body)) 116 | 117 | return nil, err 118 | } 119 | 120 | return &TextToImageResponse{ 121 | Images: respStruct.Images, 122 | Seeds: infoStruct.AllSeeds, 123 | Subseeds: infoStruct.AllSubseeds, 124 | }, nil 125 | } 126 | 127 | type UpscaleRequest struct { 128 | ResizeMode int `json:"resize_mode"` 129 | UpscalingResize int `json:"upscaling_resize"` 130 | Upscaler1 string `json:"upscaler1"` 131 | TextToImageRequest *TextToImageRequest `json:"text_to_image_request"` 132 | } 133 | 134 | type upscaleJSONRequest struct { 135 | ResizeMode int `json:"resize_mode"` 136 | UpscalingResize int `json:"upscaling_resize"` 137 | Upscaler1 string `json:"upscaler1"` 138 | Image string `json:"image"` 139 | } 140 | 141 | type UpscaleResponse struct { 142 | Image string `json:"image"` 143 | } 144 | 145 | func (api *apiImpl) UpscaleImage(upscaleReq *UpscaleRequest) (*UpscaleResponse, error) { 146 | if upscaleReq == nil { 147 | return nil, errors.New("missing request") 148 | } 149 | 150 | textToImageReq := upscaleReq.TextToImageRequest 151 | 152 | if textToImageReq == nil { 153 | return nil, errors.New("missing text to image request") 154 | } 155 | 156 | textToImageReq.NIter = 1 157 | 158 | regeneratedImage, err := api.TextToImage(textToImageReq) 159 | if err != nil { 160 | return nil, err 161 | } 162 | 163 | jsonReq := &upscaleJSONRequest{ 164 | ResizeMode: upscaleReq.ResizeMode, 165 | UpscalingResize: upscaleReq.UpscalingResize, 166 | Upscaler1: upscaleReq.Upscaler1, 167 | Image: regeneratedImage.Images[0], 168 | } 169 | 170 | postURL := api.host + "/sdapi/v1/extra-single-image" 171 | 172 | jsonData, err := json.Marshal(jsonReq) 173 | if err != nil { 174 | return nil, err 175 | } 176 | 177 | request, err := http.NewRequest("POST", postURL, bytes.NewBuffer(jsonData)) 178 | if err != nil { 179 | return nil, err 180 | } 181 | 182 | request.Header.Set("Content-Type", "application/json; charset=UTF-8") 183 | 184 | client := &http.Client{} 185 | 186 | response, err := client.Do(request) 187 | if err != nil { 188 | log.Printf("API URL: %s", postURL) 189 | log.Printf("Error with API Request: %s", string(jsonData)) 190 | 191 | return nil, err 192 | } 193 | 194 | defer response.Body.Close() 195 | 196 | body, _ := io.ReadAll(response.Body) 197 | 198 | respStruct := &UpscaleResponse{} 199 | 200 | err = json.Unmarshal(body, respStruct) 201 | if err != nil { 202 | log.Printf("API URL: %s", postURL) 203 | log.Printf("Unexpected API response: %s", string(body)) 204 | 205 | return nil, err 206 | } 207 | 208 | return respStruct, nil 209 | } 210 | 211 | type ProgressResponse struct { 212 | Progress float64 `json:"progress"` 213 | EtaRelative float64 `json:"eta_relative"` 214 | } 215 | 216 | func (api *apiImpl) GetCurrentProgress() (*ProgressResponse, error) { 217 | getURL := api.host + "/sdapi/v1/progress" 218 | 219 | request, err := http.NewRequest("GET", getURL, bytes.NewBuffer([]byte{})) 220 | if err != nil { 221 | return nil, err 222 | } 223 | 224 | client := &http.Client{} 225 | 226 | response, err := client.Do(request) 227 | if err != nil { 228 | log.Printf("API URL: %s", getURL) 229 | log.Printf("Error with API Request: %v", err) 230 | 231 | return nil, err 232 | } 233 | 234 | defer response.Body.Close() 235 | 236 | body, _ := io.ReadAll(response.Body) 237 | 238 | respStruct := &ProgressResponse{} 239 | 240 | err = json.Unmarshal(body, respStruct) 241 | if err != nil { 242 | log.Printf("API URL: %s", getURL) 243 | log.Printf("Unexpected API response: %s", string(body)) 244 | 245 | return nil, err 246 | } 247 | 248 | return respStruct, nil 249 | } 250 | -------------------------------------------------------------------------------- /discord_bot/discord_bot.go: -------------------------------------------------------------------------------- 1 | package discord_bot 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log" 7 | "stable_diffusion_bot/entities" 8 | "stable_diffusion_bot/imagine_queue" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/bwmarrin/discordgo" 13 | ) 14 | 15 | type botImpl struct { 16 | developmentMode bool 17 | botSession *discordgo.Session 18 | guildID string 19 | imagineQueue imagine_queue.Queue 20 | registeredCommands []*discordgo.ApplicationCommand 21 | imagineCommand string 22 | removeCommands bool 23 | } 24 | 25 | type Config struct { 26 | DevelopmentMode bool 27 | BotToken string 28 | GuildID string 29 | ImagineQueue imagine_queue.Queue 30 | ImagineCommand string 31 | RemoveCommands bool 32 | } 33 | 34 | func (b *botImpl) imagineCommandString() string { 35 | if b.developmentMode { 36 | return "dev_" + b.imagineCommand 37 | } 38 | 39 | return b.imagineCommand 40 | } 41 | 42 | func (b *botImpl) imagineSettingsCommandString() string { 43 | if b.developmentMode { 44 | return "dev_" + b.imagineCommand + "_settings" 45 | } 46 | 47 | return b.imagineCommand + "_settings" 48 | } 49 | 50 | func New(cfg Config) (Bot, error) { 51 | if cfg.BotToken == "" { 52 | return nil, errors.New("missing bot token") 53 | } 54 | 55 | if cfg.GuildID == "" { 56 | return nil, errors.New("missing guild ID") 57 | } 58 | 59 | if cfg.ImagineQueue == nil { 60 | return nil, errors.New("missing imagine queue") 61 | } 62 | 63 | if cfg.ImagineCommand == "" { 64 | return nil, errors.New("missing imagine command") 65 | } 66 | 67 | botSession, err := discordgo.New("Bot " + cfg.BotToken) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | botSession.AddHandler(func(s *discordgo.Session, r *discordgo.Ready) { 73 | log.Printf("Logged in as: %v#%v", s.State.User.Username, s.State.User.Discriminator) 74 | }) 75 | err = botSession.Open() 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | bot := &botImpl{ 81 | developmentMode: cfg.DevelopmentMode, 82 | botSession: botSession, 83 | imagineQueue: cfg.ImagineQueue, 84 | registeredCommands: make([]*discordgo.ApplicationCommand, 0), 85 | imagineCommand: cfg.ImagineCommand, 86 | removeCommands: cfg.RemoveCommands, 87 | } 88 | 89 | err = bot.addImagineCommand() 90 | if err != nil { 91 | return nil, err 92 | } 93 | 94 | err = bot.addImagineSettingsCommand() 95 | if err != nil { 96 | return nil, err 97 | } 98 | 99 | botSession.AddHandler(func(s *discordgo.Session, i *discordgo.InteractionCreate) { 100 | switch i.Type { 101 | case discordgo.InteractionApplicationCommand: 102 | switch i.ApplicationCommandData().Name { 103 | case bot.imagineCommandString(): 104 | bot.processImagineCommand(s, i) 105 | case bot.imagineSettingsCommandString(): 106 | bot.processImagineSettingsCommand(s, i) 107 | default: 108 | log.Printf("Unknown command '%v'", i.ApplicationCommandData().Name) 109 | } 110 | case discordgo.InteractionMessageComponent: 111 | switch customID := i.MessageComponentData().CustomID; { 112 | case customID == "imagine_reroll": 113 | bot.processImagineReroll(s, i) 114 | case strings.HasPrefix(customID, "imagine_upscale_"): 115 | interactionIndex := strings.TrimPrefix(customID, "imagine_upscale_") 116 | 117 | interactionIndexInt, intErr := strconv.Atoi(interactionIndex) 118 | if intErr != nil { 119 | log.Printf("Error parsing interaction index: %v", err) 120 | 121 | return 122 | } 123 | 124 | bot.processImagineUpscale(s, i, interactionIndexInt) 125 | case strings.HasPrefix(customID, "imagine_variation_"): 126 | interactionIndex := strings.TrimPrefix(customID, "imagine_variation_") 127 | 128 | interactionIndexInt, intErr := strconv.Atoi(interactionIndex) 129 | if intErr != nil { 130 | log.Printf("Error parsing interaction index: %v", err) 131 | 132 | return 133 | } 134 | 135 | bot.processImagineVariation(s, i, interactionIndexInt) 136 | case customID == "imagine_dimension_setting_menu": 137 | if len(i.MessageComponentData().Values) == 0 { 138 | log.Printf("No values for imagine dimension setting menu") 139 | 140 | return 141 | } 142 | 143 | sizes := strings.Split(i.MessageComponentData().Values[0], "_") 144 | 145 | width := sizes[0] 146 | height := sizes[1] 147 | 148 | widthInt, intErr := strconv.Atoi(width) 149 | if intErr != nil { 150 | log.Printf("Error parsing width: %v", err) 151 | 152 | return 153 | } 154 | 155 | heightInt, intErr := strconv.Atoi(height) 156 | if intErr != nil { 157 | log.Printf("Error parsing height: %v", err) 158 | 159 | return 160 | } 161 | 162 | bot.processImagineDimensionSetting(s, i, widthInt, heightInt) 163 | case customID == "imagine_batch_count_setting_menu": 164 | if len(i.MessageComponentData().Values) == 0 { 165 | log.Printf("No values for imagine batch count setting menu") 166 | 167 | return 168 | } 169 | 170 | batchCount := i.MessageComponentData().Values[0] 171 | 172 | batchCountInt, intErr := strconv.Atoi(batchCount) 173 | if intErr != nil { 174 | log.Printf("Error parsing batch count: %v", err) 175 | 176 | return 177 | } 178 | 179 | var batchSizeInt int 180 | 181 | // calculate the corresponding batch size 182 | switch batchCountInt { 183 | case 1: 184 | batchSizeInt = 4 185 | case 2: 186 | batchSizeInt = 2 187 | case 4: 188 | batchSizeInt = 1 189 | default: 190 | log.Printf("Unknown batch count: %v", batchCountInt) 191 | 192 | return 193 | } 194 | 195 | bot.processImagineBatchSetting(s, i, batchCountInt, batchSizeInt) 196 | case customID == "imagine_batch_size_setting_menu": 197 | if len(i.MessageComponentData().Values) == 0 { 198 | log.Printf("No values for imagine batch count setting menu") 199 | 200 | return 201 | } 202 | 203 | batchSize := i.MessageComponentData().Values[0] 204 | 205 | batchSizeInt, intErr := strconv.Atoi(batchSize) 206 | if intErr != nil { 207 | log.Printf("Error parsing batch count: %v", err) 208 | 209 | return 210 | } 211 | 212 | var batchCountInt int 213 | 214 | // calculate the corresponding batch count 215 | switch batchSizeInt { 216 | case 1: 217 | batchCountInt = 4 218 | case 2: 219 | batchCountInt = 2 220 | case 4: 221 | batchCountInt = 1 222 | default: 223 | log.Printf("Unknown batch size: %v", batchSizeInt) 224 | 225 | return 226 | } 227 | 228 | bot.processImagineBatchSetting(s, i, batchCountInt, batchSizeInt) 229 | default: 230 | log.Printf("Unknown message component '%v'", i.MessageComponentData().CustomID) 231 | } 232 | } 233 | }) 234 | 235 | return bot, nil 236 | } 237 | 238 | func (b *botImpl) Start() { 239 | b.imagineQueue.StartPolling(b.botSession) 240 | 241 | err := b.teardown() 242 | if err != nil { 243 | log.Printf("Error tearing down bot: %v", err) 244 | } 245 | } 246 | 247 | func (b *botImpl) teardown() error { 248 | // Delete all commands added by the bot 249 | if b.removeCommands { 250 | log.Printf("Removing all commands added by bot...") 251 | 252 | for _, v := range b.registeredCommands { 253 | log.Printf("Removing command '%v'...", v.Name) 254 | 255 | err := b.botSession.ApplicationCommandDelete(b.botSession.State.User.ID, b.guildID, v.ID) 256 | if err != nil { 257 | log.Panicf("Cannot delete '%v' command: %v", v.Name, err) 258 | } 259 | } 260 | } 261 | 262 | return b.botSession.Close() 263 | } 264 | 265 | func (b *botImpl) addImagineCommand() error { 266 | log.Printf("Adding command '%s'...", b.imagineCommandString()) 267 | 268 | cmd, err := b.botSession.ApplicationCommandCreate(b.botSession.State.User.ID, b.guildID, &discordgo.ApplicationCommand{ 269 | Name: b.imagineCommandString(), 270 | Description: "Ask the bot to imagine something", 271 | Options: []*discordgo.ApplicationCommandOption{ 272 | { 273 | Type: discordgo.ApplicationCommandOptionString, 274 | Name: "prompt", 275 | Description: "The text prompt to imagine", 276 | Required: true, 277 | }, 278 | }, 279 | }) 280 | if err != nil { 281 | log.Printf("Error creating '%s' command: %v", b.imagineCommandString(), err) 282 | 283 | return err 284 | } 285 | 286 | b.registeredCommands = append(b.registeredCommands, cmd) 287 | 288 | return nil 289 | } 290 | 291 | func (b *botImpl) addImagineSettingsCommand() error { 292 | log.Printf("Adding command '%s'...", b.imagineSettingsCommandString()) 293 | 294 | cmd, err := b.botSession.ApplicationCommandCreate(b.botSession.State.User.ID, b.guildID, &discordgo.ApplicationCommand{ 295 | Name: b.imagineSettingsCommandString(), 296 | Description: "Change the default settings for the imagine command", 297 | }) 298 | if err != nil { 299 | log.Printf("Error creating '%s' command: %v", b.imagineSettingsCommandString(), err) 300 | 301 | return err 302 | } 303 | 304 | b.registeredCommands = append(b.registeredCommands, cmd) 305 | 306 | return nil 307 | } 308 | 309 | func (b *botImpl) processImagineReroll(s *discordgo.Session, i *discordgo.InteractionCreate) { 310 | position, queueError := b.imagineQueue.AddImagine(&imagine_queue.QueueItem{ 311 | Type: imagine_queue.ItemTypeReroll, 312 | DiscordInteraction: i.Interaction, 313 | }) 314 | if queueError != nil { 315 | log.Printf("Error adding imagine to queue: %v\n", queueError) 316 | } 317 | 318 | err := s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ 319 | Type: discordgo.InteractionResponseChannelMessageWithSource, 320 | Data: &discordgo.InteractionResponseData{ 321 | Content: fmt.Sprintf("I'm reimagining that for you... You are currently #%d in line.", position), 322 | }, 323 | }) 324 | if err != nil { 325 | log.Printf("Error responding to interaction: %v", err) 326 | } 327 | } 328 | 329 | func (b *botImpl) processImagineUpscale(s *discordgo.Session, i *discordgo.InteractionCreate, upscaleIndex int) { 330 | position, queueError := b.imagineQueue.AddImagine(&imagine_queue.QueueItem{ 331 | Type: imagine_queue.ItemTypeUpscale, 332 | InteractionIndex: upscaleIndex, 333 | DiscordInteraction: i.Interaction, 334 | }) 335 | if queueError != nil { 336 | log.Printf("Error adding imagine to queue: %v\n", queueError) 337 | } 338 | 339 | err := s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ 340 | Type: discordgo.InteractionResponseChannelMessageWithSource, 341 | Data: &discordgo.InteractionResponseData{ 342 | Content: fmt.Sprintf("I'm upscaling that for you... You are currently #%d in line.", position), 343 | }, 344 | }) 345 | if err != nil { 346 | log.Printf("Error responding to interaction: %v", err) 347 | } 348 | } 349 | 350 | func (b *botImpl) processImagineVariation(s *discordgo.Session, i *discordgo.InteractionCreate, variationIndex int) { 351 | position, queueError := b.imagineQueue.AddImagine(&imagine_queue.QueueItem{ 352 | Type: imagine_queue.ItemTypeVariation, 353 | InteractionIndex: variationIndex, 354 | DiscordInteraction: i.Interaction, 355 | }) 356 | if queueError != nil { 357 | log.Printf("Error adding imagine to queue: %v\n", queueError) 358 | } 359 | 360 | err := s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ 361 | Type: discordgo.InteractionResponseChannelMessageWithSource, 362 | Data: &discordgo.InteractionResponseData{ 363 | Content: fmt.Sprintf("I'm imagining more variations for you... You are currently #%d in line.", position), 364 | }, 365 | }) 366 | if err != nil { 367 | log.Printf("Error responding to interaction: %v", err) 368 | } 369 | } 370 | 371 | func (b *botImpl) processImagineCommand(s *discordgo.Session, i *discordgo.InteractionCreate) { 372 | options := i.ApplicationCommandData().Options 373 | 374 | optionMap := make(map[string]*discordgo.ApplicationCommandInteractionDataOption, len(options)) 375 | for _, opt := range options { 376 | optionMap[opt.Name] = opt 377 | } 378 | 379 | var position int 380 | var queueError error 381 | var prompt string 382 | 383 | if option, ok := optionMap["prompt"]; ok { 384 | prompt = option.StringValue() 385 | 386 | position, queueError = b.imagineQueue.AddImagine(&imagine_queue.QueueItem{ 387 | Prompt: prompt, 388 | Type: imagine_queue.ItemTypeImagine, 389 | DiscordInteraction: i.Interaction, 390 | }) 391 | if queueError != nil { 392 | log.Printf("Error adding imagine to queue: %v\n", queueError) 393 | } 394 | } 395 | 396 | err := s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ 397 | Type: discordgo.InteractionResponseChannelMessageWithSource, 398 | Data: &discordgo.InteractionResponseData{ 399 | Content: fmt.Sprintf( 400 | "I'm dreaming something up for you. You are currently #%d in line.\n<@%s> asked me to imagine \"%s\".", 401 | position, 402 | i.Member.User.ID, 403 | prompt), 404 | }, 405 | }) 406 | if err != nil { 407 | log.Printf("Error responding to interaction: %v", err) 408 | } 409 | } 410 | 411 | func settingsMessageComponents(settings *entities.DefaultSettings) []discordgo.MessageComponent { 412 | minValues := 1 413 | 414 | return []discordgo.MessageComponent{ 415 | discordgo.ActionsRow{ 416 | Components: []discordgo.MessageComponent{ 417 | discordgo.SelectMenu{ 418 | CustomID: "imagine_dimension_setting_menu", 419 | MinValues: &minValues, 420 | MaxValues: 1, 421 | Options: []discordgo.SelectMenuOption{ 422 | { 423 | Label: "Size: 512x512", 424 | Value: "512_512", 425 | Default: settings.Width == 512 && settings.Height == 512, 426 | }, 427 | { 428 | Label: "Size: 768x768", 429 | Value: "768_768", 430 | Default: settings.Width == 768 && settings.Height == 768, 431 | }, 432 | }, 433 | }, 434 | }, 435 | }, 436 | discordgo.ActionsRow{ 437 | Components: []discordgo.MessageComponent{ 438 | discordgo.SelectMenu{ 439 | CustomID: "imagine_batch_count_setting_menu", 440 | MinValues: &minValues, 441 | MaxValues: 1, 442 | Options: []discordgo.SelectMenuOption{ 443 | { 444 | Label: "Batch count: 1", 445 | Value: "1", 446 | Default: settings.BatchCount == 1, 447 | }, 448 | { 449 | Label: "Batch count: 2", 450 | Value: "2", 451 | Default: settings.BatchCount == 2, 452 | }, 453 | { 454 | Label: "Batch count: 4", 455 | Value: "4", 456 | Default: settings.BatchCount == 4, 457 | }, 458 | }, 459 | }, 460 | }, 461 | }, 462 | discordgo.ActionsRow{ 463 | Components: []discordgo.MessageComponent{ 464 | discordgo.SelectMenu{ 465 | CustomID: "imagine_batch_size_setting_menu", 466 | MinValues: &minValues, 467 | MaxValues: 1, 468 | Options: []discordgo.SelectMenuOption{ 469 | { 470 | Label: "Batch size: 1", 471 | Value: "1", 472 | Default: settings.BatchSize == 1, 473 | }, 474 | { 475 | Label: "Batch size: 2", 476 | Value: "2", 477 | Default: settings.BatchSize == 2, 478 | }, 479 | { 480 | Label: "Batch size: 4", 481 | Value: "4", 482 | Default: settings.BatchSize == 4, 483 | }, 484 | }, 485 | }, 486 | }, 487 | }, 488 | } 489 | } 490 | 491 | func (b *botImpl) processImagineSettingsCommand(s *discordgo.Session, i *discordgo.InteractionCreate) { 492 | botSettings, err := b.imagineQueue.GetBotDefaultSettings() 493 | if err != nil { 494 | log.Printf("error getting default settings for settings command: %v", err) 495 | 496 | return 497 | } 498 | 499 | messageComponents := settingsMessageComponents(botSettings) 500 | 501 | err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ 502 | Type: discordgo.InteractionResponseChannelMessageWithSource, 503 | Data: &discordgo.InteractionResponseData{ 504 | Title: "Settings", 505 | Content: "Choose defaults settings for the imagine command:", 506 | Components: messageComponents, 507 | }, 508 | }) 509 | if err != nil { 510 | log.Printf("Error responding to interaction: %v", err) 511 | } 512 | } 513 | 514 | func (b *botImpl) processImagineDimensionSetting(s *discordgo.Session, i *discordgo.InteractionCreate, height, width int) { 515 | botSettings, err := b.imagineQueue.UpdateDefaultDimensions(width, height) 516 | if err != nil { 517 | log.Printf("error updating default dimensions: %v", err) 518 | 519 | err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ 520 | Type: discordgo.InteractionResponseUpdateMessage, 521 | Data: &discordgo.InteractionResponseData{ 522 | Content: "Error updating default dimensions...", 523 | }, 524 | }) 525 | if err != nil { 526 | log.Printf("Error responding to interaction: %v", err) 527 | } 528 | 529 | return 530 | } 531 | 532 | messageComponents := settingsMessageComponents(botSettings) 533 | 534 | err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ 535 | Type: discordgo.InteractionResponseUpdateMessage, 536 | Data: &discordgo.InteractionResponseData{ 537 | Content: "Choose defaults settings for the imagine command:", 538 | Components: messageComponents, 539 | }, 540 | }) 541 | if err != nil { 542 | log.Printf("Error responding to interaction: %v", err) 543 | } 544 | } 545 | 546 | func (b *botImpl) processImagineBatchSetting(s *discordgo.Session, i *discordgo.InteractionCreate, batchCount, batchSize int) { 547 | botSettings, err := b.imagineQueue.UpdateDefaultBatch(batchCount, batchSize) 548 | if err != nil { 549 | log.Printf("error updating batch settings: %v", err) 550 | 551 | err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ 552 | Type: discordgo.InteractionResponseUpdateMessage, 553 | Data: &discordgo.InteractionResponseData{ 554 | Content: "Error updating batch settings...", 555 | }, 556 | }) 557 | if err != nil { 558 | log.Printf("Error responding to interaction: %v", err) 559 | } 560 | 561 | return 562 | } 563 | 564 | messageComponents := settingsMessageComponents(botSettings) 565 | 566 | err = s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ 567 | Type: discordgo.InteractionResponseUpdateMessage, 568 | Data: &discordgo.InteractionResponseData{ 569 | Content: "Choose defaults settings for the imagine command:", 570 | Components: messageComponents, 571 | }, 572 | }) 573 | if err != nil { 574 | log.Printf("Error responding to interaction: %v", err) 575 | } 576 | } 577 | -------------------------------------------------------------------------------- /imagine_queue/queue.go: -------------------------------------------------------------------------------- 1 | package imagine_queue 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/base64" 7 | "errors" 8 | "fmt" 9 | "log" 10 | "os" 11 | "os/signal" 12 | "regexp" 13 | "stable_diffusion_bot/composite_renderer" 14 | "stable_diffusion_bot/entities" 15 | "stable_diffusion_bot/repositories" 16 | "stable_diffusion_bot/repositories/default_settings" 17 | "stable_diffusion_bot/repositories/image_generations" 18 | "stable_diffusion_bot/stable_diffusion_api" 19 | "strconv" 20 | "strings" 21 | "sync" 22 | "time" 23 | 24 | "github.com/bwmarrin/discordgo" 25 | ) 26 | 27 | const ( 28 | botID = "bot" 29 | 30 | initializedWidth = 512 31 | initializedHeight = 512 32 | initializedBatchCount = 4 33 | initializedBatchSize = 1 34 | ) 35 | 36 | type queueImpl struct { 37 | botSession *discordgo.Session 38 | stableDiffusionAPI stable_diffusion_api.StableDiffusionAPI 39 | queue chan *QueueItem 40 | currentImagine *QueueItem 41 | mu sync.Mutex 42 | imageGenerationRepo image_generations.Repository 43 | compositeRenderer composite_renderer.Renderer 44 | defaultSettingsRepo default_settings.Repository 45 | botDefaultSettings *entities.DefaultSettings 46 | } 47 | 48 | type Config struct { 49 | StableDiffusionAPI stable_diffusion_api.StableDiffusionAPI 50 | ImageGenerationRepo image_generations.Repository 51 | DefaultSettingsRepo default_settings.Repository 52 | } 53 | 54 | func New(cfg Config) (Queue, error) { 55 | if cfg.StableDiffusionAPI == nil { 56 | return nil, errors.New("missing stable diffusion API") 57 | } 58 | 59 | if cfg.ImageGenerationRepo == nil { 60 | return nil, errors.New("missing image generation repository") 61 | } 62 | 63 | if cfg.DefaultSettingsRepo == nil { 64 | return nil, errors.New("missing default settings repository") 65 | } 66 | 67 | compositeRenderer, err := composite_renderer.New(composite_renderer.Config{}) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | return &queueImpl{ 73 | stableDiffusionAPI: cfg.StableDiffusionAPI, 74 | imageGenerationRepo: cfg.ImageGenerationRepo, 75 | queue: make(chan *QueueItem, 100), 76 | compositeRenderer: compositeRenderer, 77 | defaultSettingsRepo: cfg.DefaultSettingsRepo, 78 | }, nil 79 | } 80 | 81 | type ItemType int 82 | 83 | const ( 84 | ItemTypeImagine ItemType = iota 85 | ItemTypeReroll 86 | ItemTypeUpscale 87 | ItemTypeVariation 88 | ) 89 | 90 | type QueueItem struct { 91 | Prompt string 92 | Type ItemType 93 | InteractionIndex int 94 | DiscordInteraction *discordgo.Interaction 95 | } 96 | 97 | func (q *queueImpl) AddImagine(item *QueueItem) (int, error) { 98 | q.queue <- item 99 | 100 | linePosition := len(q.queue) 101 | 102 | return linePosition, nil 103 | } 104 | 105 | func (q *queueImpl) StartPolling(botSession *discordgo.Session) { 106 | q.botSession = botSession 107 | 108 | botDefaultSettings, err := q.initializeOrGetBotDefaults() 109 | if err != nil { 110 | log.Printf("Error getting/initializing bot default settings: %v", err) 111 | 112 | return 113 | } 114 | 115 | q.botDefaultSettings = botDefaultSettings 116 | 117 | log.Println("Press Ctrl+C to exit") 118 | 119 | stop := make(chan os.Signal, 1) 120 | signal.Notify(stop, os.Interrupt) 121 | 122 | stopPolling := false 123 | 124 | for { 125 | select { 126 | case <-stop: 127 | stopPolling = true 128 | case <-time.After(1 * time.Second): 129 | if q.currentImagine == nil { 130 | q.pullNextInQueue() 131 | } 132 | } 133 | 134 | if stopPolling { 135 | break 136 | } 137 | } 138 | 139 | log.Printf("Polling stopped...\n") 140 | } 141 | 142 | func (q *queueImpl) pullNextInQueue() { 143 | if len(q.queue) > 0 { 144 | element := <-q.queue 145 | 146 | q.mu.Lock() 147 | defer q.mu.Unlock() 148 | 149 | q.currentImagine = element 150 | 151 | q.processCurrentImagine() 152 | } 153 | } 154 | 155 | func (q *queueImpl) fillInBotDefaults(settings *entities.DefaultSettings) (*entities.DefaultSettings, bool) { 156 | updated := false 157 | 158 | if settings == nil { 159 | settings = &entities.DefaultSettings{ 160 | MemberID: botID, 161 | } 162 | } 163 | 164 | if settings.Width == 0 { 165 | settings.Width = initializedWidth 166 | updated = true 167 | } 168 | 169 | if settings.Height == 0 { 170 | settings.Height = initializedHeight 171 | updated = true 172 | } 173 | 174 | if settings.BatchCount == 0 { 175 | settings.BatchCount = initializedBatchCount 176 | updated = true 177 | } 178 | 179 | if settings.BatchSize == 0 { 180 | settings.BatchSize = initializedBatchSize 181 | updated = true 182 | } 183 | 184 | return settings, updated 185 | } 186 | 187 | func (q *queueImpl) initializeOrGetBotDefaults() (*entities.DefaultSettings, error) { 188 | botDefaultSettings, err := q.GetBotDefaultSettings() 189 | if err != nil && !errors.Is(err, &repositories.NotFoundError{}) { 190 | return nil, err 191 | } 192 | 193 | botDefaultSettings, updated := q.fillInBotDefaults(botDefaultSettings) 194 | if updated { 195 | botDefaultSettings, err = q.defaultSettingsRepo.Upsert(context.Background(), botDefaultSettings) 196 | if err != nil { 197 | return nil, err 198 | } 199 | 200 | log.Printf("Initialized bot default settings: %+v\n", botDefaultSettings) 201 | } else { 202 | log.Printf("Retrieved bot default settings: %+v\n", botDefaultSettings) 203 | } 204 | 205 | return botDefaultSettings, nil 206 | } 207 | 208 | func (q *queueImpl) GetBotDefaultSettings() (*entities.DefaultSettings, error) { 209 | if q.botDefaultSettings != nil { 210 | return q.botDefaultSettings, nil 211 | } 212 | 213 | defaultSettings, err := q.defaultSettingsRepo.GetByMemberID(context.Background(), botID) 214 | if err != nil { 215 | return nil, err 216 | } 217 | 218 | q.botDefaultSettings = defaultSettings 219 | 220 | return defaultSettings, nil 221 | } 222 | 223 | func (q *queueImpl) defaultWidth() (int, error) { 224 | defaultSettings, err := q.GetBotDefaultSettings() 225 | if err != nil { 226 | return 0, err 227 | } 228 | 229 | return defaultSettings.Width, nil 230 | } 231 | 232 | func (q *queueImpl) defaultHeight() (int, error) { 233 | defaultSettings, err := q.GetBotDefaultSettings() 234 | if err != nil { 235 | return 0, err 236 | } 237 | 238 | return defaultSettings.Height, nil 239 | } 240 | 241 | func (q *queueImpl) defaultBatchCount() (int, error) { 242 | defaultSettings, err := q.GetBotDefaultSettings() 243 | if err != nil { 244 | return 0, err 245 | } 246 | 247 | return defaultSettings.BatchCount, nil 248 | } 249 | 250 | func (q *queueImpl) defaultBatchSize() (int, error) { 251 | defaultSettings, err := q.GetBotDefaultSettings() 252 | if err != nil { 253 | return 0, err 254 | } 255 | 256 | return defaultSettings.BatchSize, nil 257 | } 258 | 259 | func (q *queueImpl) UpdateDefaultDimensions(width, height int) (*entities.DefaultSettings, error) { 260 | defaultSettings, err := q.GetBotDefaultSettings() 261 | if err != nil { 262 | return nil, err 263 | } 264 | 265 | defaultSettings.Width = width 266 | defaultSettings.Height = height 267 | 268 | newDefaultSettings, err := q.defaultSettingsRepo.Upsert(context.Background(), defaultSettings) 269 | if err != nil { 270 | return nil, err 271 | } 272 | 273 | q.botDefaultSettings = newDefaultSettings 274 | 275 | log.Printf("Updated default dimensions to: %dx%d\n", width, height) 276 | 277 | return newDefaultSettings, nil 278 | } 279 | 280 | func (q *queueImpl) UpdateDefaultBatch(batchCount, batchSize int) (*entities.DefaultSettings, error) { 281 | defaultSettings, err := q.GetBotDefaultSettings() 282 | if err != nil { 283 | return nil, err 284 | } 285 | 286 | defaultSettings.BatchCount = batchCount 287 | defaultSettings.BatchSize = batchSize 288 | 289 | newDefaultSettings, err := q.defaultSettingsRepo.Upsert(context.Background(), defaultSettings) 290 | if err != nil { 291 | return nil, err 292 | } 293 | 294 | q.botDefaultSettings = newDefaultSettings 295 | 296 | log.Printf("Updated default batch count/size to: %d/%d\n", batchCount, batchSize) 297 | 298 | return newDefaultSettings, nil 299 | } 300 | 301 | type dimensionsResult struct { 302 | SanitizedPrompt string 303 | Width int 304 | Height int 305 | } 306 | 307 | const ( 308 | emdash = '\u2014' 309 | hyphen = '\u002D' 310 | ) 311 | 312 | func fixEmDash(prompt string) string { 313 | return strings.ReplaceAll(prompt, string(emdash), string(hyphen)+string(hyphen)) 314 | } 315 | 316 | var arRegex = regexp.MustCompile(`\s?--ar ([\d]*):([\d]*)\s?`) 317 | 318 | func extractDimensionsFromPrompt(prompt string, width, height int) (*dimensionsResult, error) { 319 | // Sanitize em dashes. Some phones will autocorrect to em dashes 320 | prompt = fixEmDash(prompt) 321 | 322 | arMatches := arRegex.FindStringSubmatch(prompt) 323 | 324 | if len(arMatches) == 3 { 325 | log.Printf("Aspect ratio overwrite: %#v", arMatches) 326 | 327 | prompt = arRegex.ReplaceAllString(prompt, "") 328 | 329 | firstDimension, err := strconv.Atoi(arMatches[1]) 330 | if err != nil { 331 | return nil, err 332 | } 333 | 334 | secondDimension, err := strconv.Atoi(arMatches[2]) 335 | if err != nil { 336 | return nil, err 337 | } 338 | 339 | if firstDimension > secondDimension { 340 | scaledWidth := float64(height) * (float64(firstDimension) / float64(secondDimension)) 341 | 342 | // Round up to the nearest 8 343 | width = (int(scaledWidth) + 7) & (-8) 344 | } else if secondDimension > firstDimension { 345 | scaledHeight := float64(width) * (float64(secondDimension) / float64(firstDimension)) 346 | 347 | // Round up to the nearest 8 348 | height = (int(scaledHeight) + 7) & (-8) 349 | } 350 | 351 | log.Printf("New dimensions: width: %v, height: %v", width, height) 352 | } 353 | 354 | return &dimensionsResult{ 355 | SanitizedPrompt: prompt, 356 | Width: width, 357 | Height: height, 358 | }, nil 359 | } 360 | 361 | func (q *queueImpl) processCurrentImagine() { 362 | go func() { 363 | defer func() { 364 | q.mu.Lock() 365 | defer q.mu.Unlock() 366 | 367 | q.currentImagine = nil 368 | }() 369 | 370 | if q.currentImagine.Type == ItemTypeUpscale { 371 | q.processUpscaleImagine(q.currentImagine) 372 | 373 | return 374 | } 375 | 376 | defaultWidth, err := q.defaultWidth() 377 | if err != nil { 378 | log.Printf("Error getting default width: %v", err) 379 | 380 | return 381 | } 382 | 383 | defaultHeight, err := q.defaultHeight() 384 | if err != nil { 385 | log.Printf("Error getting default height: %v", err) 386 | 387 | return 388 | } 389 | 390 | promptRes, err := extractDimensionsFromPrompt(q.currentImagine.Prompt, defaultWidth, defaultHeight) 391 | if err != nil { 392 | log.Printf("Error extracting dimensions from prompt: %v", err) 393 | 394 | return 395 | } 396 | 397 | enableHR := false 398 | hiresWidth := 0 399 | hiresHeight := 0 400 | 401 | if promptRes.Width > defaultWidth || promptRes.Height > defaultHeight { 402 | enableHR = true 403 | hiresWidth = promptRes.Width 404 | hiresHeight = promptRes.Height 405 | } 406 | 407 | // new generation with defaults 408 | newGeneration := &entities.ImageGeneration{ 409 | Prompt: promptRes.SanitizedPrompt, 410 | NegativePrompt: "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, " + 411 | "mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, " + 412 | "body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy", 413 | Width: defaultWidth, 414 | Height: defaultHeight, 415 | RestoreFaces: true, 416 | EnableHR: enableHR, 417 | HiresWidth: hiresWidth, 418 | HiresHeight: hiresHeight, 419 | DenoisingStrength: 0.7, 420 | Seed: -1, 421 | Subseed: -1, 422 | SubseedStrength: 0, 423 | SamplerName: "Euler a", 424 | CfgScale: 9, 425 | Steps: 20, 426 | Processed: false, 427 | } 428 | 429 | if q.currentImagine.Type == ItemTypeReroll || q.currentImagine.Type == ItemTypeVariation { 430 | foundGeneration, err := q.getPreviousGeneration(q.currentImagine, q.currentImagine.InteractionIndex) 431 | if err != nil { 432 | log.Printf("Error getting prompt for reroll: %v", err) 433 | 434 | return 435 | } 436 | 437 | // if we are rerolling, or generating variations, we simply replace some defaults 438 | newGeneration = foundGeneration 439 | 440 | // for variations, we need random subseeds 441 | newGeneration.Subseed = -1 442 | 443 | // for variations, the subseed strength determines how much variation we get 444 | if q.currentImagine.Type == ItemTypeVariation { 445 | newGeneration.SubseedStrength = 0.15 446 | } 447 | } 448 | 449 | err = q.processImagineGrid(newGeneration, q.currentImagine) 450 | if err != nil { 451 | log.Printf("Error processing imagine grid: %v", err) 452 | 453 | return 454 | } 455 | }() 456 | } 457 | 458 | func (q *queueImpl) getPreviousGeneration(imagine *QueueItem, sortOrder int) (*entities.ImageGeneration, error) { 459 | interactionID := imagine.DiscordInteraction.ID 460 | messageID := "" 461 | 462 | if imagine.DiscordInteraction.Message != nil { 463 | messageID = imagine.DiscordInteraction.Message.ID 464 | } 465 | 466 | log.Printf("Reimagining interaction: %v, Message: %v", interactionID, messageID) 467 | 468 | generation, err := q.imageGenerationRepo.GetByMessageAndSort(context.Background(), messageID, sortOrder) 469 | if err != nil { 470 | log.Printf("Error getting image generation: %v", err) 471 | 472 | return nil, err 473 | } 474 | 475 | log.Printf("Found generation: %v", generation) 476 | 477 | return generation, nil 478 | } 479 | 480 | func imagineMessageContent(generation *entities.ImageGeneration, user *discordgo.User, progress float64) string { 481 | if progress >= 0 && progress < 1 { 482 | return fmt.Sprintf("<@%s> asked me to imagine \"%s\". Currently dreaming it up for them. Progress: %.0f%%", 483 | user.ID, generation.Prompt, progress*100) 484 | } else { 485 | return fmt.Sprintf("<@%s> asked me to imagine \"%s\", here is what I imagined for them.", 486 | user.ID, 487 | generation.Prompt, 488 | ) 489 | } 490 | } 491 | 492 | func (q *queueImpl) processImagineGrid(newGeneration *entities.ImageGeneration, imagine *QueueItem) error { 493 | log.Printf("Processing imagine #%s: %v\n", imagine.DiscordInteraction.ID, newGeneration.Prompt) 494 | 495 | newContent := imagineMessageContent(newGeneration, imagine.DiscordInteraction.Member.User, 0) 496 | 497 | message, err := q.botSession.InteractionResponseEdit(imagine.DiscordInteraction, &discordgo.WebhookEdit{ 498 | Content: &newContent, 499 | }) 500 | if err != nil { 501 | log.Printf("Error editing interaction: %v", err) 502 | } 503 | 504 | defaultBatchCount, err := q.defaultBatchCount() 505 | if err != nil { 506 | log.Printf("Error getting default batch count: %v", err) 507 | 508 | return err 509 | } 510 | 511 | defaultBatchSize, err := q.defaultBatchSize() 512 | if err != nil { 513 | log.Printf("Error getting default batch size: %v", err) 514 | 515 | return err 516 | } 517 | 518 | newGeneration.InteractionID = imagine.DiscordInteraction.ID 519 | newGeneration.MessageID = message.ID 520 | newGeneration.MemberID = imagine.DiscordInteraction.Member.User.ID 521 | newGeneration.SortOrder = 0 522 | newGeneration.BatchCount = defaultBatchCount 523 | newGeneration.BatchSize = defaultBatchSize 524 | newGeneration.Processed = true 525 | 526 | _, err = q.imageGenerationRepo.Create(context.Background(), newGeneration) 527 | if err != nil { 528 | log.Printf("Error creating image generation record: %v\n", err) 529 | } 530 | 531 | generationDone := make(chan bool) 532 | 533 | go func() { 534 | for { 535 | select { 536 | case <-generationDone: 537 | return 538 | case <-time.After(1 * time.Second): 539 | progress, progressErr := q.stableDiffusionAPI.GetCurrentProgress() 540 | if progressErr != nil { 541 | log.Printf("Error getting current progress: %v", progressErr) 542 | 543 | return 544 | } 545 | 546 | if progress.Progress == 0 { 547 | continue 548 | } 549 | 550 | progressContent := imagineMessageContent(newGeneration, imagine.DiscordInteraction.Member.User, progress.Progress) 551 | 552 | _, progressErr = q.botSession.InteractionResponseEdit(imagine.DiscordInteraction, &discordgo.WebhookEdit{ 553 | Content: &progressContent, 554 | }) 555 | if progressErr != nil { 556 | log.Printf("Error editing interaction: %v", err) 557 | } 558 | } 559 | } 560 | }() 561 | 562 | resp, err := q.stableDiffusionAPI.TextToImage(&stable_diffusion_api.TextToImageRequest{ 563 | Prompt: newGeneration.Prompt, 564 | NegativePrompt: newGeneration.NegativePrompt, 565 | Width: newGeneration.Width, 566 | Height: newGeneration.Height, 567 | RestoreFaces: newGeneration.RestoreFaces, 568 | EnableHR: newGeneration.EnableHR, 569 | HRResizeX: newGeneration.HiresWidth, 570 | HRResizeY: newGeneration.HiresHeight, 571 | DenoisingStrength: newGeneration.DenoisingStrength, 572 | BatchSize: newGeneration.BatchSize, 573 | Seed: newGeneration.Seed, 574 | Subseed: newGeneration.Subseed, 575 | SubseedStrength: newGeneration.SubseedStrength, 576 | SamplerName: newGeneration.SamplerName, 577 | CfgScale: newGeneration.CfgScale, 578 | Steps: newGeneration.Steps, 579 | NIter: newGeneration.BatchCount, 580 | }) 581 | if err != nil { 582 | log.Printf("Error processing image: %v\n", err) 583 | 584 | errorContent := "I'm sorry, but I had a problem imagining your image." 585 | 586 | _, err = q.botSession.InteractionResponseEdit(imagine.DiscordInteraction, &discordgo.WebhookEdit{ 587 | Content: &errorContent, 588 | }) 589 | 590 | return err 591 | } 592 | 593 | generationDone <- true 594 | 595 | finishedContent := imagineMessageContent(newGeneration, imagine.DiscordInteraction.Member.User, 1) 596 | 597 | log.Printf("Seeds: %v Subseeds:%v", resp.Seeds, resp.Subseeds) 598 | 599 | imageBufs := make([]*bytes.Buffer, len(resp.Images)) 600 | 601 | for idx, image := range resp.Images { 602 | decodedImage, decodeErr := base64.StdEncoding.DecodeString(image) 603 | if decodeErr != nil { 604 | log.Printf("Error decoding image: %v\n", decodeErr) 605 | } 606 | 607 | imageBuf := bytes.NewBuffer(decodedImage) 608 | 609 | imageBufs[idx] = imageBuf 610 | } 611 | 612 | for idx := range resp.Seeds { 613 | subGeneration := &entities.ImageGeneration{ 614 | InteractionID: newGeneration.InteractionID, 615 | MessageID: newGeneration.MessageID, 616 | MemberID: newGeneration.MemberID, 617 | SortOrder: idx + 1, 618 | Prompt: newGeneration.Prompt, 619 | NegativePrompt: newGeneration.NegativePrompt, 620 | Width: newGeneration.Width, 621 | Height: newGeneration.Height, 622 | RestoreFaces: newGeneration.RestoreFaces, 623 | EnableHR: newGeneration.EnableHR, 624 | HiresWidth: newGeneration.HiresWidth, 625 | HiresHeight: newGeneration.HiresHeight, 626 | DenoisingStrength: newGeneration.DenoisingStrength, 627 | BatchCount: newGeneration.BatchCount, 628 | BatchSize: newGeneration.BatchSize, 629 | Seed: resp.Seeds[idx], 630 | Subseed: resp.Subseeds[idx], 631 | SubseedStrength: newGeneration.SubseedStrength, 632 | SamplerName: newGeneration.SamplerName, 633 | CfgScale: newGeneration.CfgScale, 634 | Steps: newGeneration.Steps, 635 | Processed: true, 636 | } 637 | 638 | _, createErr := q.imageGenerationRepo.Create(context.Background(), subGeneration) 639 | if createErr != nil { 640 | log.Printf("Error creating image generation record: %v\n", createErr) 641 | } 642 | } 643 | 644 | compositeImage, err := q.compositeRenderer.TileImages(imageBufs) 645 | if err != nil { 646 | log.Printf("Error tiling images: %v\n", err) 647 | 648 | return err 649 | } 650 | 651 | _, err = q.botSession.InteractionResponseEdit(imagine.DiscordInteraction, &discordgo.WebhookEdit{ 652 | Content: &finishedContent, 653 | Files: []*discordgo.File{ 654 | { 655 | ContentType: "image/png", 656 | Name: "imagine.png", 657 | Reader: compositeImage, 658 | }, 659 | }, 660 | Components: &[]discordgo.MessageComponent{ 661 | discordgo.ActionsRow{ 662 | Components: []discordgo.MessageComponent{ 663 | discordgo.Button{ 664 | // Label is what the user will see on the button. 665 | Label: "Re-roll", 666 | // Style provides coloring of the button. There are not so many styles tho. 667 | Style: discordgo.PrimaryButton, 668 | // Disabled allows bot to disable some buttons for users. 669 | Disabled: false, 670 | // CustomID is a thing telling Discord which data to send when this button will be pressed. 671 | CustomID: "imagine_reroll", 672 | Emoji: discordgo.ComponentEmoji{ 673 | Name: "🎲", 674 | }, 675 | }, 676 | discordgo.Button{ 677 | // Label is what the user will see on the button. 678 | Label: "V1", 679 | // Style provides coloring of the button. There are not so many styles tho. 680 | Style: discordgo.SecondaryButton, 681 | // Disabled allows bot to disable some buttons for users. 682 | Disabled: false, 683 | // CustomID is a thing telling Discord which data to send when this button will be pressed. 684 | CustomID: "imagine_variation_1", 685 | Emoji: discordgo.ComponentEmoji{ 686 | Name: "♻️", 687 | }, 688 | }, 689 | discordgo.Button{ 690 | // Label is what the user will see on the button. 691 | Label: "V2", 692 | // Style provides coloring of the button. There are not so many styles tho. 693 | Style: discordgo.SecondaryButton, 694 | // Disabled allows bot to disable some buttons for users. 695 | Disabled: false, 696 | // CustomID is a thing telling Discord which data to send when this button will be pressed. 697 | CustomID: "imagine_variation_2", 698 | Emoji: discordgo.ComponentEmoji{ 699 | Name: "♻️", 700 | }, 701 | }, 702 | discordgo.Button{ 703 | // Label is what the user will see on the button. 704 | Label: "V3", 705 | // Style provides coloring of the button. There are not so many styles tho. 706 | Style: discordgo.SecondaryButton, 707 | // Disabled allows bot to disable some buttons for users. 708 | Disabled: false, 709 | // CustomID is a thing telling Discord which data to send when this button will be pressed. 710 | CustomID: "imagine_variation_3", 711 | Emoji: discordgo.ComponentEmoji{ 712 | Name: "♻️", 713 | }, 714 | }, 715 | discordgo.Button{ 716 | // Label is what the user will see on the button. 717 | Label: "V4", 718 | // Style provides coloring of the button. There are not so many styles tho. 719 | Style: discordgo.SecondaryButton, 720 | // Disabled allows bot to disable some buttons for users. 721 | Disabled: false, 722 | // CustomID is a thing telling Discord which data to send when this button will be pressed. 723 | CustomID: "imagine_variation_4", 724 | Emoji: discordgo.ComponentEmoji{ 725 | Name: "♻️", 726 | }, 727 | }, 728 | }, 729 | }, 730 | discordgo.ActionsRow{ 731 | Components: []discordgo.MessageComponent{ 732 | discordgo.Button{ 733 | // Label is what the user will see on the button. 734 | Label: "U1", 735 | // Style provides coloring of the button. There are not so many styles tho. 736 | Style: discordgo.SecondaryButton, 737 | // Disabled allows bot to disable some buttons for users. 738 | Disabled: false, 739 | // CustomID is a thing telling Discord which data to send when this button will be pressed. 740 | CustomID: "imagine_upscale_1", 741 | Emoji: discordgo.ComponentEmoji{ 742 | Name: "⬆️", 743 | }, 744 | }, 745 | discordgo.Button{ 746 | // Label is what the user will see on the button. 747 | Label: "U2", 748 | // Style provides coloring of the button. There are not so many styles tho. 749 | Style: discordgo.SecondaryButton, 750 | // Disabled allows bot to disable some buttons for users. 751 | Disabled: false, 752 | // CustomID is a thing telling Discord which data to send when this button will be pressed. 753 | CustomID: "imagine_upscale_2", 754 | Emoji: discordgo.ComponentEmoji{ 755 | Name: "⬆️", 756 | }, 757 | }, 758 | discordgo.Button{ 759 | // Label is what the user will see on the button. 760 | Label: "U3", 761 | // Style provides coloring of the button. There are not so many styles tho. 762 | Style: discordgo.SecondaryButton, 763 | // Disabled allows bot to disable some buttons for users. 764 | Disabled: false, 765 | // CustomID is a thing telling Discord which data to send when this button will be pressed. 766 | CustomID: "imagine_upscale_3", 767 | Emoji: discordgo.ComponentEmoji{ 768 | Name: "⬆️", 769 | }, 770 | }, 771 | discordgo.Button{ 772 | // Label is what the user will see on the button. 773 | Label: "U4", 774 | // Style provides coloring of the button. There are not so many styles tho. 775 | Style: discordgo.SecondaryButton, 776 | // Disabled allows bot to disable some buttons for users. 777 | Disabled: false, 778 | // CustomID is a thing telling Discord which data to send when this button will be pressed. 779 | CustomID: "imagine_upscale_4", 780 | Emoji: discordgo.ComponentEmoji{ 781 | Name: "⬆️", 782 | }, 783 | }, 784 | }, 785 | }, 786 | }, 787 | }) 788 | if err != nil { 789 | log.Printf("Error editing interaction: %v\n", err) 790 | 791 | return err 792 | } 793 | 794 | return nil 795 | } 796 | 797 | func upscaleMessageContent(user *discordgo.User, fetchProgress, upscaleProgress float64) string { 798 | if fetchProgress >= 0 && fetchProgress <= 1 && upscaleProgress < 1 { 799 | if upscaleProgress == 0 { 800 | return fmt.Sprintf("Currently upscaling the image for you... Fetch progress: %.0f%%", fetchProgress*100) 801 | } else { 802 | return fmt.Sprintf("Currently upscaling the image for you... Fetch progress: %.0f%% Upscale progress: %.0f%%", 803 | fetchProgress*100, upscaleProgress*100) 804 | } 805 | } else { 806 | return fmt.Sprintf("<@%s> asked me to upscale their image. Here's the result:", 807 | user.ID) 808 | } 809 | } 810 | 811 | func (q *queueImpl) processUpscaleImagine(imagine *QueueItem) { 812 | interactionID := imagine.DiscordInteraction.ID 813 | messageID := "" 814 | 815 | if imagine.DiscordInteraction.Message != nil { 816 | messageID = imagine.DiscordInteraction.Message.ID 817 | } 818 | 819 | log.Printf("Upscaling image: %v, Message: %v, Upscale Index: %d", 820 | interactionID, messageID, imagine.InteractionIndex) 821 | 822 | generation, err := q.imageGenerationRepo.GetByMessageAndSort(context.Background(), messageID, imagine.InteractionIndex) 823 | if err != nil { 824 | log.Printf("Error getting image generation: %v", err) 825 | 826 | return 827 | } 828 | 829 | log.Printf("Found generation: %v", generation) 830 | 831 | newContent := upscaleMessageContent(imagine.DiscordInteraction.Member.User, 0, 0) 832 | 833 | _, err = q.botSession.InteractionResponseEdit(imagine.DiscordInteraction, &discordgo.WebhookEdit{ 834 | Content: &newContent, 835 | }) 836 | if err != nil { 837 | log.Printf("Error editing interaction: %v", err) 838 | } 839 | 840 | generationDone := make(chan bool) 841 | 842 | go func() { 843 | lastProgress := float64(0) 844 | fetchProgress := float64(0) 845 | upscaleProgress := float64(0) 846 | 847 | for { 848 | select { 849 | case <-generationDone: 850 | return 851 | case <-time.After(1 * time.Second): 852 | progress, progressErr := q.stableDiffusionAPI.GetCurrentProgress() 853 | if progressErr != nil { 854 | log.Printf("Error getting current progress: %v", progressErr) 855 | 856 | return 857 | } 858 | 859 | if progress.Progress == 0 { 860 | continue 861 | } 862 | 863 | if progress.Progress < lastProgress || upscaleProgress > 0 { 864 | upscaleProgress = progress.Progress 865 | fetchProgress = 1 866 | } else { 867 | fetchProgress = progress.Progress 868 | } 869 | 870 | lastProgress = progress.Progress 871 | 872 | progressContent := upscaleMessageContent(imagine.DiscordInteraction.Member.User, fetchProgress, upscaleProgress) 873 | 874 | _, progressErr = q.botSession.InteractionResponseEdit(imagine.DiscordInteraction, &discordgo.WebhookEdit{ 875 | Content: &progressContent, 876 | }) 877 | if progressErr != nil { 878 | log.Printf("Error editing interaction: %v", err) 879 | } 880 | } 881 | } 882 | }() 883 | 884 | resp, err := q.stableDiffusionAPI.UpscaleImage(&stable_diffusion_api.UpscaleRequest{ 885 | ResizeMode: 0, 886 | UpscalingResize: 2, 887 | Upscaler1: "ESRGAN_4x", 888 | TextToImageRequest: &stable_diffusion_api.TextToImageRequest{ 889 | Prompt: generation.Prompt, 890 | NegativePrompt: generation.NegativePrompt, 891 | Width: generation.Width, 892 | Height: generation.Height, 893 | RestoreFaces: generation.RestoreFaces, 894 | EnableHR: generation.EnableHR, 895 | HRResizeX: generation.HiresWidth, 896 | HRResizeY: generation.HiresHeight, 897 | DenoisingStrength: generation.DenoisingStrength, 898 | BatchSize: 1, 899 | Seed: generation.Seed, 900 | Subseed: generation.Subseed, 901 | SubseedStrength: generation.SubseedStrength, 902 | SamplerName: generation.SamplerName, 903 | CfgScale: generation.CfgScale, 904 | Steps: generation.Steps, 905 | NIter: 1, 906 | }, 907 | }) 908 | if err != nil { 909 | log.Printf("Error processing image upscale: %v\n", err) 910 | 911 | errorContent := "I'm sorry, but I had a problem upscaling your image." 912 | 913 | _, err = q.botSession.InteractionResponseEdit(imagine.DiscordInteraction, &discordgo.WebhookEdit{ 914 | Content: &errorContent, 915 | }) 916 | 917 | return 918 | } 919 | 920 | generationDone <- true 921 | 922 | decodedImage, decodeErr := base64.StdEncoding.DecodeString(resp.Image) 923 | if decodeErr != nil { 924 | log.Printf("Error decoding image: %v\n", decodeErr) 925 | 926 | return 927 | } 928 | 929 | imageBuf := bytes.NewBuffer(decodedImage) 930 | 931 | log.Printf("Successfully upscaled image: %v, Message: %v, Upscale Index: %d", 932 | interactionID, messageID, imagine.InteractionIndex) 933 | 934 | finishedContent := fmt.Sprintf("<@%s> asked me to upscale their image. Here's the result:", 935 | imagine.DiscordInteraction.Member.User.ID) 936 | 937 | _, err = q.botSession.InteractionResponseEdit(imagine.DiscordInteraction, &discordgo.WebhookEdit{ 938 | Content: &finishedContent, 939 | Files: []*discordgo.File{ 940 | { 941 | ContentType: "image/png", 942 | Name: "imagine.png", 943 | Reader: imageBuf, 944 | }, 945 | }, 946 | }) 947 | if err != nil { 948 | log.Printf("Error editing interaction: %v\n", err) 949 | 950 | return 951 | } 952 | } 953 | --------------------------------------------------------------------------------