├── .gitignore ├── API.md ├── Makefile ├── README.md ├── cmd └── civitai-downloader │ ├── cmd │ ├── clean.go │ ├── cmd_download_api.go │ ├── cmd_download_processing.go │ ├── cmd_download_setup.go │ ├── cmd_download_types.go │ ├── cmd_download_worker.go │ ├── cmd_images_run.go │ ├── cmd_images_setup.go │ ├── cmd_images_worker.go │ ├── cmd_search_images.go │ ├── cmd_search_models.go │ ├── db.go │ ├── download.go │ ├── images.go │ ├── root.go │ ├── search.go │ ├── search_logic.go │ └── torrent.go │ ├── main.go │ └── main_integration_test.go ├── config.toml.example ├── go.mod ├── go.sum ├── index └── index.go └── internal ├── api ├── client.go └── logging_transport.go ├── config └── config.go ├── database └── bitcask.go ├── downloader └── downloader.go ├── helpers ├── helpers.go └── helpers_test.go └── models └── models.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries 2 | /civitai-downloader 3 | 4 | # Log files 5 | *.log 6 | api.log 7 | 8 | # OS-specific files 9 | .DS_Store 10 | 11 | # IDE/Editor directories 12 | .vscode/ 13 | .idea/ 14 | 15 | # Database files (Bitcask might create related files) 16 | /civitai.db 17 | /test.db 18 | 19 | # Downloaded models and metadata 20 | downloads/ 21 | 22 | config.toml 23 | 24 | /release 25 | -------------------------------------------------------------------------------- /API.md: -------------------------------------------------------------------------------- 1 | Introduction 2 | 3 | This article describes how to use the Civitai REST API. We are going to be describing the HTTP method, path, and parameters for every operation. The API will return the response status code, response headers, and a response body. 4 | 5 | This is still in active development and will be updated once more endpoints are made available for the public 6 | 7 | Civitai API v1 8 | 9 | Authorization 10 | 11 | Creators 12 | 13 | GET /api/v1/creators 14 | 15 | Images 16 | 17 | GET /api/v1/images 18 | 19 | Models 20 | 21 | GET /api/v1/models 22 | 23 | Model 24 | 25 | GET /api/v1/models/:modelId 26 | 27 | Model Version 28 | 29 | GET /api/v1/model-versions/:modelVersionId 30 | GET /api/v1/model-versions/by-hash/:hash 31 | 32 | Tags 33 | 34 | GET /api/v1/tags 35 | 36 | Authorization 37 | 38 | To make authorized requests as a user you must use an API Key. You can generate an API Key from your User Account Settings. 39 | 40 | Once you have an API Key you can authenticate with either an Authorization Header or Query String. 41 | 42 | Creators can require that people be logged in to download their resources. That is an option we provide but not something we require – it's entirely up to the resource owner. 43 | 44 | Please see the Guide to Downloading via API for more details and open an issue if you are still having trouble downloading. 45 | Authorization Header 46 | 47 | You can pass the API token as a Bearer token using the Authorization header: 48 | 49 | GET https://civitai.com/api/v1/models 50 | Content-Type: application/json 51 | Authorization: Bearer {api_key} 52 | 53 | Query String 54 | 55 | You can pass the API token as a query parameter using the ?token= parameter: 56 | 57 | GET https://civitai.com/api/v1/models?token={api_key} 58 | Content-Type: application/json 59 | 60 | This method may be easier in some notebooks and scripts. 61 | GET /api/v1/creators 62 | Endpoint URL 63 | 64 | https://civitai.com/api/v1/creators 65 | Query Parameters 66 | Name Type Description 67 | limit (OPTIONAL) number The number of results to be returned per page. This can be a number between 0 and 200. By default, each page will return 20 results. If set to 0, it'll return all the creators 68 | page (OPTIONAL) number The page from which to start fetching creators 69 | query (OPTIONAL) string Search query to filter creators by username 70 | Response Fields 71 | Name Type Description 72 | username string The username of the creator 73 | modelCount number The amount of models linked to this user 74 | link string Url to get all models from this user 75 | metadata.totalItems string The total number of items available 76 | metadata.currentPage string The the current page you are at 77 | metadata.pageSize string The the size of the batch 78 | metadata.totalPages string The total number of pages 79 | metadata.nextPage string The url to get the next batch of items 80 | metadata.prevPage string The url to get the previous batch of items 81 | Example 82 | 83 | The following example shows a request to get the first 3 model tags from our database: 84 | 85 | curl https://civitai.com/api/v1/creators?limit=3 \ 86 | -H "Content-Type: application/json" \ 87 | -X GET 88 | 89 | This would yield the following response: 90 | 91 | { 92 | "items": [ 93 | { 94 | "username": "Civitai", 95 | "modelCount": 848, 96 | "link": "https://civitai.com/api/v1/models?username=Civitai" 97 | }, 98 | { 99 | "username": "JustMaier", 100 | "modelCount": 8, 101 | "link": "https://civitai.com/api/v1/models?username=JustMaier" 102 | }, 103 | { 104 | "username": "maxhulker", 105 | "modelCount": 2, 106 | "link": "https://civitai.com/api/v1/models?username=maxhulker" 107 | } 108 | ], 109 | "metadata": { 110 | "totalItems": 46, 111 | "currentPage": 1, 112 | "pageSize": 3, 113 | "totalPages": 16, 114 | "nextPage": "https://civitai.com/api/v1/creators?limit=3&page=2" 115 | } 116 | } 117 | 118 | GET /api/v1/images 119 | Endpoint URL 120 | 121 | https://civitai.com/api/v1/images 122 | Query Parameters 123 | Name Type Description 124 | limit (OPTIONAL) number The number of results to be returned per page. This can be a number between 0 and 200. By default, each page will return 100 results. 125 | postId (OPTIONAL) number The ID of a post to get images from 126 | modelId (OPTIONAL) number The ID of a model to get images from (model gallery) 127 | modelVersionId (OPTIONAL) number The ID of a model version to get images from (model gallery filtered to version) 128 | username (OPTIONAL) string Filter to images from a specific user 129 | nsfw (OPTIONAL) boolean | enum (None, Soft, Mature, X) Filter to images that contain mature content flags or not (undefined returns all) 130 | sort (OPTIONAL) enum (Most Reactions, Most Comments, Newest) The order in which you wish to sort the results 131 | period (OPTIONAL) enum (AllTime, Year, Month, Week, Day) The time frame in which the images will be sorted 132 | page (OPTIONAL) number The page from which to start fetching creators 133 | Response Fields 134 | Name Type Description 135 | id number The id of the image 136 | url string The url of the image at it's source resolution 137 | hash string The blurhash of the image 138 | width number The width of the image 139 | height number The height of the image 140 | nsfw boolean If the image has any mature content labels 141 | nsfwLevel enum (None, Soft, Mature, X) The NSFW level of the image 142 | createdAt date The date the image was posted 143 | postId number The ID of the post the image belongs to 144 | stats.cryCount number The number of cry reactions 145 | stats.laughCount number The number of laugh reactions 146 | stats.likeCount number The number of like reactions 147 | stats.heartCount number The number of heart reactions 148 | stats.commentCount number The number of comment reactions 149 | meta object The generation parameters parsed or input for the image 150 | username string The username of the creator 151 | metadata.nextCursor number The id of the first image in the next batch 152 | metadata.currentPage number The the current page you are at (if paging) 153 | metadata.pageSize number The the size of the batch (if paging) 154 | metadata.nextPage string The url to get the next batch of items 155 | Example 156 | 157 | The following example shows a request to get the first image: 158 | 159 | curl https://civitai.com/api/v1/images?limit=1 \ 160 | -H "Content-Type: application/json" \ 161 | -X GET 162 | 163 | This would yield the following response: 164 | Click to Expand 165 | 166 | Notes: 167 | 168 | On July 2, 2023 we switch from a paging system to a cursor based system due to the volume of data and requests for this endpoint. 169 | Whether you use paging or cursors, you can use metadata.nextPage to get the next page of results 170 | 171 | GET /api/v1/models 172 | Endpoint URL 173 | 174 | https://civitai.com/api/v1/models 175 | Query Parameters 176 | Name Type Description 177 | limit (OPTIONAL) number The number of results to be returned per page. This can be a number between 1 and 100. By default, each page will return 100 results 178 | page (OPTIONAL) number The page from which to start fetching models 179 | query (OPTIONAL) string Search query to filter models by name 180 | tag (OPTIONAL) string Search query to filter models by tag 181 | username (OPTIONAL) string Search query to filter models by user 182 | types (OPTIONAL) enum[] (Checkpoint, TextualInversion, Hypernetwork, AestheticGradient, LORA, Controlnet, Poses) The type of model you want to filter with. If none is specified, it will return all types 183 | sort (OPTIONAL) enum (Highest Rated, Most Downloaded, Newest) The order in which you wish to sort the results 184 | period (OPTIONAL) enum (AllTime, Year, Month, Week, Day) The time frame in which the models will be sorted 185 | rating (OPTIONAL) (Deprecated) number The rating you wish to filter the models with. If none is specified, it will return models with any rating 186 | favorites (OPTIONAL) (AUTHED) boolean Filter to favorites of the authenticated user (this requires an API token or session cookie) 187 | hidden (OPTIONAL) (AUTHED) boolean Filter to hidden models of the authenticated user (this requires an API token or session cookie) 188 | primaryFileOnly (OPTIONAL) boolean Only include the primary file for each model (This will use your preferred format options if you use an API token or session cookie) 189 | allowNoCredit (OPTIONAL) boolean Filter to models that require or don't require crediting the creator 190 | allowDerivatives (OPTIONAL) boolean Filter to models that allow or don't allow creating derivatives 191 | allowDifferentLicenses (OPTIONAL) boolean Filter to models that allow or don't allow derivatives to have a different license 192 | allowCommercialUse (OPTIONAL) enum (None, Image, Rent, Sell) Filter to models based on their commercial permissions 193 | nsfw (OPTIONAL) boolean If false, will return safer images and hide models that don't have safe images 194 | supportsGeneration (OPTIONAL) boolean If true, will return models that support generation 195 | Response Fields 196 | Name Type Description 197 | id number The identifier for the model 198 | name string The name of the model 199 | description string The description of the model (HTML) 200 | type enum (Checkpoint, TextualInversion, Hypernetwork, AestheticGradient, LORA, Controlnet, Poses) The model type 201 | nsfw boolean Whether the model is NSFW or not 202 | tags string[] The tags associated with the model 203 | mode enum (Archived, TakenDown) | null The mode in which the model is currently on. If Archived, files field will be empty. If TakenDown, images field will be empty 204 | creator.username string The name of the creator 205 | creator.image string | null The url of the creators avatar 206 | stats.downloadCount number The number of downloads the model has 207 | stats.favoriteCount number The number of favorites the model has 208 | stats.commentCount number The number of comments the model has 209 | stats.ratingCount number The number of ratings the model has 210 | stats.rating number The average rating of the model 211 | modelVersions.id number The identifier for the model version 212 | modelVersions.name string The name of the model version 213 | modelVersions.description string The description of the model version (usually a changelog) 214 | modelVersions.createdAt Date The date in which the version was created 215 | modelVersions.downloadUrl string The download url to get the model file for this specific version 216 | modelVersions.trainedWords string[] The words used to trigger the model 217 | modelVersions.files.sizeKb number The size of the model file 218 | modelVersions.files.pickleScanResult string Status of the pickle scan ('Pending', 'Success', 'Danger', 'Error') 219 | modelVersions.files.virusScanResult string Status of the virus scan ('Pending', 'Success', 'Danger', 'Error') 220 | modelVersions.files.scannedAt Date | null The date in which the file was scanned 221 | modelVersions.files.primary boolean | undefined If the file is the primary file for the model version 222 | modelVersions.files.metadata.fp enum (fp16, fp32) | undefined The specified floating point for the file 223 | modelVersions.files.metadata.size enum (full, pruned) | undefined The specified model size for the file 224 | modelVersions.files.metadata.format enum (SafeTensor, PickleTensor, Other) | undefined The specified model format for the file 225 | modelVersions.images.id string The id for the image 226 | modelVersions.images.url string The url for the image 227 | modelVersions.images.nsfw string Whether or not the image is NSFW (note: if the model is NSFW, treat all images on the model as NSFW) 228 | modelVersions.images.width number The original width of the image 229 | modelVersions.images.height number The original height of the image 230 | modelVersions.images.hash string The blurhash of the image 231 | modelVersions.images.meta object | null The generation params of the image 232 | modelVersions.stats.downloadCount number The number of downloads the model has 233 | modelVersions.stats.ratingCount number The number of ratings the model has 234 | modelVersions.stats.rating number The average rating of the model 235 | metadata.totalItems string The total number of items available 236 | metadata.currentPage string The the current page you are at 237 | metadata.pageSize string The the size of the batch 238 | metadata.totalPages string The total number of pages 239 | metadata.nextPage string The url to get the next batch of items 240 | metadata.prevPage string The url to get the previous batch of items 241 | 242 | Note: The download url uses a content-disposition header to set the filename correctly. Be sure to enable that header when fetching the download. For example, with wget: 243 | 244 | wget https://civitai.com/api/download/models/{modelVersionId} --content-disposition 245 | 246 | If the creator of the asset that you are trying to download requires authentication, then you will need an API Key to download it: 247 | 248 | wget https://civitai.com/api/download/models/{modelVersionId}?token={api_key} --content-disposition 249 | 250 | Example 251 | 252 | The following example shows a request to get the first 3 TextualInversion models from our database: 253 | 254 | curl https://civitai.com/api/v1/models?limit=3&types=TextualInversion \ 255 | -H "Content-Type: application/json" \ 256 | -X GET 257 | 258 | This would yield the following response: 259 | Click to expand 260 | 261 | GET /api/v1/models/:modelId 262 | Endpoint URL 263 | 264 | https://civitai.com/api/v1/models/:modelId 265 | Response Fields 266 | Name Type Description 267 | id number The identifier for the model 268 | name string The name of the model 269 | description string The description of the model (HTML) 270 | type enum (Checkpoint, TextualInversion, Hypernetwork, AestheticGradient, LORA, Controlnet, Poses) The model type 271 | nsfw boolean Whether the model is NSFW or not 272 | tags string[] The tags associated with the model 273 | mode enum (Archived, TakenDown) | null The mode in which the model is currently on. If Archived, files field will be empty. If TakenDown, images field will be empty 274 | creator.username string The name of the creator 275 | creator.image string | null The url of the creators avatar 276 | modelVersions.id number The identifier for the model version 277 | modelVersions.name string The name of the model version 278 | modelVersions.description string The description of the model version (usually a changelog) 279 | modelVersions.createdAt Date The date in which the version was created 280 | modelVersions.downloadUrl string The download url to get the model file for this specific version 281 | modelVersions.trainedWords string[] The words used to trigger the model 282 | modelVersions.files.sizeKb number The size of the model file 283 | modelVersions.files.pickleScanResult string Status of the pickle scan ('Pending', 'Success', 'Danger', 'Error') 284 | modelVersions.files.virusScanResult string Status of the virus scan ('Pending', 'Success', 'Danger', 'Error') 285 | modelVersions.files.scannedAt Date | null The date in which the file was scanned 286 | modelVersions.files.metadata.fp enum (fp16, fp32) | undefined The specified floating point for the file 287 | modelVersions.files.metadata.size enum (full, pruned) | undefined The specified model size for the file 288 | modelVersions.files.metadata.format enum (SafeTensor, PickleTensor, Other) | undefined The specified model format for the file 289 | modelVersions.images.url string The url for the image 290 | modelVersions.images.nsfw string Whether or not the image is NSFW (note: if the model is NSFW, treat all images on the model as NSFW) 291 | modelVersions.images.width number The original width of the image 292 | modelVersions.images.height number The original height of the image 293 | modelVersions.images.hash string The blurhash of the image 294 | modelVersions.images.meta object | null The generation params of the image 295 | 296 | Note: The download url uses a content-disposition header to set the filename correctly. Be sure to enable that header when fetching the download. For example, with wget: 297 | 298 | wget https://civitai.com/api/download/models/{modelVersionId} --content-disposition 299 | 300 | Example 301 | 302 | The following example shows a request to get the first 3 TextualInversion models from our database: 303 | 304 | curl https://civitai.com/api/v1/models/1102 \ 305 | -H "Content-Type: application/json" \ 306 | -X GET 307 | 308 | This would yield the following response: 309 | Click to expand 310 | 311 | GET /api/v1/models-versions/:modelVersionId 312 | Endpoint URL 313 | 314 | https://civitai.com/api/v1/model-versions/:id 315 | Response Fields 316 | Name Type Description 317 | id number The identifier for the model version 318 | name string The name of the model version 319 | description string The description of the model version (usually a changelog) 320 | model.name string The name of the model 321 | model.type enum (Checkpoint, TextualInversion, Hypernetwork, AestheticGradient, LORA, Controlnet, Poses) The model type 322 | model.nsfw boolean Whether the model is NSFW or not 323 | model.poi boolean Whether the model is of a person of interest or not 324 | model.mode enum (Archived, TakenDown) | null The mode in which the model is currently on. If Archived, files field will be empty. If TakenDown, images field will be empty 325 | modelId number The identifier for the model 326 | createdAt Date The date in which the version was created 327 | downloadUrl string The download url to get the model file for this specific version 328 | trainedWords string[] The words used to trigger the model 329 | files.sizeKb number The size of the model file 330 | files.pickleScanResult string Status of the pickle scan ('Pending', 'Success', 'Danger', 'Error') 331 | files.virusScanResult string Status of the virus scan ('Pending', 'Success', 'Danger', 'Error') 332 | files.scannedAt Date | null The date in which the file was scanned 333 | files.metadata.fp enum (fp16, fp32) | undefined The specified floating point for the file 334 | files.metadata.size enum (full, pruned) | undefined The specified model size for the file 335 | files.metadata.format enum (SafeTensor, PickleTensor, Other) | undefined The specified model format for the file 336 | stats.downloadCount number The number of downloads the model has 337 | stats.ratingCount number The number of ratings the model has 338 | stats.rating number The average rating of the model 339 | images.url string The url for the image 340 | images.nsfw string Whether or not the image is NSFW (note: if the model is NSFW, treat all images on the model as NSFW) 341 | images.width number The original width of the image 342 | images.height number The original height of the image 343 | images.hash string The blurhash of the image 344 | images.meta object | null The generation params of the image 345 | 346 | Note: The download url uses a content-disposition header to set the filename correctly. Be sure to enable that header when fetching the download. For example, with wget: 347 | 348 | wget https://civitai.com/api/download/models/{modelVersionId} --content-disposition 349 | 350 | Example 351 | 352 | The following example shows a request to get a model version from our database: 353 | 354 | curl https://civitai.com/api/v1/model-versions/1318 \ 355 | -H "Content-Type: application/json" \ 356 | -X GET 357 | 358 | This would yield the following response: 359 | Click to expand 360 | 361 | GET /api/v1/models-versions/by-hash/:hash 362 | Endpoint URL 363 | 364 | https://civitai.com/api/v1/model-versions/by-hash/:hash 365 | Response Fields 366 | 367 | Same as standard model-versions endpoint 368 | 369 | Note: We support the following hash algorithms: AutoV1, AutoV2, SHA256, CRC32, and Blake3 370 | 371 | Note 2: We are still in the process of hashing older files, so these results are incomplete 372 | GET /api/v1/tags 373 | Endpoint URL 374 | 375 | https://civitai.com/api/v1/tags 376 | Query Parameters 377 | Name Type Description 378 | limit (OPTIONAL) number The number of results to be returned per page. This can be a number between 1 and 200. By default, each page will return 20 results. If set to 0, it'll return all the tags 379 | page (OPTIONAL) number The page from which to start fetching tags 380 | query (OPTIONAL) string Search query to filter tags by name 381 | Response Fields 382 | Name Type Description 383 | name string The name of the tag 384 | modelCount number The amount of models linked to this tag 385 | link string Url to get all models from this tag 386 | metadata.totalItems string The total number of items available 387 | metadata.currentPage string The the current page you are at 388 | metadata.pageSize string The the size of the batch 389 | metadata.totalPages string The total number of pages 390 | metadata.nextPage string The url to get the next batch of items 391 | metadata.prevPage string The url to get the previous batch of items 392 | Example 393 | 394 | The following example shows a request to get the first 3 model tags from our database: 395 | 396 | curl https://civitai.com/api/v1/tags?limit=3 \ 397 | -H "Content-Type: application/json" \ 398 | -X GET 399 | 400 | This would yield the following response: 401 | 402 | { 403 | "items": [ 404 | { 405 | "name": "Pepe Larraz", 406 | "modelCount": 1, 407 | "link": "https://civitai.com/api/v1/models?tag=Pepe Larraz" 408 | }, 409 | { 410 | "name": "comic book", 411 | "modelCount": 7, 412 | "link": "https://civitai.com/api/v1/models?tag=comic book" 413 | }, 414 | { 415 | "name": "style", 416 | "modelCount": 91, 417 | "link": "https://civitai.com/api/v1/models?tag=style" 418 | } 419 | ], 420 | "metadata": { 421 | "totalItems": 200, 422 | "currentPage": 1, 423 | "pageSize": 3, 424 | "totalPages": 67, 425 | "nextPage": "https://civitai.com/api/v1/tags?limit=3&page=2" 426 | } 427 | } -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Define the name for the output binary 2 | BINARY_NAME=civitai-downloader 3 | 4 | # Define the path to the main package 5 | MAIN_PKG=./cmd/civitai-downloader 6 | 7 | # Define the Go command 8 | GO=go 9 | 10 | # Build the application 11 | build: 12 | @echo "Building $(BINARY_NAME)..." 13 | $(GO) build -o $(BINARY_NAME) $(MAIN_PKG) 14 | @echo "$(BINARY_NAME) built successfully." 15 | 16 | # Run the application (passes arguments after --) 17 | run: build 18 | @echo "Running $(BINARY_NAME)..." 19 | ./$(BINARY_NAME) $(ARGS) 20 | 21 | # Run tests 22 | test: 23 | @echo "Running tests..." 24 | $(GO) test ./... -v 25 | 26 | # Clean build artifacts 27 | clean: 28 | @echo "Cleaning..." 29 | rm -f $(BINARY_NAME) 30 | rm -rf ./release 31 | @echo "Clean complete." 32 | 33 | # Build release binaries for multiple platforms 34 | release: clean 35 | @echo "Building release binaries..." 36 | GOOS=linux GOARCH=amd64 $(GO) build -o release/$(BINARY_NAME)-linux-amd64 $(MAIN_PKG) 37 | GOOS=linux GOARCH=arm64 $(GO) build -o release/$(BINARY_NAME)-linux-arm64 $(MAIN_PKG) 38 | GOOS=windows GOARCH=amd64 $(GO) build -o release/$(BINARY_NAME)-windows-amd64.exe $(MAIN_PKG) 39 | GOOS=darwin GOARCH=amd64 $(GO) build -o release/$(BINARY_NAME)-darwin-amd64 $(MAIN_PKG) 40 | GOOS=darwin GOARCH=arm64 $(GO) build -o release/$(BINARY_NAME)-darwin-arm64 $(MAIN_PKG) 41 | @echo "Release binaries built successfully in ./release directory." 42 | 43 | # Default target 44 | all: build 45 | 46 | # Phony targets (targets that don't represent files) 47 | .PHONY: all build run test clean release -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/clean.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | 9 | // Use correct relative paths for internal packages 10 | 11 | log "github.com/sirupsen/logrus" 12 | "github.com/spf13/cobra" 13 | ) 14 | 15 | func init() { 16 | // Assumes rootCmd is defined in root.go within the same package 17 | rootCmd.AddCommand(cleanCmd) 18 | 19 | cleanCmd.Flags().BoolP("torrents", "t", false, "Also remove *.torrent files") 20 | cleanCmd.Flags().BoolP("magnets", "m", false, "Also remove *-magnet.txt files") 21 | } 22 | 23 | var cleanCmd = &cobra.Command{ 24 | Use: "clean", 25 | Short: "Remove temporary (.tmp) files from the download directory", 26 | Long: `Recursively scans the configured SavePath and removes any files ending with the .tmp extension. 27 | Optionally removes *.torrent and *-magnet.txt files as well.`, 28 | Run: runClean, 29 | } 30 | 31 | func runClean(cmd *cobra.Command, args []string) { 32 | // Access the globally loaded config from root.go's PersistentPreRunE 33 | cfg := globalConfig // Use the globalConfig variable 34 | savePath := cfg.SavePath 35 | 36 | // Get flag values 37 | cleanTorrents, _ := cmd.Flags().GetBool("torrents") 38 | cleanMagnets, _ := cmd.Flags().GetBool("magnets") 39 | 40 | // --- Path Validation --- (Moved up slightly) 41 | if savePath == "" { 42 | if cfg.DatabasePath != "" { 43 | savePath = filepath.Dir(cfg.DatabasePath) 44 | log.Warnf("SavePath is empty, inferring base directory from DatabasePath: %s", savePath) 45 | } else { 46 | log.Error("SavePath is not configured (and cannot be inferred from DatabasePath). Cannot determine where to clean.") 47 | os.Exit(1) 48 | } 49 | } 50 | info, err := os.Stat(savePath) 51 | if os.IsNotExist(err) { 52 | log.Errorf("SavePath directory does not exist: %s", savePath) 53 | os.Exit(1) 54 | } 55 | if err != nil { 56 | log.Errorf("Error accessing SavePath %q: %v", savePath, err) 57 | os.Exit(1) 58 | } 59 | if !info.IsDir() { 60 | log.Errorf("SavePath is not a directory: %s", savePath) 61 | os.Exit(1) 62 | } 63 | // --- End Path Validation --- 64 | 65 | logLine := fmt.Sprintf("Scanning for .tmp files in %s", savePath) 66 | if cleanTorrents { 67 | logLine += " (and *.torrent files)" 68 | } 69 | if cleanMagnets { 70 | logLine += " (and *-magnet.txt files)" 71 | } 72 | log.Info(logLine + "...") 73 | 74 | var tmpRemoved, torrentRemoved, magnetRemoved int64 75 | var filesFailed int64 76 | 77 | walkErr := filepath.Walk(savePath, func(path string, info os.FileInfo, err error) error { 78 | if err != nil { 79 | log.Warnf("Error accessing path %q during scan: %v", path, err) 80 | return nil 81 | } 82 | if info.IsDir() { 83 | return nil // Skip directories 84 | } 85 | 86 | lowerName := strings.ToLower(info.Name()) 87 | shouldRemove := false 88 | fileType := "" 89 | 90 | // Check file types based on flags 91 | if strings.HasSuffix(lowerName, ".tmp") { 92 | shouldRemove = true 93 | fileType = ".tmp" 94 | } else if cleanTorrents && strings.HasSuffix(lowerName, ".torrent") { 95 | shouldRemove = true 96 | fileType = ".torrent" 97 | } else if cleanMagnets && strings.HasSuffix(lowerName, "-magnet.txt") { 98 | shouldRemove = true 99 | fileType = "-magnet.txt" 100 | } 101 | 102 | if shouldRemove { 103 | log.Debugf("Found %s file: %s", fileType, path) 104 | err := os.Remove(path) 105 | if err != nil { 106 | if os.IsNotExist(err) { 107 | log.Warnf("Attempted to remove %s file %q, but it was already gone.", fileType, path) 108 | } else { 109 | log.Errorf("Failed to remove %s file %q: %v", fileType, path, err) 110 | filesFailed++ 111 | } 112 | } else { 113 | log.Infof("Removed %s file: %s", fileType, path) 114 | // Increment specific counter 115 | switch fileType { 116 | case ".tmp": 117 | tmpRemoved++ 118 | case ".torrent": 119 | torrentRemoved++ 120 | case "-magnet.txt": 121 | magnetRemoved++ 122 | } 123 | } 124 | } 125 | return nil // Continue walking 126 | }) 127 | 128 | if walkErr != nil { 129 | log.Errorf("Error during directory walk of %q: %v", savePath, walkErr) 130 | } 131 | 132 | // Build summary string 133 | var summaryParts []string 134 | if tmpRemoved > 0 { 135 | summaryParts = append(summaryParts, fmt.Sprintf("%d .tmp file(s)", tmpRemoved)) 136 | } 137 | if torrentRemoved > 0 { 138 | summaryParts = append(summaryParts, fmt.Sprintf("%d .torrent file(s)", torrentRemoved)) 139 | } 140 | if magnetRemoved > 0 { 141 | summaryParts = append(summaryParts, fmt.Sprintf("%d -magnet.txt file(s)", magnetRemoved)) 142 | } 143 | 144 | summary := "Clean complete. Removed: " 145 | if len(summaryParts) > 0 { 146 | summary += strings.Join(summaryParts, ", ") 147 | } else { 148 | summary += "0 files" 149 | } 150 | 151 | if filesFailed > 0 { 152 | summary += fmt.Sprintf(". Failed to remove %d file(s).", filesFailed) 153 | } 154 | log.Info(summary) 155 | 156 | if filesFailed > 0 || walkErr != nil { 157 | os.Exit(1) 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/cmd_download_processing.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "net/url" 8 | "os" 9 | "path/filepath" 10 | "strings" 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | 15 | "go-civitai-download/internal/database" 16 | "go-civitai-download/internal/downloader" 17 | "go-civitai-download/internal/helpers" 18 | "go-civitai-download/internal/models" 19 | 20 | log "github.com/sirupsen/logrus" 21 | "github.com/spf13/viper" 22 | ) 23 | 24 | // --- Structs for Concurrent Image Downloads --- START --- 25 | type imageDownloadJob struct { 26 | SourceURL string 27 | TargetPath string 28 | ImageID int // Keep ID for logging 29 | LogFilename string // Keep base filename for logging 30 | } 31 | 32 | // --- Structs for Concurrent Image Downloads --- END --- 33 | 34 | // --- Worker for Concurrent Image Downloads --- START --- 35 | func imageDownloadWorkerInternal(id int, jobs <-chan imageDownloadJob, imageDownloader *downloader.Downloader, wg *sync.WaitGroup, successCounter *int64, failureCounter *int64, logPrefix string) { 36 | defer wg.Done() 37 | log.Debugf("[%s-Worker-%d] Starting internal image worker", logPrefix, id) 38 | for job := range jobs { 39 | log.Debugf("[%s-Worker-%d] Received job for image ID %d -> %s", logPrefix, id, job.ImageID, job.TargetPath) 40 | 41 | // Check if image exists already 42 | if _, statErr := os.Stat(job.TargetPath); statErr == nil { 43 | log.Debugf("[%s-Worker-%d] Skipping image %s - already exists.", logPrefix, id, job.LogFilename) 44 | continue 45 | } else if !os.IsNotExist(statErr) { 46 | log.WithError(statErr).Warnf("[%s-Worker-%d] Failed to check status of target image file %s. Skipping.", logPrefix, id, job.TargetPath) 47 | atomic.AddInt64(failureCounter, 1) 48 | continue 49 | } 50 | 51 | // Download the image 52 | log.Debugf("[%s-Worker-%d] Downloading image %s from %s", logPrefix, id, job.LogFilename, job.SourceURL) 53 | _, dlErr := imageDownloader.DownloadFile(job.TargetPath, job.SourceURL, models.Hashes{}, 0) 54 | 55 | if dlErr != nil { 56 | log.WithError(dlErr).Errorf("[%s-Worker-%d] Failed to download image %s from %s", logPrefix, id, job.LogFilename, job.SourceURL) 57 | atomic.AddInt64(failureCounter, 1) 58 | } else { 59 | log.Debugf("[%s-Worker-%d] Downloaded image %s successfully.", logPrefix, id, job.LogFilename) 60 | atomic.AddInt64(successCounter, 1) 61 | } 62 | } 63 | log.Debugf("[%s-Worker-%d] Finishing internal image worker", logPrefix, id) 64 | } 65 | 66 | // --- Worker for Concurrent Image Downloads --- END --- 67 | 68 | // processPage filters downloads based on config and database status. 69 | // It returns the list of downloads that should be queued and their total size. 70 | func processPage(db *database.DB, pageDownloads []potentialDownload, cfg *models.Config) ([]potentialDownload, uint64) { 71 | downloadsToQueue := []potentialDownload{} 72 | var queuedSizeBytes uint64 = 0 73 | 74 | for _, pd := range pageDownloads { 75 | // Calculate DB Key using ModelVersion ID 76 | if pd.CleanedVersion.ID == 0 { 77 | log.Warnf("Skipping potential download %s for model %s - missing ModelVersion ID.", pd.File.Name, pd.ModelName) 78 | continue 79 | } 80 | // Use prefix "v_" to distinguish version keys 81 | dbKey := fmt.Sprintf("v_%d", pd.CleanedVersion.ID) 82 | 83 | // Check database 84 | // Get retrieves raw bytes, unmarshaling happens later if needed 85 | rawValue, err := db.Get([]byte(dbKey)) // Note: db.Get returns raw bytes 86 | 87 | shouldQueue := false 88 | // Use errors.Is to check for the specific ErrNotFound error from our database package 89 | if errors.Is(err, database.ErrNotFound) { 90 | log.Debugf("Model Version %d (Key: %s) not found in DB. Queuing for download.", pd.CleanedVersion.ID, dbKey) 91 | shouldQueue = true 92 | // Create initial entry using the correct DatabaseEntry fields 93 | newEntry := models.DatabaseEntry{ 94 | ModelName: pd.ModelName, 95 | ModelType: pd.ModelType, 96 | Version: pd.CleanedVersion, // Store the cleaned version struct 97 | File: pd.File, // Store the file struct 98 | Timestamp: time.Now().Unix(), // Use Unix timestamp for AddedAt 99 | Creator: pd.Creator, // Store the creator struct 100 | Filename: filepath.Base(pd.TargetFilepath), // Use the calculated filename 101 | Folder: pd.Slug, // Use the calculated folder slug 102 | Status: models.StatusPending, // Use constant 103 | ErrorDetails: "", // Use correct field name 104 | } 105 | // Marshal the new entry to JSON before putting into DB 106 | entryBytes, marshalErr := json.Marshal(newEntry) 107 | if marshalErr != nil { 108 | log.WithError(marshalErr).Errorf("Failed to marshal new DB entry for key %s", dbKey) 109 | continue // Skip queuing if marshalling fails 110 | } 111 | // Put the marshalled bytes 112 | if errPut := db.Put([]byte(dbKey), entryBytes); errPut != nil { 113 | log.WithError(errPut).Errorf("Failed to add pending entry to DB for key %s", dbKey) 114 | // Decide if we should still attempt download? Maybe not. 115 | continue // Skip queuing if DB write fails 116 | } 117 | } else if err != nil { 118 | // Handle other potential DB errors during Get 119 | log.WithError(err).Errorf("Error checking database for key %s", dbKey) 120 | continue // Skip this file on DB error 121 | } else { 122 | // Entry exists, unmarshal it 123 | var entry models.DatabaseEntry 124 | if unmarshalErr := json.Unmarshal(rawValue, &entry); unmarshalErr != nil { 125 | log.WithError(unmarshalErr).Errorf("Failed to unmarshal existing DB entry for key %s", dbKey) 126 | continue // Skip if we can't parse the existing entry 127 | } 128 | 129 | log.Debugf("Model Version %d (Key: %s) found in DB with status: %s", entry.Version.ID, dbKey, entry.Status) 130 | switch entry.Status { 131 | case models.StatusDownloaded: 132 | log.Debugf("DB Status for %s (VersionID: %d, Key: %s) is Downloaded. Checking filesystem...", pd.FinalBaseFilename, pd.CleanedVersion.ID, dbKey) 133 | 134 | // Construct the path using the FILENAME STORED IN THE DB ENTRY, which includes the prepended ID. 135 | expectedPathFromDB := filepath.Join(filepath.Dir(pd.TargetFilepath), entry.Filename) 136 | log.Debugf("Checking for file existence at: %s (based on DB entry filename)", expectedPathFromDB) 137 | 138 | // Check if the file *actually* exists on disk using the DB filename 139 | if _, statErr := os.Stat(expectedPathFromDB); os.IsNotExist(statErr) { 140 | // File is missing despite DB saying downloaded! 141 | log.Warnf("File %s marked as downloaded in DB (Key: %s), but not found on disk! Re-queuing.", expectedPathFromDB, dbKey) 142 | shouldQueue = true 143 | // Update status back to Pending and clear error 144 | entry.Status = models.StatusPending 145 | entry.ErrorDetails = "" 146 | // Update other fields that might change 147 | entry.Folder = pd.Slug 148 | entry.Version = pd.CleanedVersion 149 | entry.File = pd.File 150 | // Update DB entry to reflect Pending status 151 | entryBytes, marshalErr := json.Marshal(entry) 152 | if marshalErr != nil { 153 | log.WithError(marshalErr).Errorf("Failed to marshal entry for re-queue update (missing file) %s", dbKey) 154 | shouldQueue = false // Don't queue if marshalling fails 155 | } else if errUpdate := db.Put([]byte(dbKey), entryBytes); errUpdate != nil { 156 | log.WithError(errUpdate).Errorf("Failed to update DB entry to Pending (missing file) for key %s", dbKey) 157 | shouldQueue = false // Don't queue if update fails 158 | } 159 | // End of handling missing file 160 | } else if statErr == nil { 161 | // File *does* exist, proceed with original skip logic + metadata check 162 | log.Infof("Skipping %s (VersionID: %d, Key: %s) - File exists and DB status is Downloaded.", pd.TargetFilepath, pd.CleanedVersion.ID, dbKey) 163 | // Update fields that might change between runs 164 | entry.Folder = pd.Slug 165 | entry.Version = pd.CleanedVersion // Update associated metadata version 166 | entry.File = pd.File // Update file details (URL might change) 167 | 168 | // --- START: Save Metadata Check for Existing Download --- 169 | // Use Viper to check if metadata saving is enabled 170 | if viper.GetBool("savemetadata") { 171 | // Derive metadata path from the expected path based on the DB entry filename 172 | metadataPath := strings.TrimSuffix(expectedPathFromDB, filepath.Ext(expectedPathFromDB)) + ".json" 173 | 174 | if _, metaStatErr := os.Stat(metadataPath); os.IsNotExist(metaStatErr) { 175 | log.Infof("Model file exists, but metadata %s is missing. Saving metadata.", filepath.Base(metadataPath)) 176 | // Marshal the FULL version info from the potential download struct 177 | jsonData, jsonErr := json.MarshalIndent(pd.FullVersion, "", " ") 178 | if jsonErr != nil { 179 | log.WithError(jsonErr).Warnf("Failed to marshal full version metadata for existing file %s", pd.TargetFilepath) 180 | } else { 181 | if writeErr := os.WriteFile(metadataPath, jsonData, 0600); writeErr != nil { 182 | log.WithError(writeErr).Warnf("Failed to write version metadata file %s", metadataPath) 183 | } 184 | } 185 | } else if metaStatErr != nil { 186 | // Log error if stating metadata file failed for other reasons 187 | log.WithError(metaStatErr).Warnf("Could not check status of metadata file %s", metadataPath) 188 | } 189 | } 190 | // --- END: Save Metadata Check for Existing Download --- 191 | 192 | // Update the entry in the database (keeping status Downloaded) 193 | entryBytes, marshalErr := json.Marshal(entry) 194 | if marshalErr != nil { 195 | log.WithError(marshalErr).Warnf("Failed to marshal updated downloaded entry %s", dbKey) 196 | } else if errUpdate := db.Put([]byte(dbKey), entryBytes); errUpdate != nil { 197 | log.WithError(errUpdate).Warnf("Failed to update metadata for downloaded entry %s", dbKey) 198 | } 199 | shouldQueue = false 200 | } else { 201 | // Some other error occurred when checking file existence 202 | log.WithError(statErr).Warnf("Error checking filesystem for %s (Key: %s). Skipping queue.", pd.TargetFilepath, dbKey) 203 | shouldQueue = false 204 | // Optionally update DB entry here too, or just skip? 205 | } 206 | case models.StatusPending, models.StatusError: 207 | log.Infof("Re-queuing %s (VersionID: %d, Key: %s) - Status is %s.", pd.TargetFilepath, pd.CleanedVersion.ID, dbKey, entry.Status) 208 | shouldQueue = true 209 | // Update status back to Pending and clear error if any 210 | entry.Status = models.StatusPending 211 | entry.ErrorDetails = "" 212 | // Update fields that might change 213 | entry.Folder = pd.Slug 214 | entry.Version = pd.CleanedVersion 215 | entry.File = pd.File 216 | // entry.Timestamp = time.Now().Unix() // Optionally update timestamp? 217 | 218 | entryBytes, marshalErr := json.Marshal(entry) 219 | if marshalErr != nil { 220 | log.WithError(marshalErr).Errorf("Failed to marshal entry for re-queue update %s", dbKey) 221 | shouldQueue = false // Don't queue if marshalling fails 222 | } else if errUpdate := db.Put([]byte(dbKey), entryBytes); errUpdate != nil { 223 | log.WithError(errUpdate).Errorf("Failed to update DB entry to Pending for key %s", dbKey) 224 | shouldQueue = false // Don't queue if update fails 225 | } 226 | default: 227 | log.Warnf("Skipping %s (VersionID: %d, Key: %s) - Unknown status '%s' in database.", pd.TargetFilepath, pd.CleanedVersion.ID, dbKey, entry.Status) 228 | shouldQueue = false 229 | } 230 | } 231 | 232 | if shouldQueue { 233 | downloadsToQueue = append(downloadsToQueue, pd) 234 | queuedSizeBytes += uint64(pd.File.SizeKB * 1024) 235 | log.Debugf("Added confirmed download to queue: %s (Model: %s)", pd.File.Name, pd.ModelName) 236 | } 237 | } 238 | 239 | return downloadsToQueue, queuedSizeBytes 240 | } 241 | 242 | // saveModelInfoFile saves the full model metadata to a .json file. 243 | // It saves the file to {modelBaseDir}/{model.ID}.json. 244 | func saveModelInfoFile(model models.Model, modelBaseDir string) error { 245 | // The base directory is now passed directly 246 | infoDirPath := modelBaseDir 247 | 248 | // Ensure the directory exists 249 | if err := os.MkdirAll(infoDirPath, 0750); err != nil { 250 | log.WithError(err).Errorf("Failed to create model info directory: %s", infoDirPath) 251 | return fmt.Errorf("failed to create directory %s: %w", infoDirPath, err) 252 | } 253 | 254 | // Construct the file path within the model base directory 255 | // Use {modelID}-{modelNameSlug}.json format 256 | modelNameSlug := helpers.ConvertToSlug(model.Name) 257 | if modelNameSlug == "" { 258 | modelNameSlug = "unknown_model" 259 | } 260 | fileName := fmt.Sprintf("%d-%s.json", model.ID, modelNameSlug) 261 | filePath := filepath.Join(infoDirPath, fileName) 262 | 263 | // Marshal the full model info 264 | jsonData, jsonErr := json.MarshalIndent(model, "", " ") 265 | if jsonErr != nil { 266 | log.WithError(jsonErr).Warnf("Failed to marshal full model info for model %d (%s)", model.ID, model.Name) 267 | return fmt.Errorf("failed to marshal model info for %d: %w", model.ID, jsonErr) 268 | } 269 | 270 | // Write the file (overwrite if exists) 271 | if writeErr := os.WriteFile(filePath, jsonData, 0600); writeErr != nil { 272 | log.WithError(writeErr).Warnf("Failed to write model info file %s", filePath) 273 | return fmt.Errorf("failed to write model info file %s: %w", filePath, writeErr) 274 | } 275 | 276 | log.Debugf("Saved full model info to %s", filePath) 277 | return nil 278 | } 279 | 280 | // downloadImages handles downloading a list of images concurrently to a specified directory. 281 | func downloadImages(logPrefix string, images []models.ModelImage, baseDir string, imageDownloader *downloader.Downloader, numWorkers int) (finalSuccessCount, finalFailCount int) { 282 | if imageDownloader == nil { 283 | log.Warnf("[%s] Image downloader is nil, cannot download images.", logPrefix) 284 | return 0, len(images) // Count all as failed if downloader doesn't exist 285 | } 286 | if len(images) == 0 { 287 | log.Debugf("[%s] No images provided to download.", logPrefix) 288 | return 0, 0 289 | } 290 | if numWorkers <= 0 { 291 | log.Warnf("[%s] Invalid concurrency level %d for image download, defaulting to 1.", logPrefix, numWorkers) 292 | numWorkers = 1 293 | } 294 | 295 | log.Infof("[%s] Attempting concurrent download for %d images to %s (Concurrency: %d)", logPrefix, len(images), baseDir, numWorkers) 296 | 297 | if err := os.MkdirAll(baseDir, 0750); err != nil { 298 | log.WithError(err).Errorf("[%s] Failed to create base directory for images: %s", logPrefix, baseDir) 299 | return 0, len(images) // Cannot proceed, count all as failed 300 | } 301 | 302 | // --- Setup Concurrency --- 303 | jobs := make(chan imageDownloadJob, numWorkers*2) // Buffered channel 304 | var wg sync.WaitGroup 305 | var successCounter int64 = 0 306 | var failureCounter int64 = 0 307 | 308 | // --- Start Workers --- 309 | log.Debugf("[%s] Starting %d internal image download workers...", logPrefix, numWorkers) 310 | for w := 1; w <= numWorkers; w++ { 311 | wg.Add(1) 312 | go imageDownloadWorkerInternal(w, jobs, imageDownloader, &wg, &successCounter, &failureCounter, logPrefix) 313 | } 314 | 315 | // --- Queue Jobs --- Loop through images and send jobs 316 | queuedCount := 0 317 | for imgIdx, image := range images { 318 | // Construct image filename: {imageID}.{ext} (Copied from previous sequential logic) 319 | imgUrlParsed, urlErr := url.Parse(image.URL) 320 | var imgFilename string 321 | 322 | if urlErr != nil || image.ID == 0 { 323 | fallbackName := fmt.Sprintf("image_%d.jpg", imgIdx) // Default fallback 324 | // Try to get filename from URL path as a better fallback 325 | if urlErr == nil { // Only try if URL parsing itself didn't fail 326 | pathSegments := strings.Split(imgUrlParsed.Path, "/") 327 | if len(pathSegments) > 0 { 328 | lastSegment := pathSegments[len(pathSegments)-1] 329 | // Basic check if it looks like a filename (has an extension, not empty) 330 | if strings.Contains(lastSegment, ".") && len(lastSegment) > 1 { 331 | fallbackName = lastSegment 332 | log.Debugf("[%s] Using filename '%s' extracted from URL path as fallback.", logPrefix, fallbackName) 333 | } else { 334 | log.Debugf("[%s] Last URL path segment '%s' does not look like a usable filename.", logPrefix, lastSegment) 335 | } 336 | } 337 | } 338 | // Log the warning, indicating which fallback name is being used 339 | log.WithError(urlErr).Debugf("[%s] Cannot determine filename/ID for image %d (URL: %s). Using fallback: %s", logPrefix, imgIdx, image.URL, fallbackName) 340 | imgFilename = fallbackName 341 | } else { 342 | // Normal logic using image.ID 343 | ext := filepath.Ext(imgUrlParsed.Path) 344 | if ext == "" || len(ext) > 5 { // Basic check for valid extension 345 | log.Warnf("[%s] Image URL %s has unusual/missing extension '%s', defaulting to .jpg", logPrefix, image.URL, ext) 346 | ext = ".jpg" 347 | } 348 | imgFilename = fmt.Sprintf("%d%s", image.ID, ext) 349 | } 350 | imgTargetPath := filepath.Join(baseDir, imgFilename) 351 | 352 | // Create and send job 353 | job := imageDownloadJob{ 354 | SourceURL: image.URL, 355 | TargetPath: imgTargetPath, 356 | ImageID: image.ID, 357 | LogFilename: imgFilename, // Pass for consistent logging 358 | } 359 | log.Debugf("[%s] Queueing image job: ID %d -> %s", logPrefix, job.ImageID, job.TargetPath) 360 | jobs <- job 361 | queuedCount++ 362 | } 363 | 364 | close(jobs) // Signal no more jobs 365 | log.Debugf("[%s] All %d image jobs queued. Waiting for workers...", logPrefix, queuedCount) 366 | 367 | // --- Wait for Completion --- 368 | wg.Wait() 369 | log.Infof("[%s] Image download complete. Success: %d, Failed: %d", logPrefix, atomic.LoadInt64(&successCounter), atomic.LoadInt64(&failureCounter)) 370 | 371 | return int(atomic.LoadInt64(&successCounter)), int(atomic.LoadInt64(&failureCounter)) 372 | } 373 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/cmd_download_setup.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | 6 | "go-civitai-download/internal/models" 7 | 8 | log "github.com/sirupsen/logrus" 9 | "github.com/spf13/cobra" 10 | "github.com/spf13/viper" 11 | ) 12 | 13 | // Variable to store concurrency level for flag parsing 14 | // var concurrencyLevel int 15 | 16 | // Allowed values for API parameters 17 | var allowedSortOrders = map[string]bool{ 18 | "Highest Rated": true, 19 | "Most Downloaded": true, 20 | "Newest": true, 21 | } 22 | 23 | var allowedPeriods = map[string]bool{ 24 | "AllTime": true, 25 | "Year": true, 26 | "Month": true, 27 | "Week": true, 28 | "Day": true, 29 | } 30 | 31 | // Variables defined in download.go that are used here 32 | // var logLevel string // Declared in download.go 33 | // var logFormat string // Declared in download.go 34 | 35 | // REMOVED init() function to avoid flag redefinition. 36 | // Flag definitions and bindings are now consolidated in download.go's init(). 37 | 38 | // initLogging configures logrus based on persistent flags 39 | func initLogging() { 40 | level, err := log.ParseLevel(logLevel) 41 | if err != nil { 42 | log.WithError(err).Warnf("Invalid log level '%s', using default 'info'", logLevel) 43 | level = log.InfoLevel 44 | } 45 | log.SetLevel(level) 46 | 47 | switch logFormat { 48 | case "json": 49 | log.SetFormatter(&log.JSONFormatter{}) 50 | case "text": 51 | log.SetFormatter(&log.TextFormatter{FullTimestamp: true}) 52 | default: 53 | log.Warnf("Invalid log format '%s', using default 'text'", logFormat) 54 | log.SetFormatter(&log.TextFormatter{FullTimestamp: true}) 55 | } 56 | 57 | log.Infof("Logging configured: Level=%s, Format=%s", log.GetLevel(), logFormat) 58 | } 59 | 60 | // setupQueryParams initializes the query parameters using Viper for flag/config precedence. 61 | func setupQueryParams(cfg *models.Config, cmd *cobra.Command) models.QueryParameters { 62 | // Viper keys should match the keys used in viper.BindPFlag in init() 63 | 64 | // Use viper.Get* for values that can be set by flags 65 | limit := viper.GetInt("limit") // Viper key from download.go init 66 | if limit <= 0 || limit > 100 { 67 | if limit != 0 { // Don't warn if just using default 68 | log.Warnf("Invalid Limit value '%d' from flag/config, using default 100", limit) 69 | } 70 | limit = 100 // API default/max 71 | } 72 | 73 | // Use global Viper directly now that TOML parsing is fixed 74 | sort := viper.GetString("sort") 75 | if sort == "" { 76 | sort = "Most Downloaded" 77 | } else if _, ok := allowedSortOrders[sort]; !ok { 78 | log.Warnf("Invalid Sort value '%s' from flag/config, using default 'Most Downloaded'", sort) 79 | sort = "Most Downloaded" 80 | } 81 | 82 | period := viper.GetString("period") 83 | if period == "" { 84 | period = "AllTime" 85 | } else if _, ok := allowedPeriods[period]; !ok { 86 | log.Warnf("Invalid Period value '%s' from flag/config, using default 'AllTime'", period) 87 | period = "AllTime" 88 | } 89 | 90 | baseModels := viper.GetStringSlice("basemodels") // Viper should handle precedence correctly now 91 | 92 | params := models.QueryParameters{ 93 | Limit: limit, 94 | Page: 1, 95 | Query: viper.GetString("query"), 96 | Tag: viper.GetString("tag"), 97 | Username: viper.GetString("username"), 98 | Types: viper.GetStringSlice("modeltypes"), 99 | Sort: sort, 100 | Period: period, 101 | PrimaryFileOnly: viper.GetBool("primaryonly"), 102 | AllowNoCredit: true, 103 | AllowDerivatives: true, 104 | AllowDifferentLicenses: true, 105 | AllowCommercialUse: "Any", 106 | Nsfw: viper.GetBool("nsfw"), 107 | BaseModels: baseModels, // Use value directly from Viper 108 | } 109 | 110 | log.WithField("params", fmt.Sprintf("%+v", params)).Debug("Final query parameters set") 111 | return params 112 | } 113 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/cmd_download_types.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import "go-civitai-download/internal/models" 4 | 5 | // potentialDownload holds information about a file identified during the metadata scan phase. 6 | type potentialDownload struct { 7 | ModelName string 8 | ModelType string 9 | VersionName string 10 | BaseModel string 11 | Creator models.Creator 12 | File models.File // Contains URL, Hashes, SizeKB etc. 13 | ModelVersionID int // Add Model Version ID 14 | TargetFilepath string // Full calculated path for download 15 | Slug string // Folder structure 16 | FinalBaseFilename string // Base filename part without ID prefix or metadata suffix (e.g., wan_cowgirl_v1.3.safetensors) 17 | // Store cleaned version separately for potential later use in DB entry 18 | CleanedVersion models.ModelVersion 19 | FullVersion models.ModelVersion 20 | OriginalImages []models.ModelImage // Add original images for potential download 21 | } 22 | 23 | // Represents a download task to be processed by a worker. 24 | type downloadJob struct { 25 | PotentialDownload potentialDownload // Embed potential download info 26 | DatabaseKey string // Key for DB updates 27 | } 28 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/cmd_download_worker.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | "sync" 10 | "time" 11 | 12 | index "go-civitai-download/index" 13 | "go-civitai-download/internal/database" 14 | "go-civitai-download/internal/downloader" 15 | "go-civitai-download/internal/models" 16 | 17 | "github.com/blevesearch/bleve/v2" 18 | "github.com/gosuri/uilive" 19 | log "github.com/sirupsen/logrus" 20 | "github.com/spf13/viper" 21 | ) 22 | 23 | // updateDbEntry encapsulates the logic for getting, updating, and putting a database entry. 24 | // It takes the database connection, the key, the new status (string), and an optional function 25 | // to apply further modifications to the entry before saving. 26 | func updateDbEntry(db *database.DB, key string, newStatus string, updateFunc func(*models.DatabaseEntry)) error { 27 | rawValue, errGet := db.Get([]byte(key)) 28 | if errGet != nil { 29 | // If the key isn't found, we can't update it. Log and return error. 30 | // If it's another error, log that too. 31 | log.WithError(errGet).Errorf("Failed to get DB entry '%s' for update", key) 32 | return fmt.Errorf("failed to get DB entry '%s': %w", key, errGet) 33 | } 34 | 35 | var entry models.DatabaseEntry 36 | if errUnmarshal := json.Unmarshal(rawValue, &entry); errUnmarshal != nil { 37 | log.WithError(errUnmarshal).Errorf("Failed to unmarshal DB entry '%s' for update", key) 38 | return fmt.Errorf("failed to unmarshal DB entry '%s': %w", key, errUnmarshal) 39 | } 40 | 41 | // Update the status 42 | entry.Status = newStatus 43 | 44 | // Apply additional modifications if provided 45 | if updateFunc != nil { 46 | updateFunc(&entry) 47 | } 48 | 49 | // Marshal updated entry back to JSON 50 | updatedEntryBytes, marshalErr := json.Marshal(entry) 51 | if marshalErr != nil { 52 | log.WithError(marshalErr).Errorf("Failed to marshal updated DB entry '%s' (Status: %s)", key, newStatus) 53 | return fmt.Errorf("failed to marshal DB entry '%s': %w", key, marshalErr) 54 | } 55 | 56 | // Save updated entry back to DB 57 | if errPut := db.Put([]byte(key), updatedEntryBytes); errPut != nil { 58 | log.WithError(errPut).Errorf("Failed to update DB entry '%s' to status %s", key, newStatus) 59 | return fmt.Errorf("failed to put DB entry '%s': %w", key, errPut) 60 | } 61 | 62 | log.Debugf("Successfully updated DB entry '%s' to status %s", key, newStatus) 63 | return nil 64 | } 65 | 66 | // handleMetadataSaving checks the config and calls saveMetadataFile if needed. 67 | func handleMetadataSaving(logPrefix string, pd potentialDownload, finalPath string, finalStatus string, writer *uilive.Writer) { 68 | if viper.GetBool("savemetadata") { 69 | if finalStatus == models.StatusDownloaded { 70 | log.Debugf("[%s] Saving metadata for successfully downloaded file: %s", logPrefix, finalPath) 71 | if metaErr := saveMetadataFile(pd, finalPath); metaErr != nil { 72 | // Error already logged by saveMetadataFile 73 | if writer != nil { 74 | fmt.Fprintf(writer.Newline(), "[%s] Error saving metadata for %s: %v\n", logPrefix, filepath.Base(finalPath), metaErr) 75 | } 76 | } 77 | } else { 78 | log.Debugf("[%s] Skipping metadata save for %s due to download status: %s.", logPrefix, pd.TargetFilepath, finalStatus) 79 | } 80 | } else { 81 | log.Debugf("[%s] Skipping metadata save (disabled by config) for %s.", logPrefix, finalPath) 82 | } 83 | } 84 | 85 | // downloadWorker handles the actual download of a file and updates the database. 86 | // It now also accepts an imageDownloader, bleveIndex, and concurrencyLevel. 87 | func downloadWorker(id int, jobs <-chan downloadJob, db *database.DB, fileDownloader *downloader.Downloader, imageDownloader *downloader.Downloader, wg *sync.WaitGroup, writer *uilive.Writer, concurrencyLevel int, bleveIndex bleve.Index) { 88 | defer wg.Done() 89 | log.Debugf("Worker %d starting", id) 90 | for job := range jobs { 91 | pd := job.PotentialDownload 92 | dbKey := job.DatabaseKey // Use the key passed in the job 93 | log.Infof("Worker %d: Processing job for %s", id, pd.TargetFilepath) 94 | fmt.Fprintf(writer.Newline(), "Worker %d: Preparing %s...\n", id, filepath.Base(pd.TargetFilepath)) 95 | 96 | // Ensure directory exists 97 | dirPath := filepath.Dir(pd.TargetFilepath) 98 | if err := os.MkdirAll(dirPath, 0700); err != nil { 99 | log.WithError(err).Errorf("Worker %d: Failed to create directory %s", id, dirPath) 100 | // Update DB status to Error using the helper 101 | updateErr := updateDbEntry(db, dbKey, models.StatusError, func(entry *models.DatabaseEntry) { 102 | entry.ErrorDetails = fmt.Sprintf("Failed to create directory: %v", err) 103 | }) 104 | if updateErr != nil { 105 | // Log the error from the helper function 106 | log.Errorf("Worker %d: Failed to update DB status after mkdir error: %v", id, updateErr) 107 | } 108 | fmt.Fprintf(writer.Newline(), "Worker %d: Error creating directory for %s: %v\n", id, filepath.Base(pd.TargetFilepath), err) 109 | continue // Skip to next job 110 | } 111 | 112 | // --- Perform Download --- 113 | startTime := time.Now() 114 | fmt.Fprintf(writer.Newline(), "Worker %d: Checking/Downloading %s...\n", id, filepath.Base(pd.TargetFilepath)) 115 | 116 | // Initiate download - it returns the final path and error 117 | finalPath, downloadErr := fileDownloader.DownloadFile(pd.TargetFilepath, pd.File.DownloadUrl, pd.File.Hashes, pd.ModelVersionID) 118 | 119 | // --- Update DB Based on Result --- 120 | finalStatus := models.StatusError // Default to error 121 | errMsg := "" 122 | if downloadErr != nil { 123 | errMsg = downloadErr.Error() 124 | finalStatus = models.StatusError 125 | } else { 126 | finalStatus = models.StatusDownloaded 127 | } 128 | 129 | // Use the helper function to update the DB entry 130 | updateErr := updateDbEntry(db, dbKey, finalStatus, func(entry *models.DatabaseEntry) { 131 | if downloadErr != nil { 132 | // Update error details on failure 133 | entry.ErrorDetails = errMsg 134 | log.WithError(downloadErr).Errorf("Worker %d: Failed to download %s", id, pd.TargetFilepath) 135 | fmt.Fprintf(writer.Newline(), "Worker %d: Error downloading %s: %v\n", id, filepath.Base(pd.TargetFilepath), downloadErr) 136 | 137 | // Attempt to remove partially downloaded file 138 | if removeErr := os.Remove(pd.TargetFilepath); removeErr != nil && !os.IsNotExist(removeErr) { 139 | log.WithError(removeErr).Warnf("Worker %d: Failed to remove potentially partial file %s after download error", id, pd.TargetFilepath) 140 | } 141 | } else { 142 | // Update fields on success 143 | duration := time.Since(startTime) 144 | log.Infof("Worker %d: Successfully downloaded %s in %v", id, finalPath, duration) 145 | entry.ErrorDetails = "" // Clear any previous error 146 | entry.Filename = filepath.Base(finalPath) // Update filename in DB 147 | entry.File = pd.File // Update File struct 148 | entry.Version = pd.CleanedVersion // Update Version struct 149 | fmt.Fprintf(writer.Newline(), "Worker %d: Success downloading %s\n", id, filepath.Base(finalPath)) 150 | 151 | // --- Index Item with Bleve --- START --- 152 | if bleveIndex != nil { 153 | // Calculate directory paths 154 | directoryPath := filepath.Dir(finalPath) 155 | baseModelPath := filepath.Dir(directoryPath) 156 | modelPath := filepath.Dir(baseModelPath) 157 | 158 | // Parse PublishedAt timestamp 159 | publishedAtTime := time.Time{} 160 | if pd.FullVersion.PublishedAt != "" { 161 | var errParse error 162 | publishedAtTime, errParse = time.Parse(time.RFC3339Nano, pd.FullVersion.PublishedAt) 163 | if errParse != nil { 164 | publishedAtTime, errParse = time.Parse(time.RFC3339, pd.FullVersion.PublishedAt) 165 | if errParse != nil { 166 | log.WithError(errParse).Warnf("Worker %d: Failed to parse PublishedAt time '%s' for indexing", id, pd.FullVersion.PublishedAt) 167 | // Keep publishedAtTime as zero time 168 | } 169 | } 170 | } 171 | 172 | // Get file metadata 173 | fileFormat := pd.File.Metadata.Format // Already string 174 | filePrecision := pd.File.Metadata.Fp // Already string 175 | fileSizeType := pd.File.Metadata.Size // Already string 176 | 177 | itemToIndex := index.Item{ 178 | ID: fmt.Sprintf("v_%d", pd.ModelVersionID), // Use the same key format as DB 179 | Type: "model_file", 180 | Name: pd.File.Name, // Use the original file name 181 | Description: pd.CleanedVersion.Description, // Use model version description if available 182 | FilePath: finalPath, 183 | DirectoryPath: directoryPath, 184 | BaseModelPath: baseModelPath, 185 | ModelPath: modelPath, 186 | ModelName: pd.ModelName, 187 | VersionName: pd.VersionName, 188 | BaseModel: pd.BaseModel, 189 | CreatorName: pd.Creator.Username, 190 | Tags: pd.FullVersion.TrainedWords, // Use TrainedWords as tags for now 191 | // New Fields 192 | PublishedAt: publishedAtTime, // Parsed time.Time 193 | VersionDownloadCount: float64(pd.FullVersion.Stats.DownloadCount), // Convert int to float64 194 | VersionRating: pd.FullVersion.Stats.Rating, // float64 195 | VersionRatingCount: float64(pd.FullVersion.Stats.RatingCount), // Convert int to float64 196 | FileSizeKB: pd.File.SizeKB, // float64 197 | FileFormat: fileFormat, // string 198 | FilePrecision: filePrecision, // string 199 | FileSizeType: fileSizeType, // string 200 | } 201 | if indexErr := index.IndexItem(bleveIndex, itemToIndex); indexErr != nil { 202 | log.WithError(indexErr).Errorf("Worker %d: Failed to index downloaded item %s (ID: %s)", id, finalPath, itemToIndex.ID) 203 | // Don't treat indexing failure as a download failure 204 | } else { 205 | log.Debugf("Worker %d: Successfully indexed item %s (ID: %s)", id, finalPath, itemToIndex.ID) 206 | } 207 | } 208 | // --- Index Item with Bleve --- END --- 209 | } 210 | }) 211 | 212 | if updateErr != nil { 213 | // Log error from the helper function, but continue with other tasks like image download if download was successful 214 | log.Errorf("Worker %d: Failed to update DB status after download attempt: %v", id, updateErr) 215 | fmt.Fprintf(writer.Newline(), "Worker %d: DB Error updating status for %s\n", id, pd.FinalBaseFilename) 216 | } 217 | 218 | // --- Metadata Saving --- 219 | logPrefix := fmt.Sprintf("Worker %d", id) 220 | handleMetadataSaving(logPrefix, pd, finalPath, finalStatus, writer) 221 | 222 | // --- Download Version Images if Enabled and Successful --- 223 | saveVersionImages := viper.GetBool("saveversionimages") 224 | if saveVersionImages && finalStatus == models.StatusDownloaded { 225 | logPrefix := fmt.Sprintf("Worker %d Img", id) 226 | log.Infof("[%s] Downloading version images for %s (%s)...", logPrefix, pd.ModelName, pd.VersionName) 227 | modelFileDir := filepath.Dir(finalPath) // Use finalPath from model download 228 | versionImagesDir := filepath.Join(modelFileDir, "images") 229 | 230 | // Add log before calling downloadImages 231 | log.Debugf("[%s] Calling downloadImages for %d images...", logPrefix, len(pd.OriginalImages)) 232 | // Call the helper function, passing concurrencyLevel, removing writer 233 | imgSuccess, imgFail := downloadImages(logPrefix, pd.OriginalImages, versionImagesDir, imageDownloader, concurrencyLevel) 234 | log.Infof("[%s] Finished downloading version images for %s (%s). Success: %d, Failed: %d", 235 | logPrefix, pd.ModelName, pd.VersionName, imgSuccess, imgFail) 236 | } 237 | // --- End Download Version Images --- 238 | } 239 | log.Debugf("Worker %d finished", id) 240 | fmt.Fprintf(writer.Newline(), "Worker %d: Finished job processing.\n", id) // Final update for the worker 241 | } 242 | 243 | // saveMetadataFile saves the cleaned model version metadata to a .json file. 244 | // It derives the metadata filename from the provided modelFilePath. 245 | func saveMetadataFile(pd potentialDownload, modelFilePath string) error { 246 | // Calculate metadata path based on the model file path 247 | metadataPath := strings.TrimSuffix(modelFilePath, filepath.Ext(modelFilePath)) + ".json" 248 | // Ensure the target directory exists 249 | dirPath := filepath.Dir(metadataPath) 250 | if err := os.MkdirAll(dirPath, 0700); err != nil { 251 | log.WithError(err).Errorf("Failed to create directory for metadata file: %s", dirPath) 252 | return fmt.Errorf("failed to create directory %s: %w", dirPath, err) 253 | } 254 | 255 | // Marshal the full version info 256 | jsonData, jsonErr := json.MarshalIndent(pd.FullVersion, "", " ") 257 | if jsonErr != nil { 258 | log.WithError(jsonErr).Warnf("Failed to marshal metadata for %s", modelFilePath) 259 | return fmt.Errorf("failed to marshal metadata for %s: %w", pd.ModelName, jsonErr) 260 | } 261 | 262 | // Write the file 263 | if writeErr := os.WriteFile(metadataPath, jsonData, 0600); writeErr != nil { 264 | log.WithError(writeErr).Warnf("Failed to write metadata file %s", metadataPath) 265 | return fmt.Errorf("failed to write metadata file %s: %w", metadataPath, writeErr) 266 | } 267 | 268 | log.Debugf("Saved metadata to %s", metadataPath) 269 | return nil 270 | } 271 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/cmd_images_run.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "bufio" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "net/url" 11 | "os" 12 | "path/filepath" 13 | "strconv" 14 | "strings" 15 | "sync" 16 | "sync/atomic" 17 | "time" 18 | 19 | "github.com/gosuri/uilive" 20 | log "github.com/sirupsen/logrus" 21 | "github.com/spf13/cobra" 22 | "github.com/spf13/viper" 23 | 24 | index "go-civitai-download/index" 25 | "go-civitai-download/internal/downloader" 26 | "go-civitai-download/internal/models" 27 | ) 28 | 29 | // runImages orchestrates the fetching and downloading of images based on command-line flags. 30 | func runImages(cmd *cobra.Command, args []string) { 31 | // Read flags 32 | modelID := viper.GetInt("images.modelId") 33 | modelVersionID := viper.GetInt("images.modelVersionId") 34 | username := viper.GetString("images.username") 35 | limit := viper.GetInt("images.limit") 36 | period := viper.GetString("images.period") 37 | sort := viper.GetString("images.sort") 38 | nsfw := viper.GetString("images.nsfw") 39 | targetDir := viper.GetString("images.output_dir") 40 | saveMeta := viper.GetBool("images.metadata") 41 | numWorkers := viper.GetInt("images.concurrency") 42 | maxPages := viper.GetInt("images.max_pages") 43 | postID := viper.GetInt("images.postId") 44 | 45 | // --- Early Exit for Debug Print API URL --- START --- 46 | if printUrl, _ := cmd.Flags().GetBool("debug-print-api-url"); printUrl { 47 | log.Info("--- Debug API URL (--debug-print-api-url) for Images ---") 48 | // Construct URL parameters (logic duplicated/extracted from below) 49 | baseURL := "https://civitai.com/api/v1/images" 50 | params := url.Values{} 51 | if modelVersionID != 0 { 52 | params.Set("modelVersionId", strconv.Itoa(modelVersionID)) 53 | } else if modelID != 0 { 54 | params.Set("modelId", strconv.Itoa(modelID)) 55 | } else if username != "" { 56 | params.Set("username", username) 57 | } else if postID != 0 { 58 | params.Set("postId", strconv.Itoa(postID)) 59 | } 60 | if limit > 0 && limit <= 200 { 61 | params.Set("limit", strconv.Itoa(limit)) 62 | } else if limit != 100 { 63 | log.Warnf("Invalid limit %d, using API default (100). Actual API call might use different default.", limit) 64 | params.Set("limit", "100") 65 | } 66 | if period != "" { 67 | params.Set("period", period) 68 | } 69 | if sort != "" { 70 | params.Set("sort", sort) 71 | } 72 | if nsfw != "" { 73 | params.Set("nsfw", nsfw) 74 | } 75 | // Note: Does not include cursor logic, as this prints the base URL for the first page. 76 | requestURL := baseURL + "?" + params.Encode() 77 | fmt.Println(requestURL) // Print only the URL to stdout 78 | log.Info("Exiting after printing images API URL.") 79 | os.Exit(0) // Exit immediately 80 | } 81 | // --- Early Exit for Debug Print API URL --- END --- 82 | 83 | // --- Display Effective Config & Confirm --- START --- 84 | // Skip display/confirmation if global --yes flag is provided 85 | if !viper.GetBool("skipconfirmation") { 86 | log.Info("--- Review Effective Configuration (Images Command) ---") 87 | 88 | // 1. Global Settings (Relevant to Images) 89 | globalSettings := map[string]interface{}{ 90 | "SavePath": viper.GetString("savepath"), 91 | "OutputDir": viper.GetString("images.output_dir"), // Display explicit output dir 92 | "ApiKeySet": viper.GetString("apikey") != "", // Show if API key is present 93 | "ApiClientTimeoutSec": viper.GetInt("apiclienttimeoutsec"), 94 | "ApiDelayMs": viper.GetInt("apidelayms"), 95 | "LogApiRequests": viper.GetBool("logapirequests"), 96 | "Concurrency": viper.GetInt("images.concurrency"), // Show image-specific concurrency 97 | } 98 | globalJSON, _ := json.MarshalIndent(globalSettings, " ", " ") 99 | fmt.Println(" --- Global Settings (Relevant to Images) ---") 100 | fmt.Println(" " + strings.ReplaceAll(string(globalJSON), "\n", "\n ")) 101 | 102 | // 2. Image API Parameters 103 | imageAPIParams := map[string]interface{}{ 104 | "ModelID": viper.GetInt("images.modelId"), 105 | "ModelVersionID": viper.GetInt("images.modelVersionId"), 106 | "PostID": viper.GetInt("images.postId"), 107 | "Username": viper.GetString("images.username"), 108 | "Limit": viper.GetInt("images.limit"), 109 | "Period": viper.GetString("images.period"), 110 | "Sort": viper.GetString("images.sort"), 111 | "NSFW": viper.GetString("images.nsfw"), 112 | "MaxPages": viper.GetInt("images.max_pages"), 113 | "SaveMetadata": viper.GetBool("images.metadata"), 114 | } 115 | apiParamsJSON, _ := json.MarshalIndent(imageAPIParams, " ", " ") 116 | fmt.Println("\n --- Image API Parameters ---") 117 | fmt.Println(" " + strings.ReplaceAll(string(apiParamsJSON), "\n", "\n ")) 118 | 119 | // Confirmation Prompt 120 | reader := bufio.NewReader(os.Stdin) 121 | fmt.Print("\nProceed with these settings? (y/N): ") 122 | input, _ := reader.ReadString('\n') 123 | input = strings.ToLower(strings.TrimSpace(input)) 124 | 125 | if input != "y" { 126 | log.Info("Operation cancelled by user.") 127 | os.Exit(0) 128 | } 129 | log.Info("Configuration confirmed.") 130 | } else { 131 | log.Info("Skipping configuration review due to --yes flag or config setting.") 132 | } 133 | // --- Display Effective Config & Confirm --- END --- 134 | 135 | // Add log to confirm concurrency level 136 | log.Infof("Using image download concurrency level: %d", numWorkers) 137 | 138 | // Default output dir if not provided 139 | if targetDir == "" { 140 | if globalConfig.SavePath == "" { 141 | log.Fatal("Required configuration 'SavePath' is not set and --output-dir flag was not provided.") 142 | } 143 | targetDir = filepath.Join(globalConfig.SavePath, "images") 144 | log.Infof("Output directory not specified, using default: %s", targetDir) 145 | } 146 | 147 | // Validate flags 148 | if modelID == 0 && modelVersionID == 0 && username == "" { 149 | log.Fatal("At least one of --model-id, --model-version-id, or --username must be provided") 150 | } 151 | if modelVersionID != 0 { 152 | log.Infof("Filtering images by Model Version ID: %d (overrides --model-id)", modelVersionID) 153 | modelID = 0 154 | } 155 | 156 | // --- API Client Setup (standard http client) --- 157 | if globalHttpTransport == nil { 158 | log.Warn("Global HTTP transport not initialized, using default.") 159 | globalHttpTransport = http.DefaultTransport 160 | } 161 | apiClient := &http.Client{ 162 | Transport: globalHttpTransport, 163 | Timeout: time.Duration(globalConfig.ApiClientTimeoutSec) * time.Second, 164 | } 165 | 166 | // --- Fetch Image List --- 167 | log.Info("Fetching image list from Civitai API...") 168 | 169 | var allImages []models.ImageApiItem 170 | baseURL := "https://civitai.com/api/v1/images" 171 | params := url.Values{} 172 | userTotalLimit := viper.GetInt("images.limit") // User's intended total limit (0 = unlimited) 173 | 174 | if modelVersionID != 0 { 175 | params.Set("modelVersionId", strconv.Itoa(modelVersionID)) 176 | } else if modelID != 0 { 177 | params.Set("modelId", strconv.Itoa(modelID)) 178 | } else if username != "" { 179 | params.Set("username", username) 180 | } else if postID != 0 { 181 | params.Set("postId", strconv.Itoa(postID)) 182 | } 183 | 184 | // Use API default/max limit per page (e.g., 100 or 200) for efficiency. 185 | // Do NOT send the user's total limit here. 186 | params.Set("limit", "100") // Request a reasonable number per page 187 | 188 | // These parameters are still valid API parameters to send 189 | if period != "" { 190 | params.Set("period", period) 191 | } 192 | if sort != "" { 193 | params.Set("sort", sort) 194 | } 195 | if nsfw != "" { 196 | params.Set("nsfw", nsfw) 197 | } 198 | 199 | pageCount := 0 200 | var nextCursor string 201 | var loopErr error 202 | 203 | log.Info("--- Starting Image Fetching ---") 204 | 205 | for { 206 | pageCount++ 207 | if maxPages > 0 && pageCount > maxPages { 208 | log.Infof("Reached max pages limit (%d). Stopping.", maxPages) 209 | break 210 | } 211 | 212 | currentParams := params 213 | if nextCursor != "" { 214 | currentParams.Set("cursor", nextCursor) 215 | } 216 | requestURL := baseURL + "?" + currentParams.Encode() 217 | 218 | log.Debugf("Requesting Image URL (Page %d inferred, Cursor: %s): %s", pageCount, nextCursor, requestURL) 219 | 220 | // --- Check for debug flag --- NEW 221 | if printUrl, _ := cmd.Flags().GetBool("debug-print-api-url"); printUrl { 222 | fmt.Println(requestURL) // Print only the URL to stdout 223 | os.Exit(0) // Exit immediately 224 | } 225 | // --- End check for debug flag --- NEW 226 | 227 | req, err := http.NewRequest("GET", requestURL, nil) 228 | if err != nil { 229 | loopErr = fmt.Errorf("failed to create request for page %d: %w", pageCount, err) 230 | break 231 | } 232 | if globalConfig.ApiKey != "" { 233 | req.Header.Add("Authorization", "Bearer "+globalConfig.ApiKey) 234 | } 235 | 236 | resp, err := apiClient.Do(req) 237 | if err != nil { 238 | if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() { 239 | log.WithError(err).Warnf("Timeout fetching image metadata page %d. Retrying after delay...", pageCount) 240 | time.Sleep(5 * time.Second) 241 | continue 242 | } 243 | loopErr = fmt.Errorf("failed to fetch image metadata page %d: %w", pageCount, err) 244 | break 245 | } 246 | 247 | bodyBytes, readErr := io.ReadAll(resp.Body) 248 | if closeErr := resp.Body.Close(); closeErr != nil { 249 | log.WithError(closeErr).Warn("Error closing image API response body") 250 | } 251 | 252 | if readErr != nil { 253 | loopErr = fmt.Errorf("failed to read response body (Page %d): %w", pageCount, readErr) 254 | break 255 | } 256 | 257 | if resp.StatusCode != http.StatusOK { 258 | errMsg := fmt.Sprintf("Image API request failed (Page %d inferred) with status %s", pageCount, resp.Status) 259 | if len(bodyBytes) > 0 { 260 | maxLen := 200 261 | bodyStr := string(bodyBytes) 262 | if len(bodyStr) > maxLen { 263 | bodyStr = bodyStr[:maxLen] + "..." 264 | } 265 | errMsg += fmt.Sprintf(". Response: %s", bodyStr) 266 | } 267 | log.Error(errMsg) 268 | if resp.StatusCode == http.StatusTooManyRequests { 269 | log.Warn("Rate limited. Applying longer delay...") 270 | delay := time.Duration(globalConfig.ApiDelayMs)*time.Millisecond*5 + 5*time.Second 271 | time.Sleep(delay) 272 | continue 273 | } 274 | loopErr = errors.New(errMsg) 275 | break 276 | } 277 | 278 | var response models.ImageApiResponse 279 | if err := json.Unmarshal(bodyBytes, &response); err != nil { 280 | loopErr = fmt.Errorf("failed to decode image API response (Page %d): %w", pageCount, err) 281 | log.WithError(err).Errorf("Response body sample: %s", string(bodyBytes[:min(len(bodyBytes), 200)])) 282 | break 283 | } 284 | 285 | if len(response.Items) == 0 { 286 | log.Info("Received empty items list from API. Assuming end of results.") 287 | break 288 | } 289 | 290 | log.Infof("Received %d images from API page %d. Total collected: %d", len(response.Items), pageCount, len(allImages)) 291 | allImages = append(allImages, response.Items...) 292 | 293 | // --- Check Total Limit --- START --- 294 | if userTotalLimit > 0 && len(allImages) >= userTotalLimit { 295 | log.Infof("Reached total image limit (%d). Stopping image fetching.", userTotalLimit) 296 | allImages = allImages[:userTotalLimit] // Truncate to exact limit 297 | break // Stop fetching more pages 298 | } 299 | // --- Check Total Limit --- END --- 300 | 301 | nextCursor = response.Metadata.NextCursor 302 | if nextCursor == "" { 303 | log.Info("No next cursor found. Finished fetching.") 304 | break 305 | } 306 | 307 | log.Debugf("Next cursor found: %s", nextCursor) 308 | 309 | if globalConfig.ApiDelayMs > 0 { 310 | log.Debugf("Applying API delay: %d ms", globalConfig.ApiDelayMs) 311 | time.Sleep(time.Duration(globalConfig.ApiDelayMs) * time.Millisecond) 312 | } 313 | } 314 | 315 | if loopErr != nil { 316 | log.WithError(loopErr).Error("Image fetching stopped due to error.") 317 | if len(allImages) == 0 { 318 | log.Fatal("Exiting as no images were fetched before the error.") 319 | } 320 | log.Warnf("Proceeding with %d images fetched before the error.", len(allImages)) 321 | } else { 322 | log.Info("--- Finished Image Fetching ---") 323 | } 324 | 325 | if len(allImages) == 0 { 326 | log.Info("No images found matching the criteria after fetching.") 327 | return 328 | } 329 | log.Infof("Found %d total images to potentially download.", len(allImages)) 330 | 331 | // --- Initialize Bleve Index --- START --- 332 | // Use targetDir as base for index path, ensuring it's consistent 333 | indexPath := globalConfig.BleveIndexPath 334 | if indexPath == "" { 335 | indexPath = filepath.Join(targetDir, "civitai_images.bleve") // Default if config is empty 336 | log.Warnf("BleveIndexPath not set in config, defaulting index path for image downloads to: %s", indexPath) 337 | } else { 338 | // If a shared index path is provided, images might go into the same index 339 | // Or we could append a sub-directory like "images"? For now, use the path directly. 340 | // Example: If BleveIndexPath = /path/to/index, index will be at /path/to/index 341 | } 342 | log.Infof("Opening/Creating Bleve index at: %s", indexPath) 343 | bleveIndex, err := index.OpenOrCreateIndex(indexPath) 344 | if err != nil { 345 | log.Fatalf("Failed to open or create Bleve index: %v", err) 346 | } 347 | defer func() { 348 | log.Info("Closing Bleve index.") 349 | if err := bleveIndex.Close(); err != nil { 350 | log.Errorf("Error closing Bleve index: %v", err) 351 | } 352 | }() 353 | log.Info("Bleve index opened successfully.") 354 | // --- Initialize Bleve Index --- END --- 355 | 356 | // --- Downloader Setup --- 357 | downloadClient := &http.Client{ 358 | Transport: globalHttpTransport, 359 | Timeout: 0, 360 | } 361 | dl := downloader.NewDownloader(downloadClient, globalConfig.ApiKey) 362 | 363 | // --- Target Directory --- 364 | finalBaseTargetDir := targetDir 365 | log.Infof("Ensuring base target directory exists: %s", finalBaseTargetDir) 366 | if err := os.MkdirAll(finalBaseTargetDir, 0750); err != nil { 367 | log.WithError(err).Fatalf("Failed to create base target directory: %s", finalBaseTargetDir) 368 | } 369 | 370 | // --- Download Workers --- 371 | var wg sync.WaitGroup 372 | jobs := make(chan imageJob, len(allImages)) 373 | writer := uilive.New() 374 | writer.Start() 375 | 376 | var successCount int64 377 | var failureCount int64 378 | 379 | log.Infof("Starting %d image download workers...", numWorkers) 380 | for w := 1; w <= numWorkers; w++ { 381 | wg.Add(1) 382 | go imageDownloadWorker(w, jobs, dl, &wg, writer, &successCount, &failureCount, saveMeta, finalBaseTargetDir, bleveIndex) 383 | } 384 | 385 | // --- Queue Jobs --- 386 | log.Info("Queueing image download jobs...") 387 | queuedCount := 0 388 | for _, image := range allImages { 389 | if image.URL == "" { 390 | log.Warnf("Image ID %d has no URL, skipping.", image.ID) 391 | continue 392 | } 393 | 394 | job := imageJob{ 395 | SourceURL: image.URL, 396 | ImageID: image.ID, 397 | Metadata: image, 398 | } 399 | jobs <- job 400 | queuedCount++ 401 | } 402 | close(jobs) 403 | log.Infof("Queued %d image jobs.", queuedCount) 404 | 405 | // --- Wait for Completion --- 406 | log.Info("Waiting for image download workers to finish...") 407 | wg.Wait() 408 | writer.Stop() 409 | 410 | // --- Final Report --- 411 | finalSuccessCount := atomic.LoadInt64(&successCount) 412 | finalFailureCount := atomic.LoadInt64(&failureCount) 413 | 414 | log.Infof("Image download process completed.") 415 | log.Infof("Successfully downloaded: %d images", finalSuccessCount) 416 | log.Infof("Failed to download: %d images", finalFailureCount) 417 | 418 | if finalFailureCount > 0 { 419 | log.Warn("Some image downloads failed. Check logs for details.") 420 | } 421 | 422 | fmt.Println("----- Download Summary -----") 423 | fmt.Printf(" Target Base Directory: %s\n", finalBaseTargetDir) 424 | fmt.Printf(" Total Images Found API: %d\n", len(allImages)) 425 | fmt.Printf(" Images Queued: %d\n", queuedCount) 426 | fmt.Printf(" Successfully Downloaded: %d\n", finalSuccessCount) 427 | fmt.Printf(" Failed Downloads: %d\n", finalFailureCount) 428 | fmt.Printf(" Metadata Saved: %t\n", saveMeta) 429 | fmt.Println("--------------------------") 430 | } 431 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/cmd_images_setup.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/viper" 5 | ) 6 | 7 | func init() { 8 | // imagesCmd is defined in images.go 9 | rootCmd.AddCommand(imagesCmd) 10 | 11 | // --- Flags for Image Command --- 12 | imagesCmd.Flags().Int("limit", 100, "Max images per page (1-200).") 13 | imagesCmd.Flags().Int("post-id", 0, "Filter by Post ID.") 14 | imagesCmd.Flags().Int("model-id", 0, "Filter by Model ID.") 15 | imagesCmd.Flags().Int("model-version-id", 0, "Filter by Model Version ID (overrides model-id and post-id if set).") 16 | imagesCmd.Flags().StringP("username", "u", "", "Filter by username.") 17 | // Use string for nsfw flag to handle both boolean and enum values easily 18 | imagesCmd.Flags().String("nsfw", "", "Filter by NSFW level (None, Soft, Mature, X) or boolean (true/false). Empty means all.") 19 | imagesCmd.Flags().StringP("sort", "s", "Newest", "Sort order (Most Reactions, Most Comments, Newest).") 20 | imagesCmd.Flags().StringP("period", "p", "AllTime", "Time period for sorting (AllTime, Year, Month, Week, Day).") 21 | imagesCmd.Flags().Int("page", 1, "Starting page number (API defaults to 1).") // API uses page-based for images 22 | imagesCmd.Flags().Int("max-pages", 0, "Maximum number of API pages to fetch (0 for no limit)") 23 | imagesCmd.Flags().StringP("output-dir", "o", "", "Directory to save images (default: [SavePath]/images).") 24 | // Define a local variable for image command's concurrency flag 25 | var imageConcurrency int 26 | imagesCmd.Flags().IntVarP(&imageConcurrency, "concurrency", "c", 4, "Number of concurrent image downloads") 27 | // Add the save-metadata flag 28 | imagesCmd.Flags().Bool("metadata", false, "Save a .json metadata file alongside each downloaded image.") 29 | 30 | // Hidden flag for testing API URL generation 31 | imagesCmd.Flags().Bool("debug-print-api-url", false, "Print the constructed API URL for image fetching and exit") 32 | imagesCmd.Flags().MarkHidden("debug-print-api-url") // Hide from help output 33 | 34 | // Bind flags to Viper (optional) 35 | viper.BindPFlag("images.limit", imagesCmd.Flags().Lookup("limit")) 36 | viper.BindPFlag("images.postId", imagesCmd.Flags().Lookup("post-id")) 37 | viper.BindPFlag("images.modelId", imagesCmd.Flags().Lookup("model-id")) 38 | viper.BindPFlag("images.modelVersionId", imagesCmd.Flags().Lookup("model-version-id")) 39 | viper.BindPFlag("images.username", imagesCmd.Flags().Lookup("username")) 40 | viper.BindPFlag("images.nsfw", imagesCmd.Flags().Lookup("nsfw")) 41 | viper.BindPFlag("images.sort", imagesCmd.Flags().Lookup("sort")) 42 | viper.BindPFlag("images.period", imagesCmd.Flags().Lookup("period")) 43 | viper.BindPFlag("images.page", imagesCmd.Flags().Lookup("page")) 44 | viper.BindPFlag("images.max_pages", imagesCmd.Flags().Lookup("max-pages")) 45 | viper.BindPFlag("images.output_dir", imagesCmd.Flags().Lookup("output-dir")) 46 | viper.BindPFlag("images.concurrency", imagesCmd.Flags().Lookup("concurrency")) 47 | // Bind the new flag 48 | viper.BindPFlag("images.metadata", imagesCmd.Flags().Lookup("metadata")) 49 | } 50 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/cmd_images_worker.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/url" 7 | "os" 8 | "path/filepath" 9 | "strings" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | 14 | index "go-civitai-download/index" 15 | "go-civitai-download/internal/downloader" 16 | "go-civitai-download/internal/helpers" 17 | "go-civitai-download/internal/models" 18 | 19 | "github.com/blevesearch/bleve/v2" 20 | "github.com/gosuri/uilive" 21 | log "github.com/sirupsen/logrus" 22 | ) 23 | 24 | // Represents an image download task 25 | type imageJob struct { 26 | SourceURL string 27 | ImageID int 28 | Metadata models.ImageApiItem 29 | } 30 | 31 | // --- Helper to save metadata --- START --- 32 | func saveMetadataJSON(id int, job imageJob, targetPath string, writer *uilive.Writer) { 33 | baseFilename := filepath.Base(targetPath) 34 | metadataPath := strings.TrimSuffix(targetPath, filepath.Ext(targetPath)) + ".json" 35 | jsonData, jsonErr := json.MarshalIndent(job.Metadata, "", " ") 36 | if jsonErr != nil { 37 | log.WithError(jsonErr).Warnf("Worker %d: Failed to marshal image metadata for %s", id, baseFilename) 38 | fmt.Fprintf(writer.Newline(), "Worker %d: Error marshalling metadata for %s\n", id, baseFilename) 39 | } else { 40 | if writeErr := os.WriteFile(metadataPath, jsonData, 0600); writeErr != nil { 41 | log.WithError(writeErr).Warnf("Worker %d: Failed to write image metadata file %s", id, metadataPath) 42 | fmt.Fprintf(writer.Newline(), "Worker %d: Error writing metadata file for %s\n", id, baseFilename) 43 | } else { 44 | log.Infof("Worker %d: Saved image metadata to %s", id, metadataPath) // Info level for explicit save 45 | fmt.Fprintf(writer.Newline(), "Worker %d: Saved metadata for %s\n", id, baseFilename) 46 | } 47 | } 48 | } 49 | 50 | // --- Helper to save metadata --- END --- 51 | 52 | // imageDownloadWorker handles the download of a single image. 53 | // Added baseOutputDir and bleveIndex parameters. 54 | func imageDownloadWorker(id int, jobs <-chan imageJob, downloader *downloader.Downloader, wg *sync.WaitGroup, writer *uilive.Writer, successCounter *int64, failureCounter *int64, saveMeta bool, baseOutputDir string, bleveIndex bleve.Index) { 55 | defer wg.Done() 56 | log.Debugf("Image Worker %d starting", id) 57 | for job := range jobs { 58 | 59 | // --- Construct Target Path --- START --- 60 | // Create subdirectory based on username 61 | authorSlug := helpers.ConvertToSlug(job.Metadata.Username) 62 | if authorSlug == "" { 63 | authorSlug = "unknown_author" // Fallback 64 | } 65 | // Add BaseModel subdirectory 66 | baseModelSlug := helpers.ConvertToSlug(job.Metadata.BaseModel) 67 | if baseModelSlug == "" { 68 | baseModelSlug = "unknown_base_model" 69 | } 70 | targetSubDir := filepath.Join(baseOutputDir, authorSlug, baseModelSlug) // Include baseModelSlug 71 | 72 | // Construct filename: {id}-{url_filename_base}.{ext} 73 | var filename string 74 | imgURLParsed, urlErr := url.Parse(job.SourceURL) // Need to import "net/url" 75 | if urlErr != nil { 76 | log.WithError(urlErr).Warnf("Worker %d: Could not parse image URL %s for image ID %d. Using generic filename.", id, job.SourceURL, job.ImageID) 77 | filename = fmt.Sprintf("%d.image", job.ImageID) // Fallback includes ID 78 | } else { 79 | base := filepath.Base(imgURLParsed.Path) 80 | ext := filepath.Ext(base) 81 | nameOnly := strings.TrimSuffix(base, ext) 82 | safeName := helpers.ConvertToSlug(nameOnly) 83 | if safeName == "" { 84 | safeName = "image" 85 | } 86 | if ext == "" { 87 | // Guess extension based on typical Civitai usage or headers if possible 88 | // For now, default to jpg 89 | ext = ".jpg" 90 | log.Debugf("Worker %d: Could not determine extension for %s (ID %d), defaulting to .jpg", id, base, job.ImageID) 91 | } 92 | filename = fmt.Sprintf("%d-%s%s", job.ImageID, safeName, ext) 93 | } 94 | 95 | // Ensure the target subdirectory exists 96 | if err := os.MkdirAll(targetSubDir, 0750); err != nil { 97 | log.WithError(err).Errorf("Worker %d: Failed to create target directory %s for image %d, skipping download.", id, targetSubDir, job.ImageID) 98 | fmt.Fprintf(writer.Newline(), "Worker %d: Error creating dir for %s, skipping\n", id, filename) 99 | atomic.AddInt64(failureCounter, 1) // Count as failure 100 | continue 101 | } 102 | 103 | targetPath := filepath.Join(targetSubDir, filename) 104 | // --- Construct Target Path --- END --- 105 | 106 | baseFilename := filepath.Base(targetPath) // Use calculated base filename 107 | fmt.Fprintf(writer.Newline(), "Worker %d: Preparing %s (ID: %d)...\n", id, baseFilename, job.ImageID) 108 | 109 | // Check if image file already exists 110 | if _, err := os.Stat(targetPath); err == nil { 111 | log.Infof("Worker %d: Image file %s (ID: %d) already exists.", id, baseFilename, job.ImageID) 112 | // If file exists, check if metadata needs saving 113 | if saveMeta { 114 | metadataPath := strings.TrimSuffix(targetPath, filepath.Ext(targetPath)) + ".json" 115 | if _, metaErr := os.Stat(metadataPath); os.IsNotExist(metaErr) { 116 | log.Infof("Worker %d: Image exists, but metadata %s is missing. Saving metadata.", id, filepath.Base(metadataPath)) 117 | saveMetadataJSON(id, job, targetPath, writer) // Call helper to save 118 | } else if metaErr == nil { 119 | log.Debugf("Worker %d: Metadata file %s also exists.", id, filepath.Base(metadataPath)) 120 | } else { 121 | // Log error if stating metadata file failed for other reasons 122 | log.WithError(metaErr).Warnf("Worker %d: Could not check status of metadata file %s", id, metadataPath) 123 | } 124 | } 125 | // Skip the download 126 | fmt.Fprintf(writer.Newline(), "Worker %d: Skipping %s (Exists)\n", id, baseFilename) 127 | continue // Skip download steps 128 | } 129 | 130 | // --- Download section (only runs if file doesn't exist) --- 131 | fmt.Fprintf(writer.Newline(), "Worker %d: Downloading %s (ID: %d)...\n", id, baseFilename, job.ImageID) 132 | startTime := time.Now() 133 | 134 | // Use DownloadFile with the constructed targetPath 135 | _, dlErr := downloader.DownloadFile(targetPath, job.SourceURL, models.Hashes{}, 0) 136 | 137 | if dlErr != nil { 138 | log.WithError(dlErr).Errorf("Worker %d: Failed to download image %s from %s", id, targetPath, job.SourceURL) 139 | fmt.Fprintf(writer.Newline(), "Worker %d: Error downloading %s: %v\n", id, baseFilename, dlErr) 140 | // Attempt to remove partial file 141 | if removeErr := os.Remove(targetPath); removeErr != nil && !os.IsNotExist(removeErr) { 142 | log.WithError(removeErr).Warnf("Worker %d: Failed to remove partial image %s after error", id, targetPath) 143 | } 144 | atomic.AddInt64(failureCounter, 1) 145 | } else { 146 | duration := time.Since(startTime) 147 | log.Infof("Worker %d: Successfully downloaded %s in %v", id, targetPath, duration) 148 | fmt.Fprintf(writer.Newline(), "Worker %d: Success downloading %s (%v)\n", id, baseFilename, duration.Round(time.Millisecond)) 149 | // Increment success counter 150 | atomic.AddInt64(successCounter, 1) 151 | 152 | // --- Save Metadata if Enabled (after successful download) --- 153 | if saveMeta { 154 | saveMetadataJSON(id, job, targetPath, writer) // Call helper to save 155 | } 156 | // --- End Save Metadata --- 157 | 158 | // --- Index Item with Bleve --- START --- 159 | if bleveIndex != nil { 160 | // Extract data from meta with type assertions 161 | var tags []string 162 | var prompt string 163 | var modelName string // Field not directly available, might be in meta? 164 | 165 | if metaMap, ok := job.Metadata.Meta.(map[string]interface{}); ok && metaMap != nil { 166 | if p, ok := metaMap["prompt"].(string); ok { 167 | prompt = p 168 | } 169 | if t, ok := metaMap["tags"].([]interface{}); ok { 170 | for _, tagInterface := range t { 171 | if tagStr, ok := tagInterface.(string); ok { 172 | tags = append(tags, tagStr) 173 | } 174 | } 175 | } 176 | // Check for model name in meta (unlikely standard field) 177 | if mn, ok := metaMap["modelName"].(string); ok { 178 | modelName = mn 179 | } else if mn, ok := metaMap["model"].(string); ok { // Common alternative key 180 | modelName = mn 181 | } 182 | } 183 | 184 | itemToIndex := index.Item{ 185 | ID: fmt.Sprintf("img_%d", job.ImageID), 186 | Type: "image", 187 | Name: baseFilename, // Use the calculated filename 188 | Description: prompt, // Use extracted prompt as description 189 | FilePath: targetPath, 190 | ModelName: modelName, // Use extracted model name if found 191 | BaseModel: job.Metadata.BaseModel, 192 | CreatorName: job.Metadata.Username, 193 | Tags: tags, // Use extracted tags 194 | Prompt: prompt, 195 | NsfwLevel: job.Metadata.NsfwLevel, 196 | } 197 | if indexErr := index.IndexItem(bleveIndex, itemToIndex); indexErr != nil { 198 | log.WithError(indexErr).Errorf("Worker %d: Failed to index downloaded image %s (ID: %s)", id, targetPath, itemToIndex.ID) 199 | } else { 200 | log.Debugf("Worker %d: Successfully indexed image %s (ID: %s)", id, targetPath, itemToIndex.ID) 201 | } 202 | } 203 | // --- Index Item with Bleve --- END --- 204 | } 205 | } 206 | log.Debugf("Image Worker %d finished", id) 207 | fmt.Fprintf(writer.Newline(), "Worker %d: Finished image job processing.\n", id) 208 | } 209 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/cmd_search_images.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "path/filepath" 5 | 6 | log "github.com/sirupsen/logrus" 7 | "github.com/spf13/cobra" 8 | // No Bleve/index import needed here, logic is in runSearchLogic 9 | ) 10 | 11 | // searchImagesCmd represents the command to search the images index 12 | var searchImagesCmd = &cobra.Command{ 13 | Use: "images", 14 | Short: "Search the Bleve index for downloaded images", 15 | Long: `Performs a search against the Bleve index for downloaded images. 16 | This typically searches the index located at '[SavePath]/images/civitai_images.bleve' 17 | unless 'BleveIndexPath' is set in the configuration (in which case it searches that path). 18 | 19 | Supports Bleve's query string syntax. 20 | 21 | The following fields (using their lowercase JSON tag names) are typically relevant for images: 22 | - id (string): Unique ID (e.g., img_67890) 23 | - type (string): Should be "image" 24 | - name (string): Image file name 25 | - filePath (string): Full path to the downloaded image file 26 | - directoryPath (string): Directory containing the file (often the same as filePath for images) 27 | - modelName (string): Name of the parent model (if image is from a model) 28 | - versionName (string): Name of the model version (if image is from a version) 29 | - baseModel (string): Base model associated with the image 30 | - creatorName (string): Username of the image creator 31 | - tags ([]string): Associated image tags (often from generation data) 32 | - prompt (string): Image generation prompt 33 | - nsfwLevel (string): Image NSFW level (e.g., "None", "Soft", "Mature", "X") 34 | 35 | Examples: 36 | civitai-downloader search images -q "cat" 37 | civitai-downloader search images -q "+creatorName:some_creator +prompt:landscape" 38 | civitai-downloader search images -q "+tags:photorealistic"`, 39 | Run: runSearchImages, 40 | } 41 | 42 | func init() { 43 | searchCmd.AddCommand(searchImagesCmd) // Add to parent search command 44 | 45 | // Share the searchQuery variable with the models command 46 | searchImagesCmd.Flags().StringVarP(&searchQuery, "query", "q", "", "Search query (uses Bleve query string syntax)") 47 | _ = searchImagesCmd.MarkFlagRequired("query") 48 | } 49 | 50 | // runSearchImages determines the image index path and calls the shared search logic. 51 | func runSearchImages(cmd *cobra.Command, args []string) { 52 | initLogging() // Initialize logging 53 | log.Info("Starting Search Images Command") 54 | 55 | // Determine the index path for images 56 | indexPath := globalConfig.BleveIndexPath // Use path from config if set 57 | if indexPath == "" { 58 | // Default path determination for images 59 | if globalConfig.SavePath == "" { 60 | log.Fatal("Cannot determine default Bleve index path: SavePath and BleveIndexPath are not set in config.") 61 | } 62 | // Images default index is inside the default images directory 63 | defaultImageDir := filepath.Join(globalConfig.SavePath, "images") 64 | indexPath = filepath.Join(defaultImageDir, "civitai_images.bleve") 65 | log.Infof("BleveIndexPath not set, using default image index: %s", indexPath) 66 | } 67 | 68 | // Call the shared search logic 69 | runSearchLogic(indexPath, searchQuery) 70 | } 71 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/cmd_search_models.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "path/filepath" 5 | 6 | log "github.com/sirupsen/logrus" 7 | "github.com/spf13/cobra" 8 | // No Bleve/index import needed here, logic is in runSearchLogic 9 | ) 10 | 11 | // searchModelsCmd represents the command to search the models index 12 | var searchModelsCmd = &cobra.Command{ 13 | Use: "models", 14 | Short: "Search the Bleve index for downloaded models", 15 | Long: `Performs a search against the Bleve index for downloaded models. 16 | This typically searches the index located at '[SavePath]/civitai.bleve' unless 17 | 'BleveIndexPath' is set in the configuration. 18 | 19 | Supports Bleve's query string syntax. 20 | 21 | The following fields (using their lowercase JSON tag names) are typically relevant for models: 22 | - id (string): Unique ID (e.g., v_12345) 23 | - type (string): Should be "model_file" 24 | - name (string): File name of the model 25 | - description (string): Model or version description text 26 | - filePath (string): Full path to the downloaded file 27 | - directoryPath (string): Directory containing the file 28 | - baseModelPath (string): Path up to the base model slug 29 | - modelPath (string): Path up to the model name slug 30 | - modelName (string): Name of the parent model 31 | - versionName (string): Name of the model version 32 | - baseModel (string): Base model name (e.g., "SDXL 1.0") 33 | - creatorName (string): Username of the creator 34 | - tags ([]string): Associated model or version tags 35 | - publishedAt (time): Version publication timestamp (e.g., +publishedAt:>[2024-01-01]) 36 | - versionDownloadCount (numeric): Version download count 37 | - versionRating (numeric): Version rating 38 | - versionRatingCount (numeric): Version rating count 39 | - fileSizeKB (numeric): File size in KB 40 | - fileFormat (string): File format (e.g., "SafeTensor") 41 | - filePrecision (string): File precision (e.g., "fp16") 42 | - fileSizeType (string): File size type (e.g., "pruned") 43 | - torrentPath (string): Path to the downloaded .torrent file (if any) 44 | - magnetLink (string): Magnet link for the torrent (if any) 45 | 46 | Examples: 47 | civitai-downloader search models -q "lora" 48 | civitai-downloader search models -q "+modelName:MyModel +baseModel:sdxl*" 49 | civitai-downloader search models -q "+tags:style"`, 50 | Run: runSearchModels, 51 | } 52 | 53 | func init() { 54 | searchCmd.AddCommand(searchModelsCmd) // Add to parent search command 55 | 56 | // Use PersistentFlags if you want flags to be available to potential sub-subcommands 57 | // Use Flags for flags specific to this command 58 | searchModelsCmd.Flags().StringVarP(&searchQuery, "query", "q", "", "Search query (uses Bleve query string syntax)") 59 | _ = searchModelsCmd.MarkFlagRequired("query") 60 | } 61 | 62 | // runSearchModels determines the model index path and calls the shared search logic. 63 | func runSearchModels(cmd *cobra.Command, args []string) { 64 | initLogging() // Initialize logging 65 | log.Info("Starting Search Models Command") 66 | 67 | // Determine the index path for models 68 | indexPath := globalConfig.BleveIndexPath // Use path from config if set 69 | if indexPath == "" { 70 | if globalConfig.SavePath == "" { 71 | log.Fatal("Cannot determine default Bleve index path: SavePath and BleveIndexPath are not set in config.") 72 | } 73 | indexPath = filepath.Join(globalConfig.SavePath, "civitai.bleve") 74 | log.Infof("BleveIndexPath not set, using default model index: %s", indexPath) 75 | } 76 | 77 | // Call the shared search logic 78 | runSearchLogic(indexPath, searchQuery) 79 | } 80 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/images.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | ) 6 | 7 | // imagesCmd represents the images command 8 | var imagesCmd = &cobra.Command{ 9 | Use: "images", 10 | Short: "Download images based on various criteria (model, user, etc.)", 11 | Long: `Downloads images from Civitai based on filters like model ID, model version ID, 12 | or username. Allows specifying limits, sorting, and NSFW preferences. 13 | 14 | Examples: 15 | # Download latest 20 images for model ID 123 16 | civitai-downloader images --model-id 123 --limit 20 17 | 18 | # Download all SFW images for model version ID 456, sorted by most reactions 19 | civitai-downloader images --model-version-id 456 --sort "Most Reactions" --nsfw=None 20 | 21 | # Download the 50 most popular images of all time from user 'exampleUser' 22 | civitai-downloader images --username exampleUser --limit 50 --period AllTime --sort MostPopular`, 23 | Run: runImages, 24 | } 25 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/root.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | 10 | log "github.com/sirupsen/logrus" // Import logrus for config loading message 11 | "github.com/spf13/cobra" 12 | "github.com/spf13/viper" 13 | 14 | "go-civitai-download/internal/api" 15 | "go-civitai-download/internal/models" 16 | ) 17 | 18 | // cfgFile holds the path to the config file specified by the user 19 | var cfgFile string 20 | 21 | // logApiFlag holds the value of the --log-api flag 22 | var logApiFlag bool 23 | 24 | // savePathFlag holds the value of the --save-path flag 25 | var savePathFlag string 26 | 27 | // apiDelayFlag holds the value of the --api-delay flag 28 | var apiDelayFlag int 29 | 30 | // apiTimeoutFlag holds the value of the --api-timeout flag 31 | var apiTimeoutFlag int 32 | 33 | // logLevel and logFormat are declared elsewhere (e.g., cmd_download_setup.go) 34 | // var logLevel string 35 | // var logFormat string 36 | 37 | // globalConfig holds the loaded configuration 38 | var globalConfig models.Config 39 | 40 | // globalHttpTransport holds the globally configured HTTP transport (base or logging-wrapped) 41 | var globalHttpTransport http.RoundTripper 42 | 43 | // rootCmd represents the base command when called without any subcommands 44 | var rootCmd = &cobra.Command{ 45 | Use: "civitai-downloader", 46 | Short: "A tool to download models from Civitai", 47 | Long: `Civitai Downloader allows you to fetch and manage models 48 | from Civitai.com based on specified criteria.`, 49 | PersistentPreRunE: loadGlobalConfig, // Load config before any command runs 50 | // Uncomment the following line if your bare application 51 | // has an action associated with it: 52 | // Run: func(cmd *cobra.Command, args []string) { }, 53 | } 54 | 55 | // Execute adds all child commands to the root command and sets flags appropriately. 56 | // This is called by main.main(). It only needs to happen once to the rootCmd. 57 | func Execute() { 58 | // cobra.OnInitialize(initConfig) // We use PersistentPreRunE now 59 | err := rootCmd.Execute() 60 | if err != nil { 61 | fmt.Fprintf(os.Stderr, "Error executing command: %v\n", err) 62 | os.Exit(1) 63 | } 64 | } 65 | 66 | func init() { 67 | // Add persistent flags that apply to all commands 68 | rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "config.toml", "Configuration file path") 69 | 70 | // Add persistent flags for logging 71 | rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "Logging level (trace, debug, info, warn, error, fatal, panic)") 72 | rootCmd.PersistentFlags().StringVar(&logFormat, "log-format", "text", "Logging format (text, json)") 73 | // NOTE: Viper binding for log level/format is not strictly necessary 74 | // as they are handled directly in initLogging() before Viper might be fully ready, 75 | // but we can add them for consistency if needed elsewhere. 76 | // viper.BindPFlag("loglevel", rootCmd.PersistentFlags().Lookup("log-level")) 77 | // viper.BindPFlag("logformat", rootCmd.PersistentFlags().Lookup("log-format")) 78 | 79 | // Add persistent flag for API logging 80 | rootCmd.PersistentFlags().BoolVar(&logApiFlag, "log-api", false, "Log API requests/responses to api.log (overrides config)") 81 | viper.BindPFlag("logapirequests", rootCmd.PersistentFlags().Lookup("log-api")) 82 | 83 | // Add persistent flag for save path 84 | rootCmd.PersistentFlags().StringVar(&savePathFlag, "save-path", "", "Directory to save models (overrides config)") 85 | viper.BindPFlag("savepath", rootCmd.PersistentFlags().Lookup("save-path")) 86 | 87 | // Add persistent flag for API delay 88 | // Default value 0 or negative means "use config or viper default" 89 | rootCmd.PersistentFlags().IntVar(&apiDelayFlag, "api-delay", -1, "Delay between API calls in ms (overrides config, -1 uses config default)") 90 | viper.BindPFlag("apidelayms", rootCmd.PersistentFlags().Lookup("api-delay")) 91 | 92 | // Add persistent flag for API timeout 93 | // Default value 0 or negative means "use config or viper default" 94 | rootCmd.PersistentFlags().IntVar(&apiTimeoutFlag, "api-timeout", -1, "Timeout for API HTTP client in seconds (overrides config, -1 uses config default)") 95 | viper.BindPFlag("apiclienttimeoutsec", rootCmd.PersistentFlags().Lookup("api-timeout")) 96 | 97 | // Set Viper defaults (these are applied only if not set in config file or by flag) 98 | viper.SetDefault("apidelayms", 200) // Default polite delay 99 | viper.SetDefault("apiclienttimeoutsec", 60) // Default timeout 100 | 101 | // Bind persistent flags defined above 102 | _ = viper.BindPFlag("logapirequests", rootCmd.PersistentFlags().Lookup("log-api")) 103 | _ = viper.BindPFlag("savepath", rootCmd.PersistentFlags().Lookup("save-path")) 104 | _ = viper.BindPFlag("apidelayms", rootCmd.PersistentFlags().Lookup("api-delay")) 105 | _ = viper.BindPFlag("apiclienttimeoutsec", rootCmd.PersistentFlags().Lookup("api-timeout")) 106 | _ = viper.BindPFlag("bleveindexpath", rootCmd.PersistentFlags().Lookup("bleve-index-path")) 107 | 108 | // Cobra also supports local flags, which will only run 109 | // when this action is called directly. 110 | // rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") 111 | } 112 | 113 | // loadGlobalConfig attempts to load the configuration and applies flag overrides. 114 | // It also sets up the global HTTP transport based on logging settings. 115 | func loadGlobalConfig(cmd *cobra.Command, args []string) error { 116 | // --- Configure Viper to read the config file --- 117 | if cfgFile != "" { 118 | // Use config file from the flag. 119 | viper.SetConfigFile(cfgFile) 120 | } else { 121 | // Find home directory. 122 | home, err := os.UserHomeDir() 123 | cobra.CheckErr(err) 124 | 125 | // Search config in home directory with name ".go-civitai-downloader" (without extension). 126 | viper.AddConfigPath(home) 127 | // Add current directory path 128 | viper.AddConfigPath(".") 129 | viper.SetConfigName("config") // Name of config file (without extension) 130 | viper.SetConfigType("toml") // REQUIRED if the config file does not have the extension in the name 131 | } 132 | 133 | viper.AutomaticEnv() // read in environment variables that match 134 | 135 | // Normalize keys (e.g., from config like BaseModels to BASMODELS) 136 | // Might help resolve precedence issues with bound flags 137 | viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_")) 138 | 139 | // If a config file is found, read it in. 140 | if err := viper.ReadInConfig(); err == nil { 141 | log.Infof("Using configuration file: %s", viper.ConfigFileUsed()) 142 | // Try merging again AFTER potential flag bindings? Seems redundant but maybe forces precedence. 143 | // if mergeErr := viper.MergeInConfig(); mergeErr != nil { 144 | // log.WithError(mergeErr).Warnf("Error merging config file after read: %s", viper.ConfigFileUsed()) 145 | // } 146 | } else { 147 | // Handle errors reading the config file 148 | if _, ok := err.(viper.ConfigFileNotFoundError); ok { 149 | // Config file not found; ignore error if desired 150 | log.Warnf("Config file not found. Using defaults and flags.") 151 | } else { 152 | // Config file was found but another error was produced 153 | log.WithError(err).Warnf("Error reading config file: %s", viper.ConfigFileUsed()) 154 | // Don't make it fatal, let flags/defaults take over 155 | } 156 | } 157 | // --- Try merging config AFTER reading and AutomaticEnv --- 158 | if viper.ConfigFileUsed() != "" { // Only merge if a config file was actually used 159 | if err := viper.MergeInConfig(); err != nil { 160 | log.WithError(err).Warnf("Error explicitly merging config file: %s", viper.ConfigFileUsed()) 161 | } 162 | } 163 | // --- End Viper config file reading --- 164 | 165 | // --- Unmarshal directly from the global viper instance AFTER ReadInConfig/Merge --- 166 | if err := viper.Unmarshal(&globalConfig); err != nil { 167 | log.WithError(err).Warnf("Error unmarshalling config into globalConfig struct: %v", err) 168 | // Don't make it fatal, allow commands to proceed with defaults/flags if possible. 169 | } 170 | 171 | // --- REMOVED: Manual merge of loaded config values into Viper --- 172 | 173 | log.Debug("Config loaded (or attempted). Viper will manage value precedence.") 174 | 175 | baseTransport := http.DefaultTransport 176 | 177 | // Check if API logging is enabled using Viper 178 | globalHttpTransport = baseTransport // Default to base transport 179 | log.Debugf("Initial globalHttpTransport type: %T", globalHttpTransport) 180 | 181 | if viper.GetBool("logapirequests") { 182 | log.Debug("API request logging enabled (via Viper), wrapping global HTTP transport.") 183 | // Define log file path 184 | logFilePath := "api.log" 185 | // Attempt to resolve relative to SavePath if possible, otherwise use current dir 186 | // Get SavePath using Viper 187 | savePath := viper.GetString("savepath") 188 | if savePath != "" { 189 | // Ensure SavePath exists (it might not if config loading failed partially) 190 | if _, statErr := os.Stat(savePath); statErr == nil { 191 | logFilePath = filepath.Join(savePath, logFilePath) 192 | } else { 193 | log.Warnf("SavePath '%s' (from Viper) not found, saving api.log to current directory.", savePath) 194 | } 195 | } 196 | log.Infof("API logging to file: %s", logFilePath) 197 | 198 | // Initialize the logging transport 199 | loggingTransport, err := api.NewLoggingTransport(baseTransport, logFilePath) 200 | if err != nil { 201 | log.WithError(err).Error("Failed to initialize API logging transport, logging disabled.") 202 | // Keep globalHttpTransport as baseTransport 203 | } else { 204 | globalHttpTransport = loggingTransport // Use the wrapped transport 205 | } 206 | } 207 | // --- End Setup Global HTTP Transport --- 208 | 209 | // If successful or partially successful, globalConfig is populated for use by commands. 210 | // BUT: Rely on viper.Get*() for values potentially overridden by flags. 211 | return nil 212 | } 213 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/search.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | 5 | // Import bleve package directly 6 | 7 | "github.com/spf13/cobra" 8 | ) 9 | 10 | // Variable shared by subcommands 11 | var searchQuery string 12 | 13 | // searchCmd represents the base search command when called without subcommands. 14 | var searchCmd = &cobra.Command{ 15 | Use: "search", 16 | Short: "Search the Bleve index for downloaded models or images", 17 | Long: `Provides subcommands to search the Bleve index created during downloads. 18 | Use 'search models' or 'search images'.`, 19 | // No Run function, this is a parent command 20 | } 21 | 22 | func init() { 23 | rootCmd.AddCommand(searchCmd) 24 | 25 | // No flags defined here, they belong to subcommands (models, images) 26 | } 27 | 28 | // runSearch has been moved to search_logic.go as runSearchLogic 29 | // func runSearch(cmd *cobra.Command, args []string) { ... } 30 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/cmd/search_logic.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | 6 | index "go-civitai-download/index" 7 | 8 | "github.com/blevesearch/bleve/v2" // Import bleve package directly 9 | log "github.com/sirupsen/logrus" 10 | // Note: No cobra import needed here as flags are handled by subcommands 11 | ) 12 | 13 | // runSearchLogic executes the search against a specific index path. 14 | // It's called by the subcommand Run functions. 15 | func runSearchLogic(indexPath string, query string) { 16 | // Logging should already be initialized by the time this is called 17 | log.Debugf("runSearchLogic called with indexPath: %s, query: %s", indexPath, query) 18 | 19 | if query == "" { 20 | // This check should ideally happen in the subcommand before calling, 21 | // but double-checking here. 22 | log.Error("Search query cannot be empty.") 23 | return 24 | } 25 | 26 | if indexPath == "" { 27 | log.Error("Index path cannot be empty.") 28 | return 29 | } 30 | 31 | log.Infof("Opening Bleve index at: %s", indexPath) 32 | // Use Open instead of OpenOrCreateIndex to avoid creating index during search 33 | bleveIndex, err := bleve.Open(indexPath) 34 | if err != nil { 35 | if err == bleve.ErrorIndexPathDoesNotExist { // Check against bleve's error constant 36 | log.Errorf("Bleve index not found at %s. Run the corresponding download command first to create it.", indexPath) 37 | } else { 38 | log.Errorf("Failed to open Bleve index at %s: %v", indexPath, err) 39 | } 40 | return // Return instead of Fatal to allow potential multi-index search later 41 | } 42 | defer func() { 43 | log.Debug("Closing Bleve index.") 44 | if err := bleveIndex.Close(); err != nil { 45 | log.Errorf("Error closing Bleve index: %v", err) 46 | } 47 | }() 48 | 49 | log.Infof("Performing search with query: %s", query) 50 | 51 | searchResults, err := index.SearchIndex(bleveIndex, query) 52 | if err != nil { 53 | log.Errorf("Error performing search: %v", err) 54 | return 55 | } 56 | 57 | log.Infof("Search finished. Hits: %d, Total: %d, Took: %s", 58 | len(searchResults.Hits), 59 | searchResults.Total, 60 | searchResults.Took) 61 | 62 | if searchResults.Total > 0 { 63 | fmt.Println("--- Search Results ---") 64 | for i, hit := range searchResults.Hits { 65 | fmt.Printf("[%d] ID: %s (Score: %.2f)\n", i+1, hit.ID, hit.Score) 66 | // Print requested fields (all fields are requested by SearchIndex) 67 | for field, value := range hit.Fields { 68 | fmt.Printf(" %s: %v\n", field, value) 69 | } 70 | fmt.Println("---") 71 | } 72 | } else { 73 | fmt.Println("No results found matching your query.") 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /cmd/civitai-downloader/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "go-civitai-download/cmd/civitai-downloader/cmd" 5 | "go-civitai-download/internal/api" 6 | ) 7 | 8 | func main() { 9 | // Ensure all API log file buffers are flushed and files closed on exit 10 | defer api.CloseAllLoggingTransports() 11 | 12 | // Execute the root command (defined in cmd/root.go) 13 | cmd.Execute() 14 | } 15 | -------------------------------------------------------------------------------- /config.toml.example: -------------------------------------------------------------------------------- 1 | # Civitai Downloader Configuration Example 2 | 3 | # --- Connection/Auth --- 4 | # Your Civitai API Key (Required for downloading models) 5 | ApiKey = "" 6 | 7 | # --- Paths --- 8 | # Default directory to save downloaded files 9 | SavePath = "downloads" 10 | # Path to the BoltDB database file used to track downloads 11 | # If empty, defaults to [SavePath]/civitai_download_db 12 | DatabasePath = "civitai.db" 13 | # Path to the Bleve search index directory. 14 | # If empty, defaults to separate indexes within [SavePath] (e.g., [SavePath]/civitai.bleve, [SavePath]/civitai_images.bleve) 15 | BleveIndexPath = "" 16 | 17 | # --- Filtering - Model/Version Level --- 18 | # Optional search query string (corresponds to --query flag) 19 | Query = "" 20 | # Optional list of tags to filter by (API currently uses single tag via --tag flag) 21 | # Tags = ["tag1", "tag2"] 22 | # Optional list of usernames to filter by (API currently uses single username via --username flag) 23 | # Usernames = ["user1", "user2"] 24 | # Filter by specific model types (e.g., Checkpoint, LORA, LoCon). Empty will attempt to fetch all types. 25 | ModelTypes = [] 26 | # Filter by specific base models (e.g., "SD 1.5", "SDXL 1.0"). Empty will attempt to fetch all types. 27 | BaseModels = [] 28 | # List of base model names (substrings) to ignore during download 29 | IgnoreBaseModels = [] 30 | # Whether to include models marked as NSFW (Not Safe For Work) 31 | Nsfw = true 32 | # Download ONLY a specific model version ID, ignoring other filters (0 means disabled) 33 | # ModelVersionID = 12345 34 | # Download all versions of matched models, not just the latest one 35 | AllVersions = false # Corresponds to --all-versions flag 36 | 37 | # --- Filtering - File Level --- 38 | # Only download files marked as "Primary" by the uploader 39 | PrimaryOnly = false 40 | # For Checkpoint models, only download files marked as "pruned" 41 | Pruned = false 42 | # For Checkpoint models, only download files marked as "fp16" (float16 precision) 43 | Fp16 = false 44 | # List of case-insensitive strings. If a filename contains any of these, it will be ignored. 45 | IgnoreFileNameStrings = [] 46 | 47 | # --- API Query Behavior --- 48 | # Sorting order for model search results ("Highest Rated", "Most Downloaded", "Newest") 49 | Sort = "Most Downloaded" 50 | # Time period for sorting ("AllTime", "Year", "Month", "Week", "Day") 51 | Period = "AllTime" 52 | # Maximum number of models to request per API page (max 100) 53 | Limit = 100 54 | # Maximum number of API pages to fetch (0 for no limit) 55 | MaxPages = 0 56 | 57 | # --- Downloader Behavior --- 58 | # Number of concurrent download workers 59 | Concurrency = 4 60 | # Save a .json file containing model/version metadata alongside each downloaded file 61 | Metadata = true # Corresponds to --metadata flag 62 | # Only download and save metadata files, skip actual model file download 63 | MetaOnly = false # Corresponds to --meta-only flag 64 | # Save a full model info JSON (including all versions) to 'model_info/' directory 65 | ModelInfo = true # Corresponds to --model-info flag 66 | # Download preview images associated with the specific downloaded model version 67 | # Saves to '[ModelDir]/version_images/[VersionID]/' 68 | VersionImages = true # Corresponds to --version-images flag 69 | # When ModelInfo is true, also download all images for all versions of the model 70 | # Saves to '[ModelInfoDir]/images/[VersionID]/' 71 | ModelImages = false # Corresponds to --model-images flag 72 | # Skip the confirmation prompt before starting downloads 73 | SkipConfirmation = false # Corresponds to --yes flag 74 | # Delay in milliseconds between consecutive API calls (helps avoid rate limiting) 75 | ApiDelayMs = 200 76 | # Timeout in seconds for HTTP client requests (API calls and downloads) 77 | ApiClientTimeoutSec = 120 78 | 79 | # --- Other --- 80 | # Log API requests and responses to a file (api.log) 81 | LogApiRequests = false -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module go-civitai-download 2 | 3 | go 1.23 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | git.mills.io/prologic/bitcask v1.0.2 9 | github.com/BurntSushi/toml v1.3.2 10 | github.com/anacrolix/torrent v1.58.1 11 | github.com/blevesearch/bleve/v2 v2.5.0 12 | github.com/gosuri/uilive v0.0.4 13 | github.com/sirupsen/logrus v1.8.1 14 | github.com/spf13/cobra v1.9.1 15 | github.com/spf13/viper v1.20.1 16 | github.com/stretchr/testify v1.10.0 17 | github.com/zeebo/blake3 v0.2.4 18 | ) 19 | 20 | require ( 21 | github.com/RoaringBitmap/roaring/v2 v2.4.5 // indirect 22 | github.com/abcum/lcp v0.0.0-20201209214815-7a3f3840be81 // indirect 23 | github.com/anacrolix/generics v0.0.3-0.20240902042256-7fb2702ef0ca // indirect 24 | github.com/anacrolix/missinggo v1.3.0 // indirect 25 | github.com/anacrolix/missinggo/v2 v2.7.4 // indirect 26 | github.com/bits-and-blooms/bitset v1.22.0 // indirect 27 | github.com/blevesearch/bleve_index_api v1.2.7 // indirect 28 | github.com/blevesearch/geo v0.1.20 // indirect 29 | github.com/blevesearch/go-faiss v1.0.25 // indirect 30 | github.com/blevesearch/go-porterstemmer v1.0.3 // indirect 31 | github.com/blevesearch/gtreap v0.1.1 // indirect 32 | github.com/blevesearch/mmap-go v1.0.4 // indirect 33 | github.com/blevesearch/scorch_segment_api/v2 v2.3.9 // indirect 34 | github.com/blevesearch/segment v0.9.1 // indirect 35 | github.com/blevesearch/snowballstem v0.9.0 // indirect 36 | github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect 37 | github.com/blevesearch/vellum v1.1.0 // indirect 38 | github.com/blevesearch/zapx/v11 v11.4.1 // indirect 39 | github.com/blevesearch/zapx/v12 v12.4.1 // indirect 40 | github.com/blevesearch/zapx/v13 v13.4.1 // indirect 41 | github.com/blevesearch/zapx/v14 v14.4.1 // indirect 42 | github.com/blevesearch/zapx/v15 v15.4.1 // indirect 43 | github.com/blevesearch/zapx/v16 v16.2.2 // indirect 44 | github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8 // indirect 45 | github.com/davecgh/go-spew v1.1.1 // indirect 46 | github.com/fsnotify/fsnotify v1.8.0 // indirect 47 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 48 | github.com/gofrs/flock v0.8.0 // indirect 49 | github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 // indirect 50 | github.com/golang/protobuf v1.5.3 // indirect 51 | github.com/golang/snappy v0.0.4 // indirect 52 | github.com/huandu/xstrings v1.3.2 // indirect 53 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 54 | github.com/json-iterator/go v1.1.11 // indirect 55 | github.com/klauspost/cpuid/v2 v2.2.3 // indirect 56 | github.com/mattn/go-isatty v0.0.16 // indirect 57 | github.com/minio/sha256-simd v1.0.0 // indirect 58 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 59 | github.com/modern-go/reflect2 v1.0.1 // indirect 60 | github.com/mr-tron/base58 v1.2.0 // indirect 61 | github.com/mschoch/smat v0.2.0 // indirect 62 | github.com/multiformats/go-multihash v0.2.3 // indirect 63 | github.com/multiformats/go-varint v0.0.6 // indirect 64 | github.com/pelletier/go-toml/v2 v2.2.3 // indirect 65 | github.com/pkg/errors v0.9.1 // indirect 66 | github.com/plar/go-adaptive-radix-tree v1.0.4 // indirect 67 | github.com/pmezard/go-difflib v1.0.0 // indirect 68 | github.com/sagikazarmark/locafero v0.7.0 // indirect 69 | github.com/sourcegraph/conc v0.3.0 // indirect 70 | github.com/spaolacci/murmur3 v1.1.0 // indirect 71 | github.com/spf13/afero v1.12.0 // indirect 72 | github.com/spf13/cast v1.7.1 // indirect 73 | github.com/spf13/pflag v1.0.6 // indirect 74 | github.com/subosito/gotenv v1.6.0 // indirect 75 | go.etcd.io/bbolt v1.4.0 // indirect 76 | go.uber.org/atomic v1.10.0 // indirect 77 | go.uber.org/multierr v1.9.0 // indirect 78 | golang.org/x/crypto v0.32.0 // indirect 79 | golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect 80 | golang.org/x/sys v0.29.0 // indirect 81 | golang.org/x/text v0.21.0 // indirect 82 | google.golang.org/protobuf v1.36.1 // indirect 83 | gopkg.in/yaml.v3 v3.0.1 // indirect 84 | lukechampine.com/blake3 v1.1.6 // indirect 85 | ) 86 | -------------------------------------------------------------------------------- /index/index.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import ( 4 | "log" 5 | "os" 6 | "time" 7 | 8 | "github.com/blevesearch/bleve/v2" 9 | ) 10 | 11 | const defaultIndexPath = "civitai.bleve" 12 | 13 | // Item represents a generic item to be indexed. 14 | // We might need more specific structs later for models, images, etc. 15 | // By default, all fields defined here are indexed and searchable using their 16 | // lowercase JSON tag names (e.g., query '+creatorName:someuser' or '+tags:tagname'). 17 | type Item struct { 18 | ID string `json:"id"` // Unique ID (e.g., v_, img_) 19 | Type string `json:"type"` // Type of item (e.g., "model_file", "image") 20 | Name string `json:"name"` // Name of the item (file name, image name if available) 21 | Description string `json:"description"` // Description or other text content 22 | FilePath string `json:"filePath"` // Path where the item is downloaded 23 | DirectoryPath string `json:"directoryPath,omitempty"` // Directory containing the file 24 | BaseModelPath string `json:"baseModelPath,omitempty"` // Path up to the base model slug 25 | ModelPath string `json:"modelPath,omitempty"` // Path up to the model name slug 26 | ModelName string `json:"modelName,omitempty"` // Name of the parent model (for model files/images) 27 | VersionName string `json:"versionName,omitempty"` // Name of the model version (for model files) 28 | BaseModel string `json:"baseModel,omitempty"` // Base model (e.g., SDXL 1.0) 29 | CreatorName string `json:"creatorName,omitempty"` // Username of the creator 30 | Tags []string `json:"tags,omitempty"` // Associated tags (if available) 31 | Prompt string `json:"prompt,omitempty"` // Image generation prompt (for images) 32 | NsfwLevel string `json:"nsfwLevel,omitempty"` // NSFW Level (for images) 33 | 34 | // New Fields 35 | PublishedAt time.Time `json:"publishedAt,omitempty"` // Version publication timestamp 36 | VersionDownloadCount float64 `json:"versionDownloadCount,omitempty"` // Version download count 37 | VersionRating float64 `json:"versionRating,omitempty"` // Version rating 38 | VersionRatingCount float64 `json:"versionRatingCount,omitempty"` // Version rating count 39 | FileSizeKB float64 `json:"fileSizeKB,omitempty"` // File size in KB 40 | FileFormat string `json:"fileFormat,omitempty"` // File format (e.g., safetensor) 41 | FilePrecision string `json:"filePrecision,omitempty"` // File precision (e.g., fp16) 42 | FileSizeType string `json:"fileSizeType,omitempty"` // File size type (e.g., pruned) 43 | 44 | // Torrent Information (populated by the 'torrent' command) 45 | TorrentPath string `json:"torrentPath,omitempty"` // Path to the downloaded .torrent file 46 | MagnetLink string `json:"magnetLink,omitempty"` // Magnet link for the torrent 47 | } 48 | 49 | // OpenOrCreateIndex opens an existing Bleve index or creates a new one if it doesn't exist. 50 | func OpenOrCreateIndex(indexPath string) (bleve.Index, error) { 51 | if indexPath == "" { 52 | indexPath = defaultIndexPath 53 | } 54 | 55 | index, err := bleve.Open(indexPath) 56 | if err == bleve.ErrorIndexPathDoesNotExist { 57 | log.Printf("Creating new index at: %s", indexPath) 58 | mapping := bleve.NewIndexMapping() 59 | // Customize mapping here if needed (e.g., for specific fields) 60 | index, err = bleve.New(indexPath, mapping) 61 | if err != nil { 62 | return nil, err 63 | } 64 | } else if err != nil { 65 | return nil, err // Other error opening index 66 | } else { 67 | log.Printf("Opened existing index at: %s", indexPath) 68 | } 69 | return index, nil 70 | } 71 | 72 | // IndexItem adds or updates an item in the Bleve index. 73 | func IndexItem(index bleve.Index, item Item) error { 74 | return index.Index(item.ID, item) 75 | } 76 | 77 | // SearchIndex performs a search query against the index. 78 | func SearchIndex(index bleve.Index, query string) (*bleve.SearchResult, error) { 79 | searchQuery := bleve.NewQueryStringQuery(query) 80 | searchRequest := bleve.NewSearchRequest(searchQuery) 81 | searchRequest.Fields = []string{"*"} // Request all stored fields 82 | searchResults, err := index.Search(searchRequest) 83 | if err != nil { 84 | return nil, err 85 | } 86 | return searchResults, nil 87 | } 88 | 89 | // DeleteIndex removes the index directory. Use with caution! 90 | func DeleteIndex(indexPath string) error { 91 | if indexPath == "" { 92 | indexPath = defaultIndexPath 93 | } 94 | log.Printf("Attempting to delete index at: %s", indexPath) 95 | return os.RemoveAll(indexPath) 96 | } 97 | -------------------------------------------------------------------------------- /internal/api/client.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "net/http/httputil" 11 | "net/url" 12 | "os" 13 | "time" 14 | 15 | "go-civitai-download/internal/models" 16 | 17 | log "github.com/sirupsen/logrus" 18 | ) 19 | 20 | // Custom Error Types 21 | var ( 22 | ErrRateLimited = errors.New("API rate limit exceeded") 23 | ErrUnauthorized = errors.New("API request unauthorized (check API key)") 24 | ErrNotFound = errors.New("API resource not found") 25 | ErrServerError = errors.New("API server error") 26 | ) 27 | 28 | const CivitaiApiBaseUrl = "https://civitai.com/api/v1" 29 | 30 | // apiLogger is a dedicated logger for api.log 31 | var apiLogger = log.New() 32 | var apiLogFile *os.File 33 | 34 | // configureApiLogger sets up the apiLogger based on config. 35 | // This should be called once, perhaps from the main command setup or PersistentPreRun. 36 | // For simplicity now, we'll call it within NewClient, though not ideal. 37 | func configureApiLogger(shouldLog bool) { 38 | log.Debugf("configureApiLogger called with shouldLog=%t", shouldLog) // Log entry 39 | if !shouldLog { 40 | apiLogger.SetOutput(io.Discard) 41 | log.Debug("API logging disabled by config/flag.") 42 | return 43 | } 44 | 45 | if apiLogFile == nil { 46 | log.Debug("apiLogFile is nil, attempting to open...") 47 | var err error 48 | apiLogFile, err = os.OpenFile("api.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) 49 | if err != nil { 50 | log.WithError(err).Error("Failed to open api.log, API logging disabled.") 51 | apiLogger.SetOutput(io.Discard) 52 | return 53 | } 54 | log.Debug("api.log opened successfully.") 55 | apiLogger.SetOutput(apiLogFile) 56 | // Use a simple text formatter for the log file 57 | apiLogger.SetFormatter(&log.TextFormatter{ 58 | DisableColors: true, 59 | FullTimestamp: true, 60 | DisableQuote: true, 61 | QuoteEmptyFields: true, 62 | }) 63 | apiLogger.SetLevel(log.DebugLevel) // Log everything to the file if enabled 64 | apiLogger.Info("API Logger Initialized") 65 | } else { 66 | log.Debug("apiLogFile already open, reusing existing handle.") 67 | } 68 | } 69 | 70 | // CleanupApiLog closes the api.log file handle. Should be called on application exit. 71 | func CleanupApiLog() { 72 | if apiLogFile != nil { 73 | apiLogger.Info("Closing API log file.") 74 | if err := apiLogFile.Close(); err != nil { 75 | apiLogger.WithError(err).Error("Error closing API log file") 76 | } 77 | } 78 | } 79 | 80 | // Client struct for interacting with the Civitai API 81 | // TODO: Add http.Client field for reuse 82 | type Client struct { 83 | ApiKey string 84 | HttpClient *http.Client // Use a shared client 85 | logApiRequests bool // Store the config setting 86 | } 87 | 88 | // NewClient creates a new API client 89 | // TODO: Initialize and pass a shared http.Client 90 | func NewClient(apiKey string, httpClient *http.Client, cfg models.Config) *Client { 91 | if httpClient == nil { 92 | httpClient = &http.Client{Timeout: 30 * time.Second} 93 | } 94 | // Log the value being passed 95 | log.Debugf("NewClient called, cfg.LogApiRequests value: %t", cfg.LogApiRequests) 96 | // Configure the logger based on the *global* config setting 97 | configureApiLogger(cfg.LogApiRequests) 98 | 99 | return &Client{ 100 | ApiKey: apiKey, 101 | HttpClient: httpClient, 102 | logApiRequests: cfg.LogApiRequests, // Store flag for use in methods 103 | } 104 | } 105 | 106 | // GetModels fetches models based on query parameters, using cursor pagination. 107 | // Accepts the cursor for the next page. Returns the next cursor and the response. 108 | func (c *Client) GetModels(cursor string, queryParams models.QueryParameters) (string, models.ApiResponse, error) { 109 | // Use the helper function to build base query parameters 110 | values := ConvertQueryParamsToURLValues(queryParams) 111 | 112 | // Add cursor *only if* it's provided (not empty) 113 | if cursor != "" { 114 | values.Add("cursor", cursor) 115 | } else { 116 | // For the first request (empty cursor), do not add 'page' either. 117 | // The API defaults to the first page/results without page/cursor. 118 | } 119 | 120 | reqURL := fmt.Sprintf("%s/models?%s", CivitaiApiBaseUrl, values.Encode()) 121 | // No change to main logger here 122 | // log.Debugf("Requesting URL: %s", reqURL) 123 | 124 | req, err := http.NewRequest("GET", reqURL, nil) 125 | if err != nil { 126 | log.WithError(err).Errorf("Error creating request for %s", reqURL) 127 | // Wrap the underlying error 128 | return "", models.ApiResponse{}, fmt.Errorf("error creating request: %w", err) 129 | } 130 | 131 | req.Header.Set("Content-Type", "application/json") 132 | if c.ApiKey != "" { 133 | req.Header.Set("Authorization", "Bearer "+c.ApiKey) 134 | } 135 | 136 | // --- Log API Request --- 137 | if c.logApiRequests { 138 | reqDump, dumpErr := httputil.DumpRequestOut(req, true) // Dump outgoing request 139 | if dumpErr != nil { 140 | apiLogger.WithError(dumpErr).Error("Failed to dump API request") 141 | } else { 142 | apiLogger.Debugf("\n--- API Request ---\n%s\n--------------------", string(reqDump)) 143 | } 144 | } 145 | // --- End Log API Request --- 146 | 147 | var resp *http.Response 148 | var lastErr error 149 | maxRetries := 3 150 | 151 | for attempt := 0; attempt < maxRetries; attempt++ { 152 | resp, err = c.HttpClient.Do(req) 153 | 154 | // --- Log API Response (Attempt) --- 155 | if c.logApiRequests && resp != nil { // Log even on non-200 responses 156 | // Read body first for logging, then replace it for potential retries/final processing 157 | bodyBytes, readErr := io.ReadAll(resp.Body) 158 | if closeErr := resp.Body.Close(); closeErr != nil { // Close original body and check error 159 | apiLogger.WithError(closeErr).Warn("Error closing response body after reading for logging") 160 | } 161 | if readErr != nil { 162 | apiLogger.WithError(readErr).Errorf("Attempt %d: Failed to read response body for logging", attempt+1) 163 | } else { 164 | // Create a new io.ReadCloser from the read bytes 165 | resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) 166 | 167 | // Try to dump response (might fail on huge bodies) 168 | respDump, dumpErr := httputil.DumpResponse(resp, false) // false = don't dump body here 169 | if dumpErr != nil { 170 | apiLogger.WithError(dumpErr).Errorf("Attempt %d: Failed to dump API response headers", attempt+1) 171 | } else { 172 | apiLogger.Debugf("\n--- API Response (Attempt %d) ---\n%s\n--- Body (%d bytes) ---\n%s\n----------------------------- \n", 173 | attempt+1, string(respDump), len(bodyBytes), string(bodyBytes)) 174 | } 175 | } 176 | } else if c.logApiRequests && err != nil { 177 | // Log if the Do() call itself failed 178 | apiLogger.WithError(err).Errorf("Attempt %d: HTTP Client Do() failed", attempt+1) 179 | } 180 | // --- End Log API Response (Attempt) --- 181 | 182 | if err != nil { 183 | lastErr = fmt.Errorf("http request failed (attempt %d/%d): %w", attempt+1, maxRetries, err) 184 | if attempt < maxRetries-1 { // Only log retry warning if not the last attempt 185 | log.WithError(err).Warnf("Retrying (%d/%d)...", attempt+1, maxRetries) 186 | time.Sleep(time.Duration(attempt+1) * 2 * time.Second) // Exponential backoff 187 | continue 188 | } 189 | break // Max retries reached on HTTP error 190 | } 191 | 192 | switch resp.StatusCode { 193 | case http.StatusOK: 194 | lastErr = nil // Success 195 | goto ProcessResponse // Use goto to break out of switch and loop 196 | case http.StatusTooManyRequests: 197 | lastErr = ErrRateLimited 198 | case http.StatusUnauthorized, http.StatusForbidden: 199 | lastErr = ErrUnauthorized 200 | goto RequestFailed // Non-retryable auth error 201 | case http.StatusNotFound: 202 | lastErr = ErrNotFound 203 | goto RequestFailed // Non-retryable not found error 204 | case http.StatusServiceUnavailable: 205 | lastErr = fmt.Errorf("%w (status code 503)", ErrServerError) 206 | default: 207 | if resp.StatusCode >= 500 { 208 | lastErr = fmt.Errorf("%w (status code %d)", ErrServerError, resp.StatusCode) 209 | } else { 210 | // Other client-side errors (4xx) are likely not retryable 211 | lastErr = fmt.Errorf("API request failed with status %d", resp.StatusCode) 212 | goto RequestFailed 213 | } 214 | } 215 | 216 | // If we are here, it's a retryable error (Rate Limit or 5xx) 217 | // resp.Body was already closed and replaced during logging 218 | if attempt < maxRetries-1 { 219 | var sleepDuration time.Duration 220 | if resp.StatusCode == http.StatusTooManyRequests { 221 | // Longer backoff for rate limits 222 | sleepDuration = time.Duration(attempt+1) * 5 * time.Second 223 | log.WithError(lastErr).Warnf("Rate limited. Retrying (%d/%d) after %s...", attempt+1, maxRetries, sleepDuration) 224 | } else { // Server errors (5xx) 225 | sleepDuration = time.Duration(attempt+1) * 3 * time.Second 226 | log.WithError(lastErr).Warnf("Server error. Retrying (%d/%d) after %s...", attempt+1, maxRetries, sleepDuration) 227 | } 228 | time.Sleep(sleepDuration) 229 | } else { 230 | log.WithError(lastErr).Errorf("Request failed after %d attempts with status %d", maxRetries, resp.StatusCode) 231 | } 232 | } 233 | 234 | RequestFailed: 235 | if lastErr != nil { 236 | // Don't close body here, it should have been closed during logging or by defer on success path 237 | // if resp != nil { resp.Body.Close() } 238 | return "", models.ApiResponse{}, lastErr 239 | } 240 | 241 | ProcessResponse: 242 | // Body should already be replaced with a readable version from logging step 243 | defer resp.Body.Close() 244 | 245 | body, err := io.ReadAll(resp.Body) // Read the replaced body 246 | if err != nil { 247 | log.WithError(err).Error("Error reading final response body") 248 | return "", models.ApiResponse{}, fmt.Errorf("error reading response body: %w", err) 249 | } 250 | 251 | var response models.ApiResponse 252 | err = json.Unmarshal(body, &response) 253 | if err != nil { 254 | log.WithError(err).Errorf("Error unmarshalling response JSON") 255 | // Log the body that caused the error (already logged to api.log if enabled) 256 | log.Debugf("Response body causing unmarshal error: %s", string(body)) 257 | return "", models.ApiResponse{}, fmt.Errorf("error unmarshalling response JSON: %w", err) 258 | } 259 | 260 | // Return the next cursor provided by the API 261 | return response.Metadata.NextCursor, response, nil 262 | } 263 | 264 | // ConvertQueryParamsToURLValues converts the QueryParameters struct into url.Values for API requests. 265 | // This is used for constructing the request URL. 266 | func ConvertQueryParamsToURLValues(queryParams models.QueryParameters) url.Values { 267 | values := url.Values{} 268 | values.Add("sort", queryParams.Sort) 269 | values.Add("period", queryParams.Period) 270 | // Always include the nsfw parameter, converting the boolean to string "true" or "false" 271 | values.Add("nsfw", fmt.Sprintf("%t", queryParams.Nsfw)) 272 | values.Add("limit", fmt.Sprintf("%d", queryParams.Limit)) 273 | for _, t := range queryParams.Types { 274 | values.Add("types", t) 275 | } 276 | for _, t := range queryParams.BaseModels { 277 | values.Add("baseModels", t) 278 | } 279 | if queryParams.PrimaryFileOnly { 280 | values.Add("primaryFileOnly", fmt.Sprintf("%t", queryParams.PrimaryFileOnly)) 281 | } 282 | if queryParams.Query != "" { 283 | values.Add("query", queryParams.Query) 284 | } 285 | if queryParams.Tag != "" { 286 | values.Add("tag", queryParams.Tag) 287 | } 288 | if queryParams.Username != "" { 289 | values.Add("username", queryParams.Username) 290 | } 291 | 292 | // Note: Cursor/Page parameters are typically added separately based on pagination logic. 293 | return values 294 | } 295 | 296 | // TODO: Add methods for other API endpoints (e.g., GetModelByID, GetModelVersionByID) 297 | -------------------------------------------------------------------------------- /internal/api/logging_transport.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/http/httputil" 10 | "os" 11 | "strings" 12 | sync "sync" 13 | time "time" 14 | 15 | log "github.com/sirupsen/logrus" 16 | ) 17 | 18 | // Global slice to keep track of all logging transports created 19 | var ( 20 | activeLoggingTransports []*LoggingTransport 21 | transportsMu sync.Mutex 22 | ) 23 | 24 | // LoggingTransport wraps an http.RoundTripper to log request and response details. 25 | type LoggingTransport struct { 26 | Transport http.RoundTripper 27 | logFile *os.File 28 | mu sync.Mutex 29 | writer *bufio.Writer 30 | } 31 | 32 | // NewLoggingTransport creates a new LoggingTransport. 33 | // It opens the specified log file for appending. 34 | func NewLoggingTransport(transport http.RoundTripper, logFilePath string) (*LoggingTransport, error) { 35 | f, err := os.OpenFile(logFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) 36 | if err != nil { 37 | return nil, fmt.Errorf("failed to open API log file %s: %w", logFilePath, err) 38 | } 39 | 40 | // Use default transport if none provided 41 | if transport == nil { 42 | transport = http.DefaultTransport 43 | } 44 | 45 | lt := &LoggingTransport{ 46 | Transport: transport, 47 | logFile: f, 48 | writer: bufio.NewWriter(f), // Use a buffered writer 49 | } 50 | 51 | // Register the new transport 52 | transportsMu.Lock() 53 | activeLoggingTransports = append(activeLoggingTransports, lt) 54 | transportsMu.Unlock() 55 | log.Debugf("Registered new LoggingTransport for file: %s. Total active: %d", logFilePath, len(activeLoggingTransports)) 56 | 57 | return lt, nil 58 | } 59 | 60 | // RoundTrip executes a single HTTP transaction, logging details. 61 | func (t *LoggingTransport) RoundTrip(req *http.Request) (*http.Response, error) { 62 | t.mu.Lock() 63 | defer t.mu.Unlock() 64 | 65 | log.Debug("[LogTransport] RoundTrip: Entered") // VERBOSE 66 | startTime := time.Now() 67 | 68 | // Log request 69 | log.Debug("[LogTransport] RoundTrip: Dumping request...") // VERBOSE 70 | reqDump, err := httputil.DumpRequestOut(req, true) 71 | if err != nil { 72 | log.WithError(err).Error("[LogTransport] Failed to dump API request for logging") 73 | // Proceed with the request anyway 74 | } else { 75 | log.Debug("[LogTransport] RoundTrip: Writing request dump...") // VERBOSE 76 | t.writeLog(fmt.Sprintf("--- Request (%s) ---\n%s\n", startTime.Format(time.RFC3339), string(reqDump))) 77 | log.Debug("[LogTransport] RoundTrip: Request dump written.") // VERBOSE 78 | } 79 | 80 | // Perform the actual request 81 | log.Debug("[LogTransport] RoundTrip: Performing underlying Transport.RoundTrip...") // VERBOSE 82 | resp, err := t.Transport.RoundTrip(req) 83 | log.Debugf("[LogTransport] RoundTrip: Underlying Transport.RoundTrip returned. Err: %v", err) // VERBOSE 84 | 85 | duration := time.Since(startTime) 86 | 87 | // Log response or error 88 | if err != nil { 89 | log.Debug("[LogTransport] RoundTrip: Writing response error...") // VERBOSE 90 | t.writeLog(fmt.Sprintf("--- Response Error (%s, Duration: %v) ---\n%s\n", time.Now().Format(time.RFC3339), duration, err.Error())) 91 | log.Debug("[LogTransport] RoundTrip: Response error written.") // VERBOSE 92 | } else { 93 | log.Debug("[LogTransport] RoundTrip: Processing response...") // VERBOSE 94 | // Check Content-Type to decide whether to log body 95 | contentType := resp.Header.Get("Content-Type") 96 | logBody := strings.HasPrefix(contentType, "application/json") 97 | log.Debugf("[LogTransport] RoundTrip: Response Content-Type: %s, LogBody: %t", contentType, logBody) // VERBOSE 98 | 99 | if logBody { 100 | log.Debug("[LogTransport] RoundTrip: Reading response body for logging...") // VERBOSE 101 | // Read the body for logging 102 | bodyBytes, readErr := io.ReadAll(resp.Body) 103 | if readErr != nil { 104 | log.WithError(readErr).Error("[LogTransport] Failed to read response body for logging") 105 | // Log headers only if body read fails 106 | respDump, dumpErr := httputil.DumpResponse(resp, false) // Headers only 107 | if dumpErr != nil { 108 | log.WithError(dumpErr).Error("[LogTransport] Failed to dump response headers after body read error") 109 | t.writeLog(fmt.Sprintf("--- Response Headers (%s, Duration: %v) ---\nStatus: %s\n(Failed to dump headers or read body)\n", time.Now().Format(time.RFC3339), duration, resp.Status)) 110 | } else { 111 | log.Debug("[LogTransport] RoundTrip: Writing response headers (body read failed)...") // VERBOSE 112 | t.writeLog(fmt.Sprintf("--- Response Headers (%s, Duration: %v) ---\n%s\n(Body read failed)\n", time.Now().Format(time.RFC3339), duration, string(respDump))) 113 | log.Debug("[LogTransport] RoundTrip: Response headers written (body read failed).") // VERBOSE 114 | } 115 | // Restore body with an empty reader? Or let the original error propagate? 116 | // Let the caller handle the read error; resp.Body is likely closed or unusable now. 117 | } else { 118 | log.Debug("[LogTransport] RoundTrip: Response body read successfully. Restoring body...") // VERBOSE 119 | // IMPORTANT: Restore the body so the caller can read it. 120 | if closeErr := resp.Body.Close(); closeErr != nil { 121 | // Log the error but don't necessarily stop the process, as body might still be readable by caller 122 | log.WithError(closeErr).Warn("[LogTransport] Failed to close original response body before replacing it") 123 | } 124 | // Use bytes.NewReader instead of bytes.NewBuffer for the replacement body 125 | resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) 126 | 127 | // Log headers and the body we read 128 | respDumpHeader, dumpErr := httputil.DumpResponse(resp, false) // Headers only 129 | if dumpErr != nil { 130 | log.WithError(dumpErr).Error("[LogTransport] Failed to dump response headers for logging") 131 | t.writeLog(fmt.Sprintf("--- Response (%s, Duration: %v) ---\nStatus: %s\n(Failed to dump headers, body logged below)\n%s\n", time.Now().Format(time.RFC3339), duration, resp.Status, string(bodyBytes))) 132 | } else { 133 | log.Debug("[LogTransport] RoundTrip: Writing response headers and body...") // VERBOSE 134 | // Log headers first, then body for clarity 135 | t.writeLog(fmt.Sprintf("--- Response Headers (%s, Duration: %v) ---\n%s\n--- Response Body (%s) ---\n%s\n", time.Now().Format(time.RFC3339), duration, string(respDumpHeader), contentType, string(bodyBytes))) 136 | log.Debug("[LogTransport] RoundTrip: Response headers and body written.") // VERBOSE 137 | } 138 | } 139 | } else { 140 | // Log only headers for non-JSON content types 141 | log.Debug("[LogTransport] RoundTrip: Dumping non-JSON response headers...") // VERBOSE 142 | respDump, dumpErr := httputil.DumpResponse(resp, false) 143 | if dumpErr != nil { 144 | log.WithError(dumpErr).Error("[LogTransport] Failed to dump non-JSON response headers for logging") 145 | t.writeLog(fmt.Sprintf("--- Response Headers (%s, Duration: %v, Type: %s) ---\nStatus: %s\n(Failed to dump headers)\n", time.Now().Format(time.RFC3339), duration, contentType, resp.Status)) 146 | } else { 147 | log.Debug("[LogTransport] RoundTrip: Writing non-JSON response headers...") // VERBOSE 148 | t.writeLog(fmt.Sprintf("--- Response Headers (%s, Duration: %v, Type: %s) ---\n%s\n(Body not logged)\n", time.Now().Format(time.RFC3339), duration, contentType, string(respDump))) 149 | log.Debug("[LogTransport] RoundTrip: Non-JSON response headers written.") // VERBOSE 150 | } 151 | } 152 | } 153 | 154 | // Ensure logs are written **immediately** after each request/response pair 155 | log.Debug("[LogTransport] RoundTrip: Flushing writer...") // VERBOSE 156 | if errFlush := t.writer.Flush(); errFlush != nil { 157 | // Log error if flushing fails 158 | log.WithError(errFlush).Error("[LogTransport] Failed to flush log writer") 159 | } 160 | log.Debug("[LogTransport] RoundTrip: Writer flushed.") // VERBOSE 161 | 162 | log.Debug("[LogTransport] RoundTrip: Exiting") // VERBOSE 163 | return resp, err 164 | } 165 | 166 | // writeLog writes a string to the buffered writer. 167 | func (t *LoggingTransport) writeLog(logString string) { 168 | _, err := t.writer.WriteString(logString + "\n\n") 169 | if err != nil { 170 | // Log to stderr if writing to file fails 171 | fmt.Fprintf(os.Stderr, "Error writing to API log file: %v\nLog message: %s\n", err, logString) 172 | } 173 | } 174 | 175 | // Close closes the underlying log file. 176 | func (t *LoggingTransport) Close() error { 177 | t.mu.Lock() 178 | defer t.mu.Unlock() 179 | 180 | errFlush := t.writer.Flush() // Ensure buffer is flushed before closing 181 | errClose := t.logFile.Close() 182 | if errFlush != nil { 183 | return fmt.Errorf("failed to flush API log buffer: %w", errFlush) 184 | } 185 | return errClose // Return close error if flush was successful 186 | } 187 | 188 | // CloseAllLoggingTransports iterates over all created transports and closes them. 189 | func CloseAllLoggingTransports() { 190 | transportsMu.Lock() 191 | defer transportsMu.Unlock() 192 | 193 | log.Debugf("Attempting to close %d active logging transports.", len(activeLoggingTransports)) 194 | closedCount := 0 195 | for i, t := range activeLoggingTransports { 196 | log.Debugf("Closing transport #%d for file: %s", i+1, t.logFile.Name()) 197 | if err := t.Close(); err != nil { 198 | // Log error to stderr as the primary logger might also be closing 199 | fmt.Fprintf(os.Stderr, "Error closing logging transport for %s: %v\n", t.logFile.Name(), err) 200 | } else { 201 | closedCount++ 202 | } 203 | } 204 | log.Debugf("Successfully closed %d logging transports.", closedCount) 205 | // Clear the slice after closing 206 | activeLoggingTransports = []*LoggingTransport{} 207 | } 208 | 209 | // DeregisterLoggingTransport removes a specific transport from the active list. 210 | // This might be useful if a transport needs to be manually closed and removed earlier. 211 | // Note: Ensure Close() is called separately if needed before deregistering. 212 | func DeregisterLoggingTransport(transportToDeregister *LoggingTransport) { 213 | transportsMu.Lock() 214 | defer transportsMu.Unlock() 215 | 216 | log.Debugf("Attempting to deregister logging transport for file: %s", transportToDeregister.logFile.Name()) 217 | found := false 218 | newActiveTransports := []*LoggingTransport{} 219 | for _, t := range activeLoggingTransports { 220 | if t != transportToDeregister { 221 | newActiveTransports = append(newActiveTransports, t) 222 | } else { 223 | found = true 224 | } 225 | } 226 | activeLoggingTransports = newActiveTransports 227 | if found { 228 | log.Debugf("Successfully deregistered transport. Remaining active: %d", len(activeLoggingTransports)) 229 | } else { 230 | log.Warnf("Attempted to deregister a transport that was not found.") 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /internal/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "go-civitai-download/internal/models" // Import models for the Config struct 6 | 7 | "github.com/BurntSushi/toml" 8 | log "github.com/sirupsen/logrus" // Use logrus 9 | ) 10 | 11 | // LoadConfig reads the configuration from the specified path (defaulting to "config.toml") 12 | // and populates the provided models.Config struct. 13 | // It returns the loaded config and any error encountered. 14 | // TODO: Make config path configurable. 15 | func LoadConfig(configFilePath string) (models.Config, error) { 16 | if configFilePath == "" { 17 | configFilePath = "config.toml" // Default path 18 | } 19 | var cfg models.Config 20 | _, err := toml.DecodeFile(configFilePath, &cfg) 21 | if err != nil { 22 | // Return the error instead of logging fatal 23 | return models.Config{}, fmt.Errorf("error loading config file %s: %w", configFilePath, err) 24 | } 25 | 26 | // TODO: Add validation for required fields (e.g., SavePath, DatabasePath) 27 | if cfg.SavePath == "" { 28 | log.Warn("Warning: SavePath is not set in config.toml") 29 | // return models.Config{}, fmt.Errorf("SavePath is required in config file") 30 | } 31 | if cfg.DatabasePath == "" { 32 | log.Warn("Warning: DatabasePath is not set in config.toml") 33 | // return models.Config{}, fmt.Errorf("DatabasePath is required in config file") 34 | } 35 | 36 | log.Infof("Configuration loaded from %s", configFilePath) 37 | return cfg, nil 38 | } 39 | -------------------------------------------------------------------------------- /internal/database/bitcask.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "bytes" // For buffer operations 5 | "compress/gzip" 6 | "errors" 7 | "fmt" 8 | "io" // For io.ReadAll 9 | "os" 10 | "path/filepath" 11 | "strconv" 12 | "sync" 13 | 14 | // For DatabaseEntry 15 | 16 | "git.mills.io/prologic/bitcask" 17 | log "github.com/sirupsen/logrus" // Use logrus aliased as log 18 | ) 19 | 20 | // ErrNotFound is returned when a key is not found in the database. 21 | var ErrNotFound = errors.New("key not found") 22 | 23 | // gzipMagicBytes are the first two bytes of a gzip file. 24 | var gzipMagicBytes = []byte{0x1f, 0x8b} 25 | 26 | // DB wraps the bitcask database instance and provides helper methods. 27 | type DB struct { 28 | db *bitcask.Bitcask 29 | sync.RWMutex // Embed mutex for concurrent access control 30 | closeOnce sync.Once 31 | closed bool 32 | closeErr error // Store the error from the first Close call 33 | } 34 | 35 | // Open initializes and returns a DB instance. 36 | func Open(path string) (*DB, error) { 37 | // Ensure the directory exists 38 | dir := filepath.Dir(path) 39 | if dir != "." && dir != "/" { // Avoid trying to create root or current dir explicitly 40 | if err := os.MkdirAll(dir, 0700); err != nil { 41 | return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err) 42 | } 43 | } 44 | 45 | dbInstance, err := bitcask.Open(path) 46 | if err != nil { 47 | return nil, fmt.Errorf("failed to open bitcask database at %s: %w", path, err) 48 | } 49 | log.Infof("Database opened successfully at %s", path) 50 | return &DB{db: dbInstance}, nil 51 | } 52 | 53 | // Lock acquires a write lock. 54 | func (d *DB) Lock() { 55 | d.RWMutex.Lock() 56 | } 57 | 58 | // Unlock releases a write lock. 59 | func (d *DB) Unlock() { 60 | d.RWMutex.Unlock() 61 | } 62 | 63 | // RLock acquires a read lock. 64 | func (d *DB) RLock() { 65 | d.RWMutex.RLock() 66 | } 67 | 68 | // RUnlock releases a read lock. 69 | func (d *DB) RUnlock() { 70 | d.RWMutex.RUnlock() 71 | } 72 | 73 | // Close safely closes the database connection, ensuring it only happens once. 74 | func (d *DB) Close() error { 75 | d.closeOnce.Do(func() { 76 | log.Info("Closing database...") 77 | // Acquire write lock to ensure no operations are in progress during close 78 | d.Lock() 79 | defer d.Unlock() 80 | 81 | d.closeErr = d.db.Close() // Call the underlying close 82 | d.closed = true // Mark as closed 83 | 84 | if d.closeErr != nil { 85 | log.Errorf("Error during database close operation: %v", d.closeErr) 86 | } else { 87 | log.Info("Database closed successfully.") 88 | } 89 | }) 90 | 91 | // Return the error captured during the first Close attempt (if any) 92 | return d.closeErr 93 | } 94 | 95 | // Has checks if a key exists in the database. 96 | func (d *DB) Has(key []byte) bool { 97 | d.RLock() 98 | defer d.RUnlock() 99 | return d.db.Has(key) 100 | } 101 | 102 | // Get retrieves the value associated with a key and decompresses it if necessary. 103 | func (d *DB) Get(key []byte) ([]byte, error) { 104 | d.RLock() 105 | value, err := d.db.Get(key) 106 | d.RUnlock() 107 | 108 | if err != nil { 109 | // Check if the error is KeyNotFound 110 | if errors.Is(err, bitcask.ErrKeyNotFound) { 111 | return nil, ErrNotFound // Return our specific package error 112 | } 113 | return nil, fmt.Errorf("error getting key %s: %w", string(key), err) 114 | } 115 | 116 | // Check for gzip header and decompress if necessary 117 | return decompressIfGzipped(value) 118 | } 119 | 120 | // Put compresses and stores a key-value pair in the database. 121 | func (d *DB) Put(key []byte, value []byte) error { 122 | compressedValue, err := compressGzip(value, gzip.BestCompression) // Level 9 123 | if err != nil { 124 | return fmt.Errorf("error compressing value for key %s: %w", string(key), err) 125 | } 126 | 127 | // Store the compressed value 128 | d.Lock() 129 | err = d.db.Put(key, compressedValue) 130 | d.Unlock() 131 | if err != nil { 132 | return fmt.Errorf("error putting compressed key %s: %w", string(key), err) 133 | } 134 | return nil 135 | } 136 | 137 | // Delete removes a key from the database. 138 | func (d *DB) Delete(key []byte) error { 139 | d.Lock() 140 | err := d.db.Delete(key) 141 | d.Unlock() // Unlock *after* potential error check 142 | if err != nil { 143 | // Wrap error, check for KeyNotFound if deletion of non-existent key is an error 144 | if errors.Is(err, bitcask.ErrKeyNotFound) { 145 | // Consider returning nil here if deleting a non-existent key is acceptable 146 | return ErrNotFound // Return our specific package error here too 147 | } 148 | return fmt.Errorf("error deleting key %s: %w", string(key), err) 149 | } 150 | return nil 151 | } 152 | 153 | // Fold iterates over all key-value pairs, decompresses the value, 154 | // and calls the provided function. 155 | func (d *DB) Fold(fn func(key []byte, value []byte) error) error { 156 | d.RLock() 157 | defer d.RUnlock() 158 | 159 | err := d.db.Fold(func(key []byte) error { 160 | // Need to get the value inside the Fold callback 161 | // Important: Keep the main read lock for the duration of Fold 162 | rawValue, err := d.db.Get(key) // Get raw value (no extra locking needed) 163 | if err != nil { 164 | // Log or handle error getting value during fold? 165 | log.WithError(err).Warnf("Fold: Error getting value for key %s", string(key)) 166 | return nil // Decide if errors should stop the fold 167 | } 168 | 169 | // Decompress the value 170 | value, err := decompressIfGzipped(rawValue) 171 | if err != nil { 172 | log.WithError(err).Warnf("Fold: Error decompressing value for key %s", string(key)) 173 | return nil // Skip this key if decompression fails 174 | } 175 | 176 | // Call the user-provided function with the decompressed value 177 | return fn(key, value) 178 | }) 179 | 180 | return err 181 | } 182 | 183 | // Keys returns a channel of all keys in the database. 184 | // Read from the channel until it is closed. 185 | // Ensure the database is not closed while iterating. 186 | // Note: This acquires a read lock during iteration. 187 | func (d *DB) Keys() <-chan []byte { 188 | d.RLock() // Acquire read lock on the DB wrapper mutex 189 | // Use a goroutine to handle unlocking after the channel is fully consumed or closed 190 | keysChan := d.db.Keys() 191 | monitoredChan := make(chan []byte) 192 | 193 | go func() { 194 | defer d.RUnlock() // Ensure wrapper mutex unlock happens when this goroutine exits 195 | for key := range keysChan { 196 | monitoredChan <- key 197 | } 198 | close(monitoredChan) // Close our channel when the original closes 199 | }() 200 | 201 | return monitoredChan 202 | } 203 | 204 | /* 205 | // --- State Management Helpers --- 206 | 207 | // StoreModelInfo serializes and stores the DatabaseEntry using the file's CRC32 hash as the key. 208 | // NOTE: This is now handled directly in download.go using version ID keys. 209 | func (d *DB) StoreModelInfo(entry models.DatabaseEntry) error { 210 | dbKey := strings.ToUpper(entry.File.Hashes.CRC32) 211 | if dbKey == "" { 212 | return errors.New("cannot store model info: file CRC32 hash is empty") 213 | } 214 | 215 | dataBytes, err := json.Marshal(entry) 216 | if err != nil { 217 | return fmt.Errorf("error marshalling database entry for %s: %w", entry.Filename, err) 218 | } 219 | 220 | log.Debugf("Storing model info with key %s", dbKey) 221 | return d.Put([]byte(dbKey), dataBytes) 222 | } 223 | 224 | // CheckModelExists checks if a model file (identified by CRC32 hash) exists in the database. 225 | // NOTE: Key strategy changed to version ID. 226 | func (d *DB) CheckModelExists(crc32Hash string) bool { 227 | key := strings.ToUpper(crc32Hash) 228 | if key == "" { 229 | return false 230 | } 231 | return d.Has([]byte(key)) 232 | } 233 | */ 234 | 235 | // --- Compression Helpers --- 236 | 237 | // decompressIfGzipped decompresses the value if it is gzipped. 238 | func decompressIfGzipped(value []byte) ([]byte, error) { 239 | // Check for gzip header and decompress if present 240 | if bytes.HasPrefix(value, gzipMagicBytes) { 241 | bReader := bytes.NewReader(value) 242 | gReader, err := gzip.NewReader(bReader) 243 | if err != nil { 244 | log.WithError(err).Warnf("Error creating gzip reader for value, returning raw data.") 245 | return value, nil // Return raw data on decompression error 246 | } 247 | defer gReader.Close() 248 | 249 | decompressedValue, err := io.ReadAll(gReader) 250 | if err != nil { 251 | log.WithError(err).Warnf("Error decompressing value, returning raw data.") 252 | return value, nil // Return raw data on decompression error 253 | } 254 | return decompressedValue, nil 255 | } 256 | 257 | // If no gzip header, return the value as is 258 | return value, nil 259 | } 260 | 261 | // compressGzip compresses the value using gzip with the specified compression level. 262 | func compressGzip(value []byte, level int) ([]byte, error) { 263 | var buf bytes.Buffer 264 | // Use BestCompression (Level 9) 265 | gWriter, err := gzip.NewWriterLevel(&buf, level) 266 | if err != nil { 267 | // Should generally not happen with a bytes.Buffer 268 | return nil, fmt.Errorf("error creating gzip writer for value: %w", err) 269 | } 270 | _, err = gWriter.Write(value) 271 | if err != nil { 272 | _ = gWriter.Close() // Attempt to close writer even on error 273 | return nil, fmt.Errorf("error writing compressed data for value: %w", err) 274 | } 275 | err = gWriter.Close() // Close *must* be called to flush buffers 276 | if err != nil { 277 | return nil, fmt.Errorf("error closing gzip writer for value: %w", err) 278 | } 279 | 280 | compressedValue := buf.Bytes() 281 | return compressedValue, nil 282 | } 283 | 284 | // --- Specific Helpers (Can be expanded for CLI features) --- 285 | 286 | // GetPageState retrieves the saved page number for a given query hash. 287 | func (d *DB) GetPageState(queryHash string) (int, error) { 288 | key := []byte("current_page_" + queryHash) 289 | pageBytes, err := d.Get(key) 290 | if err != nil { 291 | if err == bitcask.ErrKeyNotFound { 292 | return 1, nil // Default to page 1 if not found 293 | } 294 | return 0, fmt.Errorf("error reading page state for %s: %w", queryHash, err) 295 | } 296 | 297 | page, err := strconv.Atoi(string(pageBytes)) 298 | if err != nil { 299 | return 0, fmt.Errorf("error parsing saved page number '%s': %w", string(pageBytes), err) 300 | } 301 | log.WithField("queryHash", queryHash).Debugf("Retrieved page state: %d", page) 302 | return page, nil 303 | } 304 | 305 | // SetPageState saves the next page number for a given query hash. 306 | func (d *DB) SetPageState(queryHash string, nextPage int) error { 307 | key := []byte("current_page_" + queryHash) 308 | value := []byte(strconv.Itoa(nextPage)) 309 | err := d.Put(key, value) 310 | if err != nil { 311 | return err // Put already wraps error 312 | } 313 | log.WithField("queryHash", queryHash).Debugf("Set page state to: %d", nextPage) 314 | return nil 315 | } 316 | 317 | // DeletePageState removes the saved page number for a given query hash. 318 | func (d *DB) DeletePageState(queryHash string) error { 319 | key := []byte("current_page_" + queryHash) 320 | err := d.Delete(key) 321 | if err != nil && err != bitcask.ErrKeyNotFound { 322 | return fmt.Errorf("error deleting page state for %s: %w", queryHash, err) 323 | } 324 | log.WithField("queryHash", queryHash).Info("Deleted page state") 325 | return nil // Treat KeyNotFound as success 326 | } 327 | 328 | // TODO: Add functions for CLI features like ListModels, GetModelInfo, etc. 329 | -------------------------------------------------------------------------------- /internal/downloader/downloader.go: -------------------------------------------------------------------------------- 1 | package downloader 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "mime" 8 | "net/http" 9 | "os" 10 | "path/filepath" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | "go-civitai-download/internal/helpers" 16 | "go-civitai-download/internal/models" 17 | 18 | log "github.com/sirupsen/logrus" 19 | ) 20 | 21 | // Custom Downloader Errors 22 | var ( 23 | ErrHashMismatch = errors.New("downloaded file hash mismatch") 24 | ErrHttpStatus = errors.New("unexpected HTTP status code") 25 | ErrFileSystem = errors.New("filesystem error") // Covers create, remove, rename 26 | ErrHttpRequest = errors.New("HTTP request creation/execution error") 27 | ) 28 | 29 | // Downloader handles downloading files with progress and hash checks. 30 | type Downloader struct { 31 | client *http.Client 32 | apiKey string // Add field to store API key 33 | } 34 | 35 | // NewDownloader creates a new Downloader instance. 36 | func NewDownloader(client *http.Client, apiKey string) *Downloader { 37 | if client == nil { 38 | // Provide a default client if none is passed 39 | client = &http.Client{ 40 | Timeout: 15 * time.Minute, 41 | } 42 | } 43 | return &Downloader{ 44 | client: client, 45 | apiKey: apiKey, // Store the API key 46 | } 47 | } 48 | 49 | // Helper function to check for existing file by base name and hash. 50 | // Now requires the expected file extension to avoid checking hashes on mismatched file types (e.g., .json vs .safetensors). 51 | func findExistingFileWithMatchingBaseAndHash(dirPath string, baseNameWithoutExt string, expectedExt string, hashes models.Hashes) (foundPath string, exists bool, err error) { 52 | entries, err := os.ReadDir(dirPath) 53 | if err != nil { 54 | if os.IsNotExist(err) { 55 | return "", false, nil // Directory doesn't exist, so file doesn't exist 56 | } 57 | return "", false, fmt.Errorf("reading directory %s: %w", dirPath, err) 58 | } 59 | 60 | log.Debugf("Scanning directory %s for base name '%s' with expected extension '%s' and matching hash...", dirPath, baseNameWithoutExt, expectedExt) 61 | for _, entry := range entries { 62 | if entry.IsDir() { 63 | continue // Skip directories 64 | } 65 | entryName := entry.Name() 66 | ext := filepath.Ext(entryName) 67 | entryBaseName := strings.TrimSuffix(entryName, ext) 68 | 69 | if strings.EqualFold(entryBaseName, baseNameWithoutExt) { 70 | // Base names match 71 | fullPath := filepath.Join(dirPath, entryName) 72 | 73 | hashesProvided := hashes.SHA256 != "" || hashes.BLAKE3 != "" || hashes.CRC32 != "" || hashes.AutoV2 != "" 74 | 75 | if !hashesProvided { 76 | // No standard hashes provided (likely an image), base name match is enough 77 | log.Debugf("Base name match found and no standard hashes provided. Assuming valid existing file: %s", fullPath) 78 | return fullPath, true, nil 79 | } else { 80 | // Hashes ARE provided. Check if the extension ALSO matches before checking hash. 81 | if strings.EqualFold(ext, expectedExt) { 82 | log.Debugf("Base name and extension match found: %s. Checking hash...", fullPath) 83 | if helpers.CheckHash(fullPath, hashes) { 84 | log.Debugf("Hash match successful for existing file: %s", fullPath) 85 | return fullPath, true, nil // Found a valid match! 86 | } else { 87 | log.Debugf("Hash mismatch for existing file with matching base name and extension: %s", fullPath) 88 | // Continue checking other files in case of duplicates with different content but same name/ext 89 | } 90 | } else { 91 | log.Debugf("Base name match found (%s), but extension '%s' does not match expected '%s'. Skipping hash check.", fullPath, ext, expectedExt) 92 | } 93 | } 94 | } 95 | } 96 | 97 | log.Debugf("No valid existing file found matching base name '%s' and extension '%s' in %s", baseNameWithoutExt, expectedExt, dirPath) 98 | return "", false, nil // No matching file found 99 | } 100 | 101 | // DownloadFile downloads a file from the specified URL to the target filepath. 102 | // It checks for existing files, verifies hashes, and attempts to use the 103 | // Content-Disposition header for the filename. 104 | // It also now accepts a modelVersionID to prepend to the final filename. 105 | // Returns the final filepath used (or empty string on failure) and an error if one occurred. 106 | func (d *Downloader) DownloadFile(targetFilepath string, url string, hashes models.Hashes, modelVersionID int) (string, error) { 107 | initialFinalFilepath := targetFilepath // Store the initially constructed path 108 | targetDir := filepath.Dir(initialFinalFilepath) 109 | initialBaseName := filepath.Base(initialFinalFilepath) 110 | initialExt := filepath.Ext(initialBaseName) 111 | initialBaseNameWithoutExt := strings.TrimSuffix(initialBaseName, initialExt) 112 | 113 | log.Debugf("Checking for existing file based on initial path: Dir=%s, BaseName=%s, Ext=%s", targetDir, initialBaseNameWithoutExt, initialExt) 114 | // --- Initial Check for Existing File (using new helper with expected extension) --- 115 | foundPath, exists, errCheck := findExistingFileWithMatchingBaseAndHash(targetDir, initialBaseNameWithoutExt, initialExt, hashes) 116 | if errCheck != nil { 117 | log.WithError(errCheck).Errorf("Error during initial check for existing file matching %s%s in %s", initialBaseNameWithoutExt, initialExt, targetDir) 118 | return "", fmt.Errorf("%w: initial check for existing file: %v", ErrFileSystem, errCheck) 119 | } 120 | if exists { 121 | log.Infof("Found valid existing file matching base name '%s' and extension '%s': %s. Skipping download.", initialBaseNameWithoutExt, initialExt, foundPath) 122 | return foundPath, nil // Success, return the path of the valid existing file 123 | } 124 | log.Infof("No valid file matching base name '%s' and extension '%s' found initially. Proceeding with download process.", initialBaseNameWithoutExt, initialExt) 125 | // --- End Initial Check --- 126 | 127 | // Ensure target directory exists before creating temp file 128 | if !helpers.CheckAndMakeDir(targetDir) { 129 | return "", fmt.Errorf("%w: failed to create target directory %s", ErrFileSystem, targetDir) 130 | } 131 | 132 | // Create a temporary file in the target directory 133 | baseName := filepath.Base(targetFilepath) 134 | tempFile, err := os.CreateTemp(targetDir, baseName+".*.tmp") // Use targetDir here 135 | if err != nil { 136 | return "", fmt.Errorf("%w: creating temporary file %s: %w", ErrFileSystem, targetFilepath, err) 137 | } 138 | // Use a flag to track if we should remove the temp file on error exit 139 | shouldCleanupTemp := true 140 | defer func() { 141 | if shouldCleanupTemp { 142 | // If tempFile wasn't closed explicitly due to an early error *before* the explicit close, 143 | // or if it was closed but we still need to cleanup (e.g., hash mismatch), 144 | // we might need to close it here, but the explicit close should handle most cases. 145 | // The main goal here is the os.Remove. 146 | log.Debugf("Cleaning up temporary file via defer: %s", tempFile.Name()) 147 | if removeErr := os.Remove(tempFile.Name()); removeErr != nil { 148 | log.WithError(removeErr).Warnf("Failed to remove temporary file %s during defer cleanup", tempFile.Name()) 149 | } 150 | } 151 | }() 152 | 153 | log.Info("Starting download process...") // Log before creating temp file 154 | 155 | log.Infof("Attempting to download from URL: %s", url) 156 | 157 | // Create request 158 | req, err := http.NewRequest("GET", url, nil) 159 | if err != nil { 160 | return "", fmt.Errorf("%w: creating download request for %s: %w", ErrHttpRequest, url, err) 161 | } 162 | 163 | // Add authentication header if API key is present 164 | log.Debugf("Downloader stored API Key: %s", d.apiKey) // Added Debug Log 165 | if d.apiKey != "" { 166 | log.Debug("Adding Authorization header to download request.") // Added Debug Log 167 | req.Header.Set("Authorization", "Bearer "+d.apiKey) 168 | } else { 169 | log.Debug("No API Key found, skipping Authorization header for download.") // Added Debug Log 170 | } 171 | 172 | resp, err := d.client.Do(req) 173 | if err != nil { 174 | log.WithError(err).Errorf("Error performing download request from %s", url) 175 | return "", fmt.Errorf("%w: performing request for %s: %v", ErrHttpRequest, url, err) 176 | } 177 | defer resp.Body.Close() 178 | 179 | if resp.StatusCode != http.StatusOK { 180 | log.Errorf("Error downloading file: Received status code %d from %s", resp.StatusCode, url) 181 | return "", fmt.Errorf("%w: received status %d from %s", ErrHttpStatus, resp.StatusCode, url) 182 | } 183 | 184 | // --- Filename Handling from Content-Disposition --- 185 | // Recalculate finalFilepath based on header 186 | contentDisposition := resp.Header.Get("Content-Disposition") 187 | potentialApiFilename := "" // Store potential filename from header 188 | if contentDisposition != "" { 189 | _, params, err := mime.ParseMediaType(contentDisposition) 190 | if err == nil && params["filename"] != "" { 191 | potentialApiFilename = params["filename"] 192 | log.Infof("Received filename from Content-Disposition: %s", potentialApiFilename) 193 | } else { 194 | // If the disposition is 'inline' and has no filename, it's expected, log as debug. 195 | if strings.HasPrefix(contentDisposition, "inline") && params["filename"] == "" { 196 | log.Debugf("Content-Disposition is '%s' (no filename), using constructed filename.", contentDisposition) 197 | } else { 198 | // Log other parsing issues as warnings. 199 | log.WithError(err).Warnf("Could not parse Content-Disposition header: %s", contentDisposition) 200 | } 201 | } 202 | } else { 203 | log.Warn("Warning: No Content-Disposition header found, will use constructed filename.") 204 | } 205 | 206 | // Determine the base filename to use (API provided or original) 207 | var baseFilenameToUse string 208 | if potentialApiFilename != "" { 209 | baseFilenameToUse = potentialApiFilename 210 | } else { 211 | baseFilenameToUse = filepath.Base(targetFilepath) // Use original base filename if API doesn't provide one 212 | } 213 | // Construct the path *before* prepending ID 214 | pathBeforeId := filepath.Join(filepath.Dir(targetFilepath), baseFilenameToUse) 215 | 216 | var finalFilepath string // Declare finalFilepath here 217 | // --- Prepend Model Version ID to Filename --- 218 | if modelVersionID > 0 { // Only prepend if ID is valid 219 | finalFilepath = filepath.Join(filepath.Dir(pathBeforeId), fmt.Sprintf("%d_%s", modelVersionID, baseFilenameToUse)) 220 | log.Debugf("Prepended model version ID, final target path: %s", finalFilepath) 221 | } else { 222 | finalFilepath = pathBeforeId // Use the path without ID if ID is 0 223 | log.Debugf("Model version ID is 0, final target path: %s", finalFilepath) 224 | } 225 | 226 | // --- Check Existence of FINAL Path (with potential API name and ID, using new helper) --- 227 | finalTargetDir := filepath.Dir(finalFilepath) 228 | finalBaseName := filepath.Base(finalFilepath) 229 | finalExt := filepath.Ext(finalBaseName) // Get extension from the FINAL path 230 | finalBaseNameWithoutExt := strings.TrimSuffix(finalBaseName, finalExt) 231 | 232 | log.Debugf("Checking for existing file based on determined final path: Dir=%s, BaseName=%s, Ext=%s", finalTargetDir, finalBaseNameWithoutExt, finalExt) 233 | foundPathFinal, existsFinal, errCheckFinal := findExistingFileWithMatchingBaseAndHash(finalTargetDir, finalBaseNameWithoutExt, finalExt, hashes) 234 | if errCheckFinal != nil { 235 | log.WithError(errCheckFinal).Errorf("Error during final check for existing file matching %s%s in %s", finalBaseNameWithoutExt, finalExt, finalTargetDir) 236 | return "", fmt.Errorf("%w: final check for existing file: %v", ErrFileSystem, errCheckFinal) 237 | } 238 | if existsFinal { 239 | log.Infof("Found valid existing file matching final base name '%s' and extension '%s': %s. Download not needed.", finalBaseNameWithoutExt, finalExt, foundPathFinal) 240 | shouldCleanupTemp = true // Ensure any temp file created before this check is removed 241 | return foundPathFinal, nil // Success, return the path of the valid existing file 242 | } 243 | log.Debugf("Final target file base name '%s' with extension '%s' does not exist with valid hash. Proceeding with network download to temp file.", finalBaseNameWithoutExt, finalExt) 244 | // --- End Final Path Check --- 245 | 246 | // Get the size of the file 247 | size, _ := strconv.ParseUint(resp.Header.Get("Content-Length"), 10, 64) 248 | 249 | // Create a CounterWriter 250 | counter := &helpers.CounterWriter{ 251 | Writer: tempFile, 252 | Total: 0, 253 | } 254 | 255 | // Write the body to temporary file, showing progress 256 | log.Infof("Downloading to %s (Target: %s, Size: %s)...", tempFile.Name(), finalFilepath, helpers.BytesToSize(size)) 257 | _, err = io.Copy(counter, resp.Body) 258 | if err != nil { 259 | log.WithError(err).Errorf("Error writing temporary file %s", tempFile.Name()) 260 | return "", fmt.Errorf("%w: writing temporary file %s: %v", ErrFileSystem, tempFile.Name(), err) 261 | } 262 | log.Infof("Finished writing %s.", tempFile.Name()) 263 | 264 | // --- Explicitly close the file BEFORE hash check and rename --- 265 | if err := tempFile.Close(); err != nil { 266 | // Log the error, but try to continue with hash check/rename if closing failed? 267 | // Or maybe return error here? Returning error seems safer. 268 | log.WithError(err).Errorf("Failed to explicitly close temp file %s before hash/rename", tempFile.Name()) 269 | return "", fmt.Errorf("%w: closing temp file %s: %w", ErrFileSystem, tempFile.Name(), err) 270 | } 271 | 272 | // Verify the hash of the downloaded temporary file ONLY if hashes were provided 273 | hashesProvided := hashes.SHA256 != "" || hashes.BLAKE3 != "" || hashes.CRC32 != "" || hashes.AutoV2 != "" 274 | if hashesProvided { 275 | log.Debugf("Verifying hash for temp file: %s", tempFile.Name()) 276 | if !helpers.CheckHash(tempFile.Name(), hashes) { 277 | log.Errorf("Hash mismatch for downloaded file: %s", tempFile.Name()) 278 | return "", ErrHashMismatch 279 | } 280 | log.Infof("Hash verified for %s.", tempFile.Name()) 281 | } else { 282 | log.Debugf("Skipping hash verification for %s (no expected hashes provided).", tempFile.Name()) 283 | } 284 | 285 | // Rename the temporary file to the final path 286 | log.Debugf("Renaming temp file %s to %s", tempFile.Name(), finalFilepath) 287 | if err = os.Rename(tempFile.Name(), finalFilepath); err != nil { 288 | log.WithError(err).Errorf("Error renaming temporary file %s to %s", tempFile.Name(), finalFilepath) 289 | return "", fmt.Errorf("%w: renaming temporary file %s to %s: %v", ErrFileSystem, tempFile.Name(), finalFilepath, err) 290 | } 291 | 292 | // If rename was successful, we don't want the defer to remove the temp file (which is now the final file) 293 | shouldCleanupTemp = false 294 | log.Infof("Successfully downloaded and verified %s", finalFilepath) 295 | 296 | return finalFilepath, nil 297 | } 298 | -------------------------------------------------------------------------------- /internal/helpers/helpers.go: -------------------------------------------------------------------------------- 1 | package helpers 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/hex" 6 | "fmt" 7 | "hash" 8 | "hash/crc32" 9 | "io" 10 | "math" 11 | "net/http" 12 | "os" 13 | "path/filepath" 14 | "strings" 15 | 16 | "go-civitai-download/internal/models" // Import the models package 17 | 18 | log "github.com/sirupsen/logrus" 19 | "github.com/zeebo/blake3" 20 | ) 21 | 22 | // CheckHash verifies the hash of a file against expected values. 23 | // Returns true if ANY of the provided hashes match the calculated ones. 24 | // Checks in the order: BLAKE3, SHA256, CRC32, AutoV2. 25 | func CheckHash(filePath string, hashes models.Hashes) bool { 26 | // Check BLAKE3 (Prioritized for speed) 27 | if hashes.BLAKE3 != "" { 28 | hasher := blake3.New() 29 | calculatedHash, err := calculateHash(filePath, hasher) 30 | if err != nil { 31 | log.WithError(err).Errorf("Failed to calculate BLAKE3 for %s, skipping check.", filePath) 32 | } else { 33 | if strings.EqualFold(calculatedHash, hashes.BLAKE3) { 34 | log.Debugf("BLAKE3 match for %s", filePath) 35 | return true // Match found! 36 | } else { 37 | log.Warnf("BLAKE3 mismatch for %s: Expected %s, Got %s", filePath, hashes.BLAKE3, calculatedHash) 38 | } 39 | } 40 | } 41 | 42 | // Check SHA256 43 | if hashes.SHA256 != "" { 44 | hasher := sha256.New() 45 | calculatedHash, err := calculateHash(filePath, hasher) 46 | if err != nil { 47 | log.WithError(err).Errorf("Failed to calculate SHA256 for %s, skipping check.", filePath) 48 | } else { 49 | if strings.EqualFold(calculatedHash, hashes.SHA256) { 50 | log.Debugf("SHA256 match for %s", filePath) 51 | return true // Match found! 52 | } else { 53 | log.Warnf("SHA256 mismatch for %s: Expected %s, Got %s", filePath, hashes.SHA256, calculatedHash) 54 | } 55 | } 56 | } 57 | 58 | // Check CRC32 (using Castagnoli polynomial) 59 | if hashes.CRC32 != "" { 60 | table := crc32.MakeTable(crc32.Castagnoli) 61 | hasher := crc32.New(table) 62 | calculatedHash, err := calculateHash(filePath, hasher) 63 | if err != nil { 64 | log.WithError(err).Errorf("Failed to calculate CRC32 for %s, skipping check.", filePath) 65 | } else { 66 | if strings.EqualFold(calculatedHash, hashes.CRC32) { 67 | log.Debugf("CRC32 match for %s", filePath) 68 | return true // Match found! 69 | } else { 70 | log.Warnf("CRC32 mismatch for %s: Expected %s, Got %s", filePath, hashes.CRC32, calculatedHash) 71 | } 72 | } 73 | } 74 | 75 | // Check AutoV2 (derived from SHA256) 76 | if hashes.AutoV2 != "" { 77 | hasher := sha256.New() // Still need SHA256 calculation for AutoV2 78 | calculatedSha256Hash, err := calculateHash(filePath, hasher) 79 | if err != nil { 80 | log.WithError(err).Errorf("Failed to calculate hash (for AutoV2 check) for %s, skipping check.", filePath) 81 | } else { 82 | // Civitai AutoV2 hashes seem to be the first 10 chars of SHA256 83 | if len(calculatedSha256Hash) >= 10 && strings.EqualFold(calculatedSha256Hash[:10], hashes.AutoV2) { 84 | log.Debugf("AutoV2 match for %s", filePath) 85 | return true // Match found! 86 | } else { 87 | log.Warnf("AutoV2 mismatch for %s: Expected %s, Got %s (derived from SHA256: %s)", filePath, hashes.AutoV2, calculatedSha256Hash[:min(10, len(calculatedSha256Hash))], calculatedSha256Hash) 88 | } 89 | } 90 | } 91 | 92 | // If we reached here, none of the provided hashes matched. 93 | log.Warnf("No matching hash found for %s after checking all provided types.", filePath) 94 | return false 95 | } 96 | 97 | // CounterWriter tracks the number of bytes written to the underlying writer. 98 | // It's used to display download progress. 99 | // Note: Consider moving this to the 'downloader' package later. 100 | type CounterWriter struct { 101 | Total uint64 102 | Writer io.Writer 103 | } 104 | 105 | // Write implements the io.Writer interface for CounterWriter. 106 | func (cw *CounterWriter) Write(p []byte) (n int, err error) { 107 | n, err = cw.Writer.Write(p) 108 | // Only add to total if write was successful and n is positive 109 | if err == nil && n > 0 { 110 | cw.Total += uint64(n) 111 | } 112 | // Progress reporting might be handled differently in CLI context 113 | // fmt.Printf("\rDownloaded %s", BytesToSize(cw.Total)) 114 | return n, err 115 | } 116 | 117 | // BytesToSize converts a byte count into a human-readable string (KB, MB, GB, etc.). 118 | func BytesToSize(bytes uint64) string { 119 | sizes := []string{"B", "KB", "MB", "GB", "TB"} 120 | if bytes == 0 { 121 | return "0B" 122 | } 123 | i := int(math.Floor(math.Log(float64(bytes)) / math.Log(1024))) 124 | if i >= len(sizes) { 125 | i = len(sizes) - 1 // Handle very large sizes 126 | } 127 | return fmt.Sprintf("%.2f%s", float64(bytes)/math.Pow(1024, float64(i)), sizes[i]) 128 | } 129 | 130 | // ConvertToSlug converts a string into a filesystem-friendly slug. 131 | func ConvertToSlug(str string) string { 132 | str = strings.ReplaceAll(str, " ", "_") 133 | str = strings.ReplaceAll(str, ":", "-") 134 | str = strings.ToLower(str) 135 | 136 | allowedChars := "0123456789abcdefghijklmnopqrstuvwxyz._-" 137 | 138 | var filteredDescription strings.Builder 139 | for _, ch := range str { 140 | if strings.ContainsRune(allowedChars, ch) { 141 | filteredDescription.WriteRune(ch) 142 | } 143 | } 144 | str = filteredDescription.String() 145 | 146 | // Simplify repeated separators 147 | for strings.Contains(str, "--") { 148 | str = strings.ReplaceAll(str, "--", "-") 149 | } 150 | for strings.Contains(str, "__") { 151 | str = strings.ReplaceAll(str, "__", "_") 152 | } 153 | str = strings.ReplaceAll(str, "-_", "-") 154 | str = strings.ReplaceAll(str, "_-", "-") 155 | 156 | // Remove leading/trailing separators 157 | str = strings.Trim(str, "_-") 158 | 159 | return str 160 | } 161 | 162 | // CheckAndMakeDir ensures a directory exists, creating it if necessary. 163 | // Uses standard directory permissions (0700). 164 | func CheckAndMakeDir(dir string) bool { 165 | // Use MkdirAll to create parent directories if they don't exist 166 | err := os.MkdirAll(dir, 0700) 167 | if err != nil { 168 | log.WithError(err).Errorf("Error creating directory %s", dir) // Use logrus 169 | return false 170 | } 171 | return true 172 | } 173 | 174 | // CorrectPathBasedOnImageType checks the MIME type of a file and corrects the extension 175 | // in the final path if it doesn't match the detected image type. 176 | // It only corrects for common image types (jpg, png, gif, webp). 177 | // Returns the corrected final path and an error if reading fails. 178 | func CorrectPathBasedOnImageType(tempFilePath, finalFilePath string) (string, error) { 179 | originalExt := strings.ToLower(filepath.Ext(finalFilePath)) 180 | // Map of known image MIME types to extensions 181 | mimeToExt := map[string]string{ 182 | "image/jpeg": ".jpg", 183 | "image/png": ".png", 184 | "image/gif": ".gif", 185 | "image/webp": ".webp", 186 | } 187 | 188 | fRead, errRead := os.Open(tempFilePath) 189 | if errRead != nil { 190 | log.WithError(errRead).Warnf("Could not open file %s for MIME type detection, using original extension.", tempFilePath) 191 | // Return original path, don't treat this as a fatal error for the caller 192 | return finalFilePath, nil 193 | } 194 | defer fRead.Close() 195 | 196 | // Read the first 512 bytes for MIME detection 197 | buff := make([]byte, 512) 198 | n, errReadBytes := fRead.Read(buff) 199 | if errReadBytes != nil && errReadBytes != io.EOF { 200 | log.WithError(errReadBytes).Warnf("Could not read from file %s for MIME type detection, using original extension.", tempFilePath) 201 | // Return original path 202 | return finalFilePath, nil 203 | } 204 | 205 | // Only use the bytes actually read 206 | mimeType := http.DetectContentType(buff[:n]) 207 | // Extract the main type (e.g., "image/jpeg") 208 | mainMimeType := strings.Split(mimeType, ";")[0] 209 | 210 | correctedFinalPath := finalFilePath // Default to original 211 | 212 | if detectedExt, ok := mimeToExt[mainMimeType]; ok { 213 | log.Debugf("Detected MIME type: %s -> Extension: %s for %s", mimeType, detectedExt, tempFilePath) 214 | 215 | // Check for mismatch, BUT allow .jpeg for image/jpeg 216 | mismatch := originalExt != detectedExt 217 | if mismatch && mainMimeType == "image/jpeg" && originalExt == ".jpeg" { 218 | mismatch = false // Allow .jpeg extension for jpeg content 219 | log.Debugf("Original extension '.jpeg' is valid for detected type 'image/jpeg'. No correction needed.") 220 | } 221 | 222 | if mismatch { // Correct only if it's a real mismatch 223 | correctedFinalPath = strings.TrimSuffix(finalFilePath, originalExt) + detectedExt 224 | log.Warnf("Original extension '%s' differs from detected image type '%s'. Correcting final path to: %s", originalExt, detectedExt, correctedFinalPath) 225 | } else if originalExt == detectedExt { // Log if it matched exactly 226 | log.Debugf("Original extension '%s' matches detected image type '%s'. No path correction needed.", originalExt, detectedExt) 227 | } 228 | } else { 229 | log.Debugf("Detected MIME type '%s' for %s is not in the recognized image map. Using original extension '%s'.", mimeType, tempFilePath, originalExt) 230 | } 231 | 232 | return correctedFinalPath, nil 233 | } 234 | 235 | // TODO: Move loadConfig function to internal/config/config.go 236 | 237 | // -- Hashing Helper -- 238 | func calculateHash(filePath string, hashAlgo hash.Hash) (string, error) { 239 | file, err := os.Open(filePath) 240 | if err != nil { 241 | return "", fmt.Errorf("opening file %s for hashing: %w", filePath, err) 242 | } 243 | defer file.Close() 244 | 245 | if _, err := io.Copy(hashAlgo, file); err != nil { 246 | return "", fmt.Errorf("hashing file %s: %w", filePath, err) 247 | } 248 | 249 | return hex.EncodeToString(hashAlgo.Sum(nil)), nil 250 | } 251 | -------------------------------------------------------------------------------- /internal/helpers/helpers_test.go: -------------------------------------------------------------------------------- 1 | package helpers 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "strings" 7 | "testing" 8 | 9 | "go-civitai-download/internal/models" // For models.Hashes 10 | ) 11 | 12 | func TestConvertToSlug(t *testing.T) { 13 | tests := []struct { 14 | name string 15 | input string 16 | want string 17 | }{ 18 | {"Empty string", "", ""}, 19 | {"Simple string", "Simple Test", "simple_test"}, 20 | {"With colon", "Test: Colon", "test-colon"}, 21 | {"With numbers", "Model V1.5", "model_v1.5"}, 22 | {"Mixed case", "MixedCase Slug", "mixedcase_slug"}, 23 | {"Invalid characters", "File*Name?Is\"Bad!", "filenameisbad"}, 24 | {"Repeated dashes", "double--dash", "double-dash"}, 25 | {"Repeated underscores", "double__underscore", "double_underscore"}, 26 | {"Mixed repeated separators", "mixed-_-separator--test", "mixed-separator-test"}, 27 | {"Leading/trailing spaces (handled by Trim)", " Leading Trailing ", "leading_trailing"}, 28 | {"Leading/trailing separators", "-_Leading Trailing_-_", "leading_trailing"}, 29 | {"Already valid", "valid-slug_1.0", "valid-slug_1.0"}, 30 | {"All invalid", "!@#$%^&*()+", ""}, 31 | } 32 | 33 | for _, tt := range tests { 34 | t.Run(tt.name, func(t *testing.T) { 35 | got := ConvertToSlug(tt.input) 36 | if got != tt.want { 37 | t.Errorf("ConvertToSlug(%q) = %q, want %q", tt.input, got, tt.want) 38 | } 39 | }) 40 | } 41 | } 42 | 43 | func TestBytesToSize(t *testing.T) { 44 | tests := []struct { 45 | name string 46 | bytes uint64 47 | want string 48 | }{ 49 | {"Zero bytes", 0, "0B"}, 50 | {"Bytes", 500, "500.00B"}, 51 | {"Kilobytes", 1024, "1.00KB"}, 52 | {"Kilobytes fractional", 1536, "1.50KB"}, 53 | {"Megabytes", 1024 * 1024, "1.00MB"}, 54 | {"Megabytes fractional", 1024*1024 + 512*1024, "1.50MB"}, 55 | {"Gigabytes", 1024 * 1024 * 1024, "1.00GB"}, 56 | {"Terabytes", 1024 * 1024 * 1024 * 1024, "1.00TB"}, 57 | {"Large Terabytes", 1536 * 1024 * 1024 * 1024, "1.50TB"}, 58 | // Add edge cases if necessary 59 | } 60 | 61 | for _, tt := range tests { 62 | t.Run(tt.name, func(t *testing.T) { 63 | got := BytesToSize(tt.bytes) 64 | if got != tt.want { 65 | t.Errorf("BytesToSize(%d) = %q, want %q", tt.bytes, got, tt.want) 66 | } 67 | }) 68 | } 69 | } 70 | 71 | func TestCheckHash(t *testing.T) { 72 | // Create a temporary directory for test files 73 | tempDir := t.TempDir() 74 | 75 | // Test file content and its known hashes 76 | testContent := []byte("this is test content for hashing") 77 | // Calculate expected hashes (replace with actual known values if preferred) 78 | expectedBlake3 := "B3C004D66E2A918576F44266A57BBCF854B79ED13D068A6A0EF5156C3CF41B74" 79 | expectedCRC32 := "4c6b15d9" 80 | expectedSHA256 := "f7b8f3f1c4c7c3f1d7f1e4e1e5f3f7f9a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0a0" // Placeholder - Recalculate this! 81 | // Note: You might want to pre-calculate these using external tools (like sha256sum, crc32, b3sum) 82 | // For SHA256: echo -n "this is test content for hashing" | sha256sum -> e41e304c0e53a1561616a4871f64707701a38342665599694bb3774519a867e7 83 | expectedSHA256 = "e41e304c0e53a1561616a4871f64707701a38342665599694bb3774519a867e7" // Corrected 84 | 85 | // Create the test file 86 | testFilePath := filepath.Join(tempDir, "test_hash_file.txt") 87 | err := os.WriteFile(testFilePath, testContent, 0644) 88 | if err != nil { 89 | t.Fatalf("Failed to create test file: %v", err) 90 | } 91 | 92 | // --- Test Cases --- 93 | tests := []struct { 94 | name string 95 | filepath string 96 | hashes models.Hashes 97 | wantResult bool 98 | }{ 99 | { 100 | name: "No file exists", 101 | filepath: filepath.Join(tempDir, "nonexistent_file.txt"), 102 | hashes: models.Hashes{BLAKE3: expectedBlake3}, 103 | wantResult: false, 104 | }, 105 | { 106 | name: "File exists, BLAKE3 match", 107 | filepath: testFilePath, 108 | hashes: models.Hashes{BLAKE3: expectedBlake3}, 109 | wantResult: true, 110 | }, 111 | { 112 | name: "File exists, CRC32 match (lowercase api)", 113 | filepath: testFilePath, 114 | hashes: models.Hashes{CRC32: expectedCRC32}, // Function handles case diff 115 | wantResult: true, 116 | }, 117 | { 118 | name: "File exists, SHA256 match (uppercase api)", 119 | filepath: testFilePath, 120 | hashes: models.Hashes{SHA256: strings.ToUpper(expectedSHA256)}, // Function handles case diff 121 | wantResult: true, 122 | }, 123 | { 124 | name: "File exists, multiple hashes match", 125 | filepath: testFilePath, 126 | hashes: models.Hashes{BLAKE3: expectedBlake3, CRC32: expectedCRC32, SHA256: expectedSHA256}, 127 | wantResult: true, 128 | }, 129 | { 130 | name: "File exists, one hash mismatch, one match", 131 | filepath: testFilePath, 132 | hashes: models.Hashes{BLAKE3: "incorrecthash", CRC32: expectedCRC32}, 133 | wantResult: true, // Should return true if any hash matches 134 | }, 135 | { 136 | name: "File exists, all hashes mismatch", 137 | filepath: testFilePath, 138 | hashes: models.Hashes{BLAKE3: "incorrect1", CRC32: "incorrect2", SHA256: "incorrect3"}, 139 | wantResult: false, 140 | }, 141 | { 142 | name: "File exists, no hashes provided", 143 | filepath: testFilePath, 144 | hashes: models.Hashes{}, 145 | wantResult: false, 146 | }, 147 | } 148 | 149 | for _, tt := range tests { 150 | t.Run(tt.name, func(t *testing.T) { 151 | gotResult := CheckHash(tt.filepath, tt.hashes) 152 | if gotResult != tt.wantResult { 153 | t.Errorf("CheckHash(%q, %+v) = %v, want %v", tt.filepath, tt.hashes, gotResult, tt.wantResult) 154 | } 155 | }) 156 | } 157 | } 158 | 159 | func TestCheckAndMakeDir(t *testing.T) { 160 | // Create a base temporary directory for this test 161 | baseTempDir := t.TempDir() 162 | 163 | tests := []struct { 164 | name string 165 | dirToMake string // Relative to baseTempDir 166 | wantResult bool 167 | wantExists bool // Check if the directory should actually exist afterwards 168 | }{ 169 | { 170 | name: "Create simple directory", 171 | dirToMake: "new_dir", 172 | wantResult: true, 173 | wantExists: true, 174 | }, 175 | { 176 | name: "Create nested directory", 177 | dirToMake: filepath.Join("nested", "dir", "to", "create"), 178 | wantResult: true, 179 | wantExists: true, 180 | }, 181 | { 182 | name: "Attempt to create directory that is a file", 183 | dirToMake: "existing_file.txt", 184 | wantResult: false, // Should fail because it's a file 185 | wantExists: false, // Directory should not exist 186 | }, 187 | { 188 | name: "Directory already exists", 189 | dirToMake: "already_exists", 190 | wantResult: true, // Should succeed even if it exists 191 | wantExists: true, 192 | }, 193 | } 194 | 195 | // Pre-create structures needed for certain tests 196 | preExistingDir := filepath.Join(baseTempDir, "already_exists") 197 | if err := os.Mkdir(preExistingDir, 0755); err != nil { 198 | t.Fatalf("Failed to pre-create directory %s: %v", preExistingDir, err) 199 | } 200 | preExistingFile := filepath.Join(baseTempDir, "existing_file.txt") 201 | if _, err := os.Create(preExistingFile); err != nil { 202 | t.Fatalf("Failed to pre-create file %s: %v", preExistingFile, err) 203 | } 204 | 205 | for _, tt := range tests { 206 | t.Run(tt.name, func(t *testing.T) { 207 | fullPathToMake := filepath.Join(baseTempDir, tt.dirToMake) 208 | gotResult := CheckAndMakeDir(fullPathToMake) 209 | 210 | if gotResult != tt.wantResult { 211 | t.Errorf("CheckAndMakeDir(%q) = %v, want %v", fullPathToMake, gotResult, tt.wantResult) 212 | } 213 | 214 | // Verify if the directory actually exists or not 215 | _, err := os.Stat(fullPathToMake) 216 | gotExists := err == nil 217 | 218 | if gotExists != tt.wantExists { 219 | if tt.wantExists { 220 | t.Errorf("CheckAndMakeDir(%q) succeeded (%v) but directory does not exist", fullPathToMake, gotResult) 221 | } else { 222 | t.Errorf("CheckAndMakeDir(%q) failed (%v) but directory unexpectedly exists", fullPathToMake, gotResult) 223 | } 224 | } 225 | 226 | // Double-check if it's actually a directory (if it should exist) 227 | if tt.wantExists && gotExists { 228 | info, _ := os.Stat(fullPathToMake) 229 | if !info.IsDir() { 230 | t.Errorf("CheckAndMakeDir(%q) created something, but it's not a directory", fullPathToMake) 231 | } 232 | } 233 | }) 234 | } 235 | } 236 | 237 | // TODO: Add tests for CheckAndMakeDir (might need filesystem mocking or cleanup) 238 | -------------------------------------------------------------------------------- /internal/models/models.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "net/url" 5 | "strconv" 6 | ) 7 | 8 | type ( 9 | Config struct { 10 | // Connection/Auth 11 | ApiKey string `toml:"ApiKey"` 12 | 13 | // Paths 14 | SavePath string `toml:"SavePath"` 15 | DatabasePath string `toml:"DatabasePath"` 16 | BleveIndexPath string `toml:"BleveIndexPath"` // New field for Bleve index path 17 | 18 | // Filtering - Model/Version Level 19 | Query string `toml:"Query"` 20 | Tag string `toml:"Tag"` 21 | Username string `toml:"Username"` 22 | ModelTypes []string `toml:"ModelTypes"` // Renamed from Types 23 | BaseModels []string `toml:"BaseModels"` 24 | IgnoreBaseModels []string `toml:"IgnoreBaseModels"` 25 | Nsfw bool `toml:"Nsfw"` // Renamed from GetNsfw 26 | ModelVersionID int `toml:"ModelVersionID"` // New 27 | DownloadAllVersions bool `toml:"DownloadAllVersions"` // New 28 | 29 | // Filtering - File Level 30 | PrimaryOnly bool `toml:"PrimaryOnly"` // Renamed from GetOnlyPrimaryModel 31 | Pruned bool `toml:"Pruned"` // Renamed from GetPruned 32 | Fp16 bool `toml:"Fp16"` // Renamed from GetFp16 33 | IgnoreFileNameStrings []string `toml:"IgnoreFileNameStrings"` 34 | 35 | // API Query Behavior 36 | Sort string `toml:"Sort"` 37 | Period string `toml:"Period"` 38 | Limit int `toml:"Limit"` 39 | MaxPages int `toml:"MaxPages"` // New 40 | 41 | // Downloader Behavior 42 | Concurrency int `toml:"Concurrency"` // Renamed from DefaultConcurrency 43 | SaveMetadata bool `toml:"SaveMetadata"` 44 | DownloadMetaOnly bool `toml:"DownloadMetaOnly"` // New 45 | SaveModelInfo bool `toml:"SaveModelInfo"` // New 46 | SaveVersionImages bool `toml:"SaveVersionImages"` // New 47 | SaveModelImages bool `toml:"SaveModelImages"` // New 48 | SkipConfirmation bool `toml:"SkipConfirmation"` // New (for --yes flag) 49 | ApiDelayMs int `toml:"ApiDelayMs"` 50 | ApiClientTimeoutSec int `toml:"ApiClientTimeoutSec"` 51 | 52 | // Other 53 | LogApiRequests bool `toml:"LogApiRequests"` 54 | } 55 | 56 | // Api Calls and Responses 57 | QueryParameters struct { 58 | Limit int `json:"limit"` 59 | Page int `json:"page,omitempty"` 60 | Query string `json:"query,omitempty"` 61 | Tag string `json:"tag,omitempty"` 62 | Username string `json:"username,omitempty"` 63 | Types []string `json:"types,omitempty"` 64 | Sort string `json:"sort"` 65 | Period string `json:"period"` 66 | PrimaryFileOnly bool `json:"primaryFileOnly,omitempty"` 67 | AllowNoCredit bool `json:"allowNoCredit,omitempty"` 68 | AllowDerivatives bool `json:"allowDerivatives,omitempty"` 69 | AllowDifferentLicenses bool `json:"allowDifferentLicenses,omitempty"` 70 | AllowCommercialUse string `json:"allowCommercialUse,omitempty"` 71 | Nsfw bool `json:"nsfw"` 72 | BaseModels []string `json:"baseModels,omitempty"` 73 | Cursor string `json:"cursor,omitempty"` 74 | } 75 | 76 | Model struct { 77 | ID int `json:"id"` 78 | Name string `json:"name"` 79 | Description string `json:"description"` 80 | Type string `json:"type"` 81 | Poi bool `json:"poi"` 82 | Nsfw bool `json:"nsfw"` 83 | AllowNoCredit bool `json:"allowNoCredit"` 84 | AllowCommercialUse []string `json:"allowCommercialUse"` 85 | AllowDerivatives bool `json:"allowDerivatives"` 86 | AllowDifferentLicense bool `json:"allowDifferentLicense"` 87 | Stats Stats `json:"stats"` 88 | Creator Creator `json:"creator"` 89 | Tags []string `json:"tags"` 90 | ModelVersions []ModelVersion `json:"modelVersions"` 91 | Meta interface{} `json:"meta"` // Meta can be null or an object, so we use interface{} 92 | } 93 | 94 | Stats struct { 95 | DownloadCount int `json:"downloadCount"` 96 | FavoriteCount int `json:"favoriteCount"` 97 | CommentCount int `json:"commentCount"` 98 | RatingCount int `json:"ratingCount"` 99 | Rating float64 `json:"rating"` 100 | } 101 | 102 | Creator struct { 103 | Username string `json:"username"` 104 | Image string `json:"image"` 105 | } 106 | 107 | // --- NEW: Struct for nested 'model' field in /model-versions/{id} response --- 108 | BaseModelInfo struct { 109 | Name string `json:"name"` 110 | Type string `json:"type"` 111 | Nsfw bool `json:"nsfw"` 112 | Poi bool `json:"poi"` 113 | Mode string `json:"mode"` // Can be null, "Archived", "TakenDown" 114 | } 115 | 116 | ModelVersion struct { 117 | ID int `json:"id"` 118 | ModelId int `json:"modelId"` 119 | Name string `json:"name"` 120 | PublishedAt string `json:"publishedAt"` 121 | UpdatedAt string `json:"updatedAt"` 122 | TrainedWords []string `json:"trainedWords"` 123 | BaseModel string `json:"baseModel"` 124 | EarlyAccessTimeFrame int `json:"earlyAccessTimeFrame"` 125 | Description string `json:"description"` 126 | Stats Stats `json:"stats"` 127 | Files []File `json:"files"` 128 | Images []ModelImage `json:"images"` 129 | DownloadUrl string `json:"downloadUrl"` 130 | // --- ADDED: Nested model info from /model-versions/{id} endpoint --- 131 | Model BaseModelInfo `json:"model"` 132 | } 133 | 134 | File struct { 135 | Name string `json:"name"` 136 | ID int `json:"id"` 137 | SizeKB float64 `json:"sizeKB"` 138 | Type string `json:"type"` 139 | Metadata Metadata `json:"metadata"` 140 | PickleScanResult string `json:"pickleScanResult"` 141 | PickleScanMessage string `json:"pickleScanMessage"` 142 | VirusScanResult string `json:"virusScanResult"` 143 | ScannedAt string `json:"scannedAt"` 144 | Hashes Hashes `json:"hashes"` 145 | DownloadUrl string `json:"downloadUrl"` 146 | Primary bool `json:"primary"` 147 | } 148 | 149 | Metadata struct { 150 | Fp string `json:"fp"` 151 | Size string `json:"size"` 152 | Format string `json:"format"` 153 | } 154 | 155 | Hashes struct { 156 | AutoV2 string `json:"AutoV2"` 157 | SHA256 string `json:"SHA256"` 158 | CRC32 string `json:"CRC32"` 159 | BLAKE3 string `json:"BLAKE3"` 160 | } 161 | 162 | ModelImage struct { 163 | ID int `json:"id"` 164 | URL string `json:"url"` 165 | Hash string `json:"hash"` // Blurhash 166 | Width int `json:"width"` 167 | Height int `json:"height"` 168 | Nsfw bool `json:"nsfw"` // Keep boolean for simplicity, align with Model struct Nsfw 169 | NsfwLevel interface{} `json:"nsfwLevel"` // Changed to interface{} to handle number OR string from API 170 | CreatedAt string `json:"createdAt"` // Consider parsing to time.Time if needed 171 | PostID *int `json:"postId"` // Use pointer for optional field 172 | Stats ImageStats `json:"stats"` 173 | Meta interface{} `json:"meta"` // Often unstructured JSON, use interface{} 174 | Username string `json:"username"` 175 | } 176 | 177 | ImageStats struct { 178 | CryCount int `json:"cryCount"` 179 | LaughCount int `json:"laughCount"` 180 | LikeCount int `json:"likeCount"` 181 | HeartCount int `json:"heartCount"` 182 | CommentCount int `json:"commentCount"` 183 | } 184 | 185 | ApiResponse struct { // Renamed from Response 186 | Items []Model `json:"items"` 187 | Metadata PaginationMetadata `json:"metadata"` // Added field for pagination 188 | } 189 | 190 | // Added struct for pagination metadata based on API docs 191 | PaginationMetadata struct { 192 | TotalItems int `json:"totalItems"` 193 | CurrentPage int `json:"currentPage"` 194 | PageSize int `json:"pageSize"` 195 | TotalPages int `json:"totalPages"` 196 | NextPage string `json:"nextPage"` 197 | PrevPage string `json:"prevPage"` // Added based on API docs 198 | NextCursor string `json:"nextCursor"` // Added based on API docs (for images endpoint mainly) 199 | } 200 | 201 | // Internal file db entry for each model 202 | DatabaseEntry struct { 203 | ModelName string `json:"modelName"` 204 | ModelType string `json:"modelType"` 205 | Version ModelVersion `json:"version"` 206 | File File `json:"file"` 207 | Timestamp int64 `json:"timestamp"` 208 | Creator Creator `json:"creator"` 209 | Filename string `json:"filename"` 210 | Folder string `json:"folder"` 211 | Status string `json:"status"` 212 | ErrorDetails string `json:"errorDetails,omitempty"` 213 | } 214 | 215 | // --- Start: /api/v1/images Endpoint Structures --- 216 | 217 | // ImageApiResponse represents the structure of the response from the /api/v1/images endpoint. 218 | ImageApiResponse struct { 219 | Items []ImageApiItem `json:"items"` // Renamed Image -> ImageApiItem to avoid conflict 220 | Metadata MetadataNextPage `json:"metadata"` 221 | } 222 | 223 | // ImageApiItem represents a single image item specifically from the /api/v1/images response. 224 | ImageApiItem struct { 225 | ID int `json:"id"` 226 | URL string `json:"url"` 227 | Hash string `json:"hash"` // Blurhash 228 | Width int `json:"width"` 229 | Height int `json:"height"` 230 | Nsfw bool `json:"nsfw"` // Keep boolean for simplicity 231 | NsfwLevel string `json:"nsfwLevel"` // None, Soft, Mature, X 232 | CreatedAt string `json:"createdAt"` 233 | PostID *int `json:"postId"` 234 | Stats ImageStats `json:"stats"` 235 | Meta interface{} `json:"meta"` 236 | Username string `json:"username"` 237 | BaseModel string `json:"baseModel"` 238 | } 239 | 240 | // MetadataNextPage is used when the API returns metadata with a `nextPage` URL. 241 | MetadataNextPage struct { 242 | TotalItems int `json:"totalItems,omitempty"` 243 | CurrentPage int `json:"currentPage,omitempty"` 244 | PageSize int `json:"pageSize,omitempty"` 245 | NextCursor string `json:"nextCursor,omitempty"` 246 | NextPage string `json:"nextPage,omitempty"` 247 | PreviousPage string `json:"previousPage,omitempty"` 248 | } 249 | // --- End: /api/v1/images Endpoint Structures --- 250 | ) 251 | 252 | // Database Status Constants 253 | const ( 254 | StatusPending = "Pending" 255 | StatusDownloaded = "Downloaded" 256 | StatusError = "Error" 257 | ) 258 | 259 | // ConstructApiUrl builds the Civitai API URL from query parameters. 260 | func ConstructApiUrl(params QueryParameters) string { 261 | base := "https://civitai.com/api/v1/models" 262 | values := url.Values{} 263 | 264 | // Add parameters if they have non-default values 265 | if params.Limit > 0 && params.Limit <= 100 { // Use API default if not set or invalid 266 | values.Set("limit", strconv.Itoa(params.Limit)) 267 | } else { 268 | // Let the API use its default limit (usually 100) 269 | } 270 | 271 | if params.Page > 0 { // Page is only relevant for non-cursor pagination (less common now) 272 | // values.Set("page", strconv.Itoa(params.Page)) 273 | // Generally, Cursor should be preferred over Page. 274 | } 275 | 276 | if params.Query != "" { 277 | values.Set("query", params.Query) 278 | } 279 | 280 | if params.Tag != "" { 281 | values.Set("tag", params.Tag) 282 | } 283 | 284 | if params.Username != "" { 285 | values.Set("username", params.Username) 286 | } 287 | 288 | for _, t := range params.Types { 289 | values.Add("types", t) 290 | } 291 | 292 | if params.Sort != "" { 293 | values.Set("sort", params.Sort) 294 | } 295 | 296 | if params.Period != "" { 297 | values.Set("period", params.Period) 298 | } 299 | 300 | if !params.AllowNoCredit { // Default is true, so only add if false 301 | values.Set("allowNoCredit", "false") 302 | } 303 | 304 | if !params.AllowDerivatives { // Default is true 305 | values.Set("allowDerivatives", "false") 306 | } 307 | 308 | if !params.AllowDifferentLicenses { // Default is true 309 | values.Set("allowDifferentLicense", "false") // API uses singular 'License' 310 | } 311 | 312 | if params.AllowCommercialUse != "Any" && params.AllowCommercialUse != "" { // Default is Any 313 | values.Set("allowCommercialUse", params.AllowCommercialUse) 314 | } 315 | 316 | // Only add nsfw param if true 317 | if params.Nsfw { 318 | values.Set("nsfw", "true") 319 | } 320 | 321 | for _, bm := range params.BaseModels { 322 | values.Add("baseModels", bm) // API uses camelCase 323 | } 324 | 325 | if params.Cursor != "" { 326 | values.Set("cursor", params.Cursor) 327 | } 328 | 329 | queryString := values.Encode() 330 | if queryString != "" { 331 | return base + "?" + queryString 332 | } 333 | return base 334 | } 335 | --------------------------------------------------------------------------------