├── .github ├── dependabot.yaml └── workflows │ ├── ci.yml │ └── verify-registry-push-pull.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── assets ├── dummy.gguf └── license.txt ├── builder └── builder.go ├── cmd └── mdltool │ ├── main.go │ └── main_test.go ├── distribution ├── client.go ├── client_test.go ├── delete_test.go ├── ecr_test.go ├── errors.go └── gar_test.go ├── go.mod ├── go.sum ├── internal ├── gguf │ ├── create.go │ ├── metadata.go │ ├── model.go │ └── model_test.go ├── mutate │ ├── model.go │ ├── mutate.go │ └── mutate_test.go ├── partial │ ├── layer.go │ └── partial.go ├── progress │ ├── reporter.go │ └── reporter_test.go ├── store │ ├── blobs.go │ ├── blobs_test.go │ ├── errors.go │ ├── index.go │ ├── index_test.go │ ├── layout.go │ ├── manifests.go │ ├── model.go │ ├── store.go │ ├── store_test.go │ └── testdata │ │ ├── dummy.gguf │ │ └── license.txt └── utils │ └── utils.go ├── registry ├── artifact.go ├── client.go └── errors.go ├── scripts └── model-push │ ├── README.md │ ├── llama-converter │ ├── Dockerfile │ └── entrypoint.sh │ └── push-model.sh └── types ├── config.go └── model.go /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: gomod 4 | directory: / 5 | schedule: 6 | interval: weekly 7 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | 8 | jobs: 9 | build: 10 | name: Build and Test 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v4 18 | with: 19 | go-version-file: go.mod 20 | cache: true 21 | 22 | - name: Run all checks 23 | run: make all 24 | -------------------------------------------------------------------------------- /.github/workflows/verify-registry-push-pull.yml: -------------------------------------------------------------------------------- 1 | name: Verify Registry Push Pull 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | permissions: 7 | contents: read 8 | id-token: write # Required for OIDC authentication 9 | 10 | env: 11 | AWS_ACCOUNT_ID: 676043725699 12 | AWS_REGION: us-east-1 13 | ECR_REPOSITORY: images/model-distribution 14 | GAR_LOCATION: us-east4-docker.pkg.dev 15 | GAR_REGION: us-east4 16 | GAR_REPOSITORY: docker-model-distribution 17 | MODEL_NAME: test-model 18 | MODEL_VERSION: latest 19 | PROJECT_ID: sandbox-298914 20 | 21 | jobs: 22 | verify-gar: 23 | runs-on: ubuntu-latest 24 | steps: 25 | - name: Checkout code 26 | uses: actions/checkout@v4 27 | 28 | - name: Set up Go 29 | uses: actions/setup-go@v4 30 | with: 31 | go-version: '1.21' 32 | 33 | - name: Authenticate to Google Cloud 34 | uses: google-github-actions/auth@v2 35 | with: 36 | project_id: ${{ env.PROJECT_ID }} 37 | workload_identity_provider: 'projects/981855438795/locations/global/workloadIdentityPools/model-distribution-pool/providers/model-distribution-github' 38 | create_credentials_file: true 39 | 40 | - name: Configure Docker for GAR 41 | run: | 42 | gcloud auth configure-docker ${{ env.GAR_LOCATION }} --quiet 43 | 44 | - name: Run tests with GAR integration 45 | run: | 46 | # Set environment variables for the test 47 | export TEST_GAR_ENABLED=true 48 | 49 | # Set the full tag directly (preferred method) 50 | export TEST_GAR_TAG="${{ env.GAR_LOCATION }}/${{ env.PROJECT_ID }}/${{ env.GAR_REPOSITORY }}/${{ env.MODEL_NAME }}:${{ env.MODEL_VERSION }}" 51 | # GOOGLE_APPLICATION_CREDENTIALS is automatically set by the auth action 52 | echo "Using credentials file at: ${GOOGLE_APPLICATION_CREDENTIALS}" 53 | echo "Using GAR tag: ${TEST_GAR_TAG}" 54 | 55 | # Run the tests 56 | go test -v ./pkg/distribution -run TestGARIntegration 57 | 58 | verify-ecr: 59 | runs-on: ubuntu-latest 60 | steps: 61 | - name: Checkout code 62 | uses: actions/checkout@v4 63 | 64 | - name: Set up Go 65 | uses: actions/setup-go@v4 66 | with: 67 | go-version: '1.21' 68 | 69 | - name: Configure AWS Credentials 70 | id: assume-role 71 | uses: aws-actions/configure-aws-credentials@v4 72 | with: 73 | role-to-assume: arn:aws:iam::${{ env.AWS_ACCOUNT_ID }}:role/release-model-distribution 74 | role-session-name: gha-build-push-image-ecr 75 | aws-region: ${{ env.AWS_REGION }} 76 | 77 | - name: Create ECR Repository 78 | run: | 79 | # Check if repository exists, create if it doesn't 80 | aws ecr describe-repositories --repository-names ${{ env.ECR_REPOSITORY }}/${{ env.MODEL_NAME }} || \ 81 | aws ecr create-repository --repository-name ${{ env.ECR_REPOSITORY }}/${{ env.MODEL_NAME }} 82 | 83 | - name: Configure Docker for ECR 84 | run: | 85 | aws ecr get-login-password --region ${{ env.AWS_REGION }} | docker login --username AWS --password-stdin ${{ env.AWS_ACCOUNT_ID }}.dkr.ecr.${{ env.AWS_REGION }}.amazonaws.com 86 | 87 | - name: Run tests with ECR integration 88 | run: | 89 | # Set environment variables for the test 90 | export TEST_ECR_ENABLED=true 91 | 92 | # Set the full tag directly (preferred method) 93 | export TEST_ECR_TAG="${{ env.AWS_ACCOUNT_ID }}.dkr.ecr.${{ env.AWS_REGION }}.amazonaws.com/${{ env.ECR_REPOSITORY }}/${{ env.MODEL_NAME }}:${{ env.MODEL_VERSION }}" 94 | echo "Using ECR tag: ${TEST_ECR_TAG}" 95 | 96 | # Run the tests 97 | go test -v ./pkg/distribution -run TestECRIntegration 98 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Environment variables 2 | .env 3 | 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | model-distribution-tool 11 | 12 | # Test binary, built with `go test -c` 13 | *.test 14 | 15 | # Output of the go coverage tool, specifically when used with LiteIDE 16 | *.out 17 | 18 | # Dependency directories (remove the comment below to include it) 19 | vendor/ 20 | 21 | # Go workspace file 22 | go.work 23 | 24 | # IDE specific files 25 | .idea 26 | .vscode 27 | *.swp 28 | *.swo 29 | 30 | # Test artifacts 31 | test/artifacts/ 32 | 33 | # Project specific 34 | /bin/ 35 | /tmp/ 36 | /models/ 37 | 38 | # OS specific 39 | .DS_Store 40 | Thumbs.db 41 | model-store 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | https://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | Copyright 2025 Docker, Inc. 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | https://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. 192 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all build test clean lint run 2 | 3 | # Import env file if it exists 4 | -include .env 5 | 6 | # Build variables 7 | BINARY_NAME=model-distribution-tool 8 | VERSION?=0.1.0 9 | 10 | # Go related variables 11 | GOBASE=$(shell pwd) 12 | GOBIN=$(GOBASE)/bin 13 | 14 | # Run configuration 15 | SOURCE?= 16 | TAG?= 17 | STORE_PATH?=./model-store 18 | 19 | # Use linker flags to provide version/build information 20 | LDFLAGS=-ldflags "-X main.Version=${VERSION}" 21 | 22 | all: clean lint build test 23 | 24 | build: 25 | @echo "Building ${BINARY_NAME}..." 26 | @mkdir -p ${GOBIN} 27 | @go build ${LDFLAGS} -o ${GOBIN}/${BINARY_NAME} github.com/docker/model-distribution/cmd/mdltool 28 | 29 | test: 30 | @echo "Running unit tests..." 31 | @go test -v ./... 32 | 33 | clean: 34 | @echo "Cleaning..." 35 | @rm -rf ${GOBIN} 36 | @rm -f ${BINARY_NAME} 37 | @rm -f *.test 38 | @rm -rf test/artifacts/* 39 | 40 | lint: 41 | @echo "Running linters..." 42 | @gofmt -s -l . | tee /dev/stderr | xargs -r false 43 | @go vet ./... 44 | 45 | run-pull: 46 | @echo "Pulling model from ${TAG}..." 47 | @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} pull ${TAG} 48 | 49 | run-package: 50 | @echo "Pushing model ${SOURCE} to ${TAG}..." 51 | @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} package ${SOURCE} ${TAG} ${LICENSE:+--license ${LICENSE}} 52 | 53 | run-list: 54 | @echo "Listing models..." 55 | @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} list 56 | 57 | run-get: 58 | @echo "Getting model ${TAG}..." 59 | @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} get ${TAG} 60 | 61 | run-get-path: 62 | @echo "Getting path for model ${TAG}..." 63 | @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} get-path ${TAG} 64 | 65 | run-rm: 66 | @echo "Removing model ${TAG}..." 67 | @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} rm ${TAG} 68 | 69 | run-tag: 70 | @echo "Tagging model ${SOURCE} as ${TAG}..." 71 | @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} tag ${SOURCE} ${TAG} 72 | 73 | help: 74 | @echo "Available targets:" 75 | @echo " all - Clean, build, and test" 76 | @echo " build - Build the binary" 77 | @echo " test - Run unit tests" 78 | @echo " clean - Clean build artifacts" 79 | @echo " run-pull - Pull a model (TAG=registry/model:tag)" 80 | @echo " run-package - Package and push a model (SOURCE=path/to/model.gguf TAG=registry/model:tag LICENSE=path/to/license.txt)" 81 | @echo " run-list - List all models" 82 | @echo " run-get - Get model info (TAG=registry/model:tag)" 83 | @echo " run-get-path - Get model path (TAG=registry/model:tag)" 84 | @echo " run-rm - Remove a model (TAG=registry/model:tag)" 85 | @echo " run-tag - Tag a model (SOURCE=registry/model:tag TAG=registry/model:newtag)" 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model Distribution 2 | 3 | A library and CLI tool for distributing models using container registries. 4 | 5 | ## Overview 6 | 7 | Model Distribution is a Go library and CLI tool that allows you to package, push, pull, and manage models using container registries. It provides a simple API and command-line interface for working with models in GGUF format. 8 | 9 | ## Features 10 | 11 | - Push models to container registries 12 | - Pull models from container registries 13 | - Local model storage 14 | - Model metadata management 15 | - Command-line interface for all operations 16 | 17 | ## Usage 18 | 19 | ### As a CLI Tool 20 | 21 | ```bash 22 | # Build the CLI tool 23 | make build 24 | 25 | # Pull a model from a registry 26 | ./bin/model-distribution-tool pull registry.example.com/models/llama:v1.0 27 | 28 | # Package a model and push to a registry 29 | ./bin/model-distribution-tool package ./model.gguf registry.example.com/models/llama:v1.0 30 | 31 | # Package a model with license files and push to a registry 32 | ./bin/model-distribution-tool package --licenses license1.txt --licenses license2.txt ./model.gguf registry.example.com/models/llama:v1.0 33 | 34 | # Push a model from the content store to the registry 35 | ./bin/model-distribution-tool push registry.example.com/models/llama:v1.0 36 | 37 | # List all models in the local store 38 | ./bin/model-distribution-tool list 39 | 40 | # Get information about a model 41 | ./bin/model-distribution-tool get registry.example.com/models/llama:v1.0 42 | 43 | # Get the local file path for a model 44 | ./bin/model-distribution-tool get-path registry.example.com/models/llama:v1.0 45 | 46 | # Remove a model from the local store (will untag w/o deleting if there are multiple tags) 47 | ./bin/model-distribution-tool rm registry.example.com/models/llama:v1.0 48 | 49 | # Force Removal of a model from the local store, even when there are multiple referring tags 50 | ./bin/model-distribution-tool rm --force sha256:0b329b335467cccf7aa219e8f5e1bd65e59b6dfa81cfa42fba2f8881268fbf82 51 | 52 | # Tag a model with an additional reference 53 | ./bin/model-distribution-tool tag registry.example.com/models/llama:v1.0 registry.example.com/models/llama:latest 54 | ``` 55 | 56 | For more information about the CLI tool, run: 57 | 58 | ```bash 59 | ./bin/model-distribution-tool --help 60 | ``` 61 | 62 | ### As a Library 63 | 64 | ```go 65 | import ( 66 | "context" 67 | "github.com/docker/model-distribution/pkg/distribution" 68 | ) 69 | 70 | // Create a new client 71 | client, err := distribution.NewClient("/path/to/cache") 72 | if err != nil { 73 | // Handle error 74 | } 75 | 76 | // Pull a model 77 | err := client.PullModel(context.Background(), "registry.example.com/models/llama:v1.0", os.Stdout) 78 | if err != nil { 79 | // Handle error 80 | } 81 | 82 | // Get a model 83 | model, err := client.GetModel("registry.example.com/models/llama:v1.0") 84 | if err != nil { 85 | // Handle error 86 | } 87 | 88 | // Get the GGUF file path 89 | modelPath, err := model.GGUFPath() 90 | if err != nil { 91 | // Handle error 92 | } 93 | 94 | fmt.Println("Model path:", modelPath) 95 | 96 | // List all models 97 | models, err := client.ListModels() 98 | if err != nil { 99 | // Handle error 100 | } 101 | 102 | // Delete a model 103 | err = client.DeleteModel("registry.example.com/models/llama:v1.0", false) 104 | if err != nil { 105 | // Handle error 106 | } 107 | 108 | // Tag a model 109 | err = client.Tag("registry.example.com/models/llama:v1.0", "registry.example.com/models/llama:latest") 110 | if err != nil { 111 | // Handle error 112 | } 113 | 114 | // Push a model 115 | err = client.PushModel("registry.example.com/models/llama:v1.0") 116 | if err != nil { 117 | // Handle error 118 | } 119 | ``` 120 | -------------------------------------------------------------------------------- /assets/dummy.gguf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docker/model-distribution/b377026db94a863fceaec1a3088f924e80825888/assets/dummy.gguf -------------------------------------------------------------------------------- /assets/license.txt: -------------------------------------------------------------------------------- 1 | FAKE LICENSE 2 | -------------------------------------------------------------------------------- /builder/builder.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/docker/model-distribution/internal/gguf" 9 | "github.com/docker/model-distribution/internal/mutate" 10 | "github.com/docker/model-distribution/internal/partial" 11 | "github.com/docker/model-distribution/types" 12 | ) 13 | 14 | // Builder builds a model artifact 15 | type Builder struct { 16 | model types.ModelArtifact 17 | } 18 | 19 | // FromGGUF returns a *Builder that builds a model artifacts from a GGUF file 20 | func FromGGUF(path string) (*Builder, error) { 21 | mdl, err := gguf.NewModel(path) 22 | if err != nil { 23 | return nil, err 24 | } 25 | return &Builder{ 26 | model: mdl, 27 | }, nil 28 | } 29 | 30 | // WithLicense adds a license file to the artifact 31 | func (b *Builder) WithLicense(path string) (*Builder, error) { 32 | licenseLayer, err := partial.NewLayer(path, types.MediaTypeLicense) 33 | if err != nil { 34 | return nil, fmt.Errorf("license layer from %q: %w", path, err) 35 | } 36 | return &Builder{ 37 | model: mutate.AppendLayers(b.model, licenseLayer), 38 | }, nil 39 | } 40 | 41 | // Target represents a build target 42 | type Target interface { 43 | Write(context.Context, types.ModelArtifact, io.Writer) error 44 | } 45 | 46 | // Build finalizes the artifact and writes it to the given target, reporting progress to the given writer 47 | func (b *Builder) Build(ctx context.Context, target Target, pw io.Writer) error { 48 | return target.Write(ctx, b.model, pw) 49 | } 50 | -------------------------------------------------------------------------------- /cmd/mdltool/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | "strings" 10 | 11 | "github.com/docker/model-distribution/builder" 12 | "github.com/docker/model-distribution/distribution" 13 | "github.com/docker/model-distribution/registry" 14 | ) 15 | 16 | // stringSliceFlag is a flag that can be specified multiple times to collect multiple string values 17 | type stringSliceFlag []string 18 | 19 | func (s *stringSliceFlag) String() string { 20 | return strings.Join(*s, ", ") 21 | } 22 | 23 | func (s *stringSliceFlag) Set(value string) error { 24 | *s = append(*s, value) 25 | return nil 26 | } 27 | 28 | const ( 29 | defaultStorePath = "./model-store" 30 | version = "0.1.0" 31 | ) 32 | 33 | var ( 34 | storePath string 35 | showHelp bool 36 | showVer bool 37 | ) 38 | 39 | func init() { 40 | flag.StringVar(&storePath, "store-path", defaultStorePath, "Path to the model store") 41 | flag.BoolVar(&showHelp, "help", false, "Show help") 42 | flag.BoolVar(&showVer, "version", false, "Show version") 43 | } 44 | 45 | func main() { 46 | flag.Parse() 47 | 48 | if showVer { 49 | fmt.Printf("model-distribution-tool version %s\n", version) 50 | return 51 | } 52 | 53 | if showHelp || flag.NArg() == 0 { 54 | printUsage() 55 | return 56 | } 57 | 58 | // Create absolute path for store 59 | absStorePath, err := filepath.Abs(storePath) 60 | if err != nil { 61 | fmt.Fprintf(os.Stderr, "Error resolving store path: %v\n", err) 62 | os.Exit(1) 63 | } 64 | 65 | // Create the client 66 | client, err := distribution.NewClient( 67 | distribution.WithStoreRootPath(absStorePath), 68 | distribution.WithUserAgent("model-distribution-tool/"+version), 69 | ) 70 | 71 | if err != nil { 72 | fmt.Fprintf(os.Stderr, "Error creating client: %v\n", err) 73 | os.Exit(1) 74 | } 75 | 76 | // Get the command and arguments 77 | command := flag.Arg(0) 78 | args := flag.Args()[1:] 79 | 80 | // Execute the command 81 | exitCode := 0 82 | switch command { 83 | case "pull": 84 | exitCode = cmdPull(client, args) 85 | case "package": 86 | exitCode = cmdPackage(args) 87 | case "push": 88 | exitCode = cmdPush(client, args) 89 | case "list": 90 | exitCode = cmdList(client, args) 91 | case "get": 92 | exitCode = cmdGet(client, args) 93 | case "get-path": 94 | exitCode = cmdGetPath(client, args) 95 | case "rm": 96 | exitCode = cmdRm(client, args) 97 | case "tag": 98 | exitCode = cmdTag(client, args) 99 | default: 100 | fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command) 101 | printUsage() 102 | exitCode = 1 103 | } 104 | 105 | os.Exit(exitCode) 106 | } 107 | 108 | func printUsage() { 109 | fmt.Println("Usage: model-distribution-tool [options] [arguments]") 110 | fmt.Println("\nOptions:") 111 | flag.PrintDefaults() 112 | fmt.Println("\nCommands:") 113 | fmt.Println(" pull Pull a model from a registry") 114 | fmt.Println(" package Package a model file as an OCI artifact and push it to a registry (use --licenses to add license files)") 115 | fmt.Println(" push Push a model from the content store to the registry") 116 | fmt.Println(" list List all models") 117 | fmt.Println(" get Get a model by reference") 118 | fmt.Println(" get-path Get the local file path for a model") 119 | fmt.Println(" rm Remove a model by reference") 120 | fmt.Println("\nExamples:") 121 | fmt.Println(" model-distribution-tool --store-path ./models pull registry.example.com/models/llama:v1.0") 122 | fmt.Println(" model-distribution-tool package ./model.gguf registry.example.com/models/llama:v1.0 --licenses ./license1.txt --licenses ./license2.txt") 123 | fmt.Println(" model-distribution-tool push registry.example.com/models/llama:v1.0") 124 | fmt.Println(" model-distribution-tool list") 125 | fmt.Println(" model-distribution-tool rm registry.example.com/models/llama:v1.0") 126 | } 127 | 128 | func cmdPull(client *distribution.Client, args []string) int { 129 | if len(args) < 1 { 130 | fmt.Fprintf(os.Stderr, "Error: missing reference argument\n") 131 | fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool pull \n") 132 | return 1 133 | } 134 | 135 | reference := args[0] 136 | ctx := context.Background() 137 | 138 | if err := client.PullModel(ctx, reference, os.Stdout); err != nil { 139 | fmt.Fprintf(os.Stderr, "Error pulling model: %v\n", err) 140 | return 1 141 | } 142 | 143 | fmt.Printf("Successfully pulled model: %s\n", reference) 144 | return 0 145 | } 146 | 147 | func cmdPackage(args []string) int { 148 | fs := flag.NewFlagSet("push", flag.ExitOnError) 149 | var licensePaths stringSliceFlag 150 | fs.Var(&licensePaths, "licenses", "Paths to license files (can be specified multiple times)") 151 | if err := fs.Parse(args); err != nil { 152 | fmt.Fprintf(os.Stderr, "Error parsing flags: %v\n", err) 153 | return 1 154 | } 155 | args = fs.Args() 156 | 157 | if len(args) < 2 { 158 | fmt.Fprintf(os.Stderr, "Error: missing arguments\n") 159 | fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool push [--licenses --licenses ...]\n") 160 | return 1 161 | } 162 | 163 | source := args[0] 164 | reference := args[1] 165 | ctx := context.Background() 166 | 167 | // Check if source file exists 168 | if _, err := os.Stat(source); os.IsNotExist(err) { 169 | fmt.Fprintf(os.Stderr, "Error: source file does not exist: %s\n", source) 170 | return 1 171 | } 172 | 173 | // Check if source file is a GGUF file 174 | if !strings.HasSuffix(strings.ToLower(source), ".gguf") { 175 | fmt.Fprintf(os.Stderr, "Warning: source file does not have .gguf extension: %s\n", source) 176 | fmt.Fprintf(os.Stderr, "Continuing anyway, but this may cause issues.\n") 177 | } 178 | 179 | // Parse the reference 180 | target, err := registry.NewClient( 181 | registry.WithUserAgent("model-distribution-tool/" + version), 182 | ).NewTarget(reference) 183 | if err != nil { 184 | fmt.Fprintf(os.Stderr, "Error parsing reference: %v\n", err) 185 | return 1 186 | } 187 | 188 | // Create image with layer 189 | builder, err := builder.FromGGUF(source) 190 | if err != nil { 191 | fmt.Fprintf(os.Stderr, "Error creating model from gguf: %v\n", err) 192 | return 1 193 | } 194 | 195 | // Add all license files as layers 196 | for _, path := range licensePaths { 197 | fmt.Println("Adding license file:", path) 198 | builder, err = builder.WithLicense(path) 199 | if err != nil { 200 | fmt.Fprintf(os.Stderr, "Error adding license layer for %s: %v\n", path, err) 201 | return 1 202 | } 203 | } 204 | 205 | // Push the image 206 | if err := builder.Build(ctx, target, os.Stdout); err != nil { 207 | fmt.Fprintf(os.Stderr, "Error writing model %q to registry: %v\n", reference, err) 208 | return 1 209 | } 210 | return 0 211 | } 212 | 213 | func cmdPush(client *distribution.Client, args []string) int { 214 | if len(args) < 1 { 215 | fmt.Fprintf(os.Stderr, "Error: missing tag argument\n") 216 | fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool push \n") 217 | return 1 218 | } 219 | 220 | tag := args[0] 221 | ctx := context.Background() 222 | 223 | if err := client.PushModel(ctx, tag, os.Stdout); err != nil { 224 | fmt.Fprintf(os.Stderr, "Error pushing model: %v\n", err) 225 | return 1 226 | } 227 | 228 | fmt.Printf("Successfully pushed model: %s\n", tag) 229 | return 0 230 | } 231 | 232 | func cmdList(client *distribution.Client, args []string) int { 233 | models, err := client.ListModels() 234 | if err != nil { 235 | fmt.Fprintf(os.Stderr, "Error listing models: %v\n", err) 236 | return 1 237 | } 238 | 239 | if len(models) == 0 { 240 | fmt.Println("No models found") 241 | return 0 242 | } 243 | 244 | fmt.Println("Models:") 245 | for i, model := range models { 246 | id, err := model.ID() 247 | if err != nil { 248 | fmt.Fprintf(os.Stderr, "Error getting model ID: %v\n", err) 249 | continue 250 | } 251 | fmt.Printf("%d. ID: %s\n", i+1, id) 252 | fmt.Printf(" Tags: %s\n", strings.Join(model.Tags(), ", ")) 253 | 254 | ggufPath, err := model.GGUFPath() 255 | if err == nil { 256 | fmt.Printf(" GGUF Path: %s\n", ggufPath) 257 | } 258 | } 259 | return 0 260 | } 261 | 262 | func cmdGet(client *distribution.Client, args []string) int { 263 | if len(args) < 1 { 264 | fmt.Fprintf(os.Stderr, "Error: missing reference argument\n") 265 | fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool get \n") 266 | return 1 267 | } 268 | 269 | reference := args[0] 270 | 271 | model, err := client.GetModel(reference) 272 | if err != nil { 273 | fmt.Fprintf(os.Stderr, "Error getting model: %v\n", err) 274 | return 1 275 | } 276 | 277 | fmt.Printf("Model: %s\n", reference) 278 | 279 | id, err := model.ID() 280 | if err != nil { 281 | fmt.Fprintf(os.Stderr, "Error getting model ID %v\n", err) 282 | return 1 283 | } 284 | fmt.Printf("ID: %s\n", id) 285 | 286 | ggufPath, err := model.GGUFPath() 287 | if err != nil { 288 | fmt.Fprintf(os.Stderr, "Error getting gguf path %v\n", err) 289 | return 1 290 | } 291 | fmt.Printf("GGUF Path: %s\n", ggufPath) 292 | 293 | cfg, err := model.Config() 294 | if err != nil { 295 | fmt.Fprintf(os.Stderr, "Error reading model config: %v\n", err) 296 | return 1 297 | } 298 | fmt.Printf("Format: %s\n", cfg.Format) 299 | fmt.Printf("Architecture: %s\n", cfg.Architecture) 300 | fmt.Printf("Parameters: %s\n", cfg.Parameters) 301 | fmt.Printf("Quantization: %s\n", cfg.Quantization) 302 | return 0 303 | } 304 | 305 | func cmdGetPath(client *distribution.Client, args []string) int { 306 | if len(args) < 1 { 307 | fmt.Fprintf(os.Stderr, "Error: missing reference argument\n") 308 | fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool get-path \n") 309 | return 1 310 | } 311 | 312 | reference := args[0] 313 | 314 | model, err := client.GetModel(reference) 315 | if err != nil { 316 | fmt.Fprintf(os.Stderr, "Failed to get model: %v\n", err) 317 | return 1 318 | } 319 | 320 | modelPath, err := model.GGUFPath() 321 | if err != nil { 322 | fmt.Fprintf(os.Stderr, "Error getting model path: %v\n", err) 323 | return 1 324 | } 325 | 326 | fmt.Println(modelPath) 327 | return 0 328 | } 329 | 330 | func cmdRm(client *distribution.Client, args []string) int { 331 | var force bool 332 | fs := flag.NewFlagSet("rm", flag.ExitOnError) 333 | fs.BoolVar(&force, "force", false, "Force remove the model") 334 | 335 | if err := fs.Parse(args); err != nil { 336 | fmt.Fprintf(os.Stderr, "Error parsing flags: %v\n", err) 337 | return 1 338 | } 339 | args = fs.Args() 340 | 341 | if len(args) < 1 { 342 | fmt.Fprintf(os.Stderr, "Error: missing reference argument\n") 343 | fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool rm [--force] \n") 344 | return 1 345 | } 346 | 347 | reference := args[0] 348 | 349 | if err := client.DeleteModel(reference, force); err != nil { 350 | fmt.Fprintf(os.Stderr, "Error removing model: %v\n", err) 351 | return 1 352 | } 353 | 354 | fmt.Printf("Successfully removed model: %s\n", reference) 355 | return 0 356 | } 357 | 358 | func cmdTag(client *distribution.Client, args []string) int { 359 | if len(args) != 2 { 360 | fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool tag \n") 361 | return 1 362 | } 363 | 364 | source := args[0] 365 | target := args[1] 366 | 367 | if err := client.Tag(source, target); err != nil { 368 | fmt.Fprintf(os.Stderr, "Error tagging model: %v\n", err) 369 | return 1 370 | } 371 | 372 | fmt.Printf("Successfully applied tag %s to model: %s\n", target, source) 373 | return 0 374 | } 375 | -------------------------------------------------------------------------------- /cmd/mdltool/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "os/exec" 6 | "path/filepath" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/docker/model-distribution/distribution" 11 | ) 12 | 13 | // TestMainHelp tests the help command 14 | func TestMainHelp(t *testing.T) { 15 | cmd := exec.Command("go", "run", "main.go", "--help") 16 | output, err := cmd.CombinedOutput() 17 | if err != nil { 18 | t.Fatalf("Failed to run help command: %v\nOutput: %s", err, output) 19 | } 20 | 21 | // Check that the output contains the usage information 22 | if !strings.Contains(string(output), "Usage:") { 23 | t.Errorf("Help output does not contain usage information") 24 | } 25 | 26 | // Check that the output contains the commands 27 | commands := []string{"pull", "package", "list", "get", "get-path"} 28 | for _, cmd := range commands { 29 | if !strings.Contains(string(output), cmd) { 30 | t.Errorf("Help output does not contain command: %s", cmd) 31 | } 32 | } 33 | } 34 | 35 | // TestMainVersion tests the version command 36 | func TestMainVersion(t *testing.T) { 37 | cmd := exec.Command("go", "run", "main.go", "--version") 38 | output, err := cmd.CombinedOutput() 39 | if err != nil { 40 | t.Fatalf("Failed to run version command: %v\nOutput: %s", err, output) 41 | } 42 | 43 | // Check that the output contains the version information 44 | if !strings.Contains(string(output), "version") { 45 | t.Errorf("Version output does not contain version information") 46 | } 47 | } 48 | 49 | // TestMainPull tests the pull command 50 | func TestMainPull(t *testing.T) { 51 | // Create a temporary directory for the test 52 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 53 | if err != nil { 54 | t.Fatalf("Failed to create temp directory: %v", err) 55 | } 56 | defer os.RemoveAll(tempDir) 57 | 58 | // Create a model store directory 59 | storeDir := filepath.Join(tempDir, "model-store") 60 | if err := os.MkdirAll(storeDir, 0755); err != nil { 61 | t.Fatalf("Failed to create model store directory: %v", err) 62 | } 63 | 64 | // Create a client for testing 65 | client, err := distribution.NewClient(distribution.WithStoreRootPath(storeDir)) 66 | if err != nil { 67 | t.Fatalf("Failed to create client: %v", err) 68 | } 69 | 70 | // Test the pull command with invalid arguments 71 | exitCode := cmdPull(client, []string{}) 72 | if exitCode != 1 { 73 | t.Errorf("Pull command with invalid arguments should fail") 74 | } 75 | } 76 | 77 | // TestMainPackage tests the package command 78 | func TestMainPackage(t *testing.T) { 79 | // Create a temporary directory for the test 80 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 81 | if err != nil { 82 | t.Fatalf("Failed to create temp directory: %v", err) 83 | } 84 | defer os.RemoveAll(tempDir) 85 | 86 | // Test the package command with invalid arguments 87 | exitCode := cmdPackage([]string{}) 88 | if exitCode != 1 { 89 | t.Errorf("Push command with invalid arguments should fail") 90 | } 91 | } 92 | 93 | // TestMainList tests the list command 94 | func TestMainList(t *testing.T) { 95 | // Create a temporary directory for the test 96 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 97 | if err != nil { 98 | t.Fatalf("Failed to create temp directory: %v", err) 99 | } 100 | defer os.RemoveAll(tempDir) 101 | 102 | // Create a client for testing 103 | client, err := distribution.NewClient(distribution.WithStoreRootPath(tempDir)) 104 | if err != nil { 105 | t.Fatalf("Failed to create client: %v", err) 106 | } 107 | 108 | // Test the list command 109 | exitCode := cmdList(client, []string{}) 110 | if exitCode != 0 { 111 | t.Errorf("List command failed with exit code: %d", exitCode) 112 | } 113 | } 114 | 115 | // TestMainGet tests the get command 116 | func TestMainGet(t *testing.T) { 117 | // Create a temporary directory for the test 118 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 119 | if err != nil { 120 | t.Fatalf("Failed to create temp directory: %v", err) 121 | } 122 | defer os.RemoveAll(tempDir) 123 | 124 | // Create a client for testing 125 | client, err := distribution.NewClient(distribution.WithStoreRootPath(tempDir)) 126 | if err != nil { 127 | t.Fatalf("Failed to create client: %v", err) 128 | } 129 | 130 | // Test the get command with invalid arguments 131 | exitCode := cmdGet(client, []string{}) 132 | if exitCode != 1 { 133 | t.Errorf("Get command with invalid arguments should fail") 134 | } 135 | } 136 | 137 | // TestMainGetPath tests the get-path command 138 | func TestMainGetPath(t *testing.T) { 139 | // Create a temporary directory for the test 140 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 141 | if err != nil { 142 | t.Fatalf("Failed to create temp directory: %v", err) 143 | } 144 | defer os.RemoveAll(tempDir) 145 | 146 | // Create a client for testing 147 | client, err := distribution.NewClient(distribution.WithStoreRootPath(tempDir)) 148 | if err != nil { 149 | t.Fatalf("Failed to create client: %v", err) 150 | } 151 | 152 | // Test the get-path command with invalid arguments 153 | exitCode := cmdGetPath(client, []string{}) 154 | if exitCode != 1 { 155 | t.Errorf("Get-path command with invalid arguments should fail") 156 | } 157 | } 158 | 159 | // TestMainPush tests the push command 160 | func TestMainPush(t *testing.T) { 161 | // Create a temporary directory for the test 162 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 163 | if err != nil { 164 | t.Fatalf("Failed to create temp directory: %v", err) 165 | } 166 | defer os.RemoveAll(tempDir) 167 | 168 | // Create a client for testing 169 | client, err := distribution.NewClient(distribution.WithStoreRootPath(tempDir)) 170 | if err != nil { 171 | t.Fatalf("Failed to create client: %v", err) 172 | } 173 | 174 | // Test the push command with invalid arguments 175 | exitCode := cmdPush(client, []string{}) 176 | if exitCode != 1 { 177 | t.Errorf("Push command with invalid arguments should fail") 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /distribution/client.go: -------------------------------------------------------------------------------- 1 | package distribution 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "os" 9 | 10 | "github.com/sirupsen/logrus" 11 | 12 | "github.com/docker/model-distribution/internal/progress" 13 | "github.com/docker/model-distribution/internal/store" 14 | "github.com/docker/model-distribution/registry" 15 | "github.com/docker/model-distribution/types" 16 | ) 17 | 18 | // Client provides model distribution functionality 19 | type Client struct { 20 | store *store.LocalStore 21 | log *logrus.Entry 22 | registry *registry.Client 23 | } 24 | 25 | // GetStorePath returns the root path where models are stored 26 | func (c *Client) GetStorePath() string { 27 | return c.store.RootPath() 28 | } 29 | 30 | // Option represents an option for creating a new Client 31 | type Option func(*options) 32 | 33 | // options holds the configuration for a new Client 34 | type options struct { 35 | storeRootPath string 36 | logger *logrus.Entry 37 | transport http.RoundTripper 38 | userAgent string 39 | } 40 | 41 | // WithStoreRootPath sets the store root path 42 | func WithStoreRootPath(path string) Option { 43 | return func(o *options) { 44 | if path != "" { 45 | o.storeRootPath = path 46 | } 47 | } 48 | } 49 | 50 | // WithLogger sets the logger 51 | func WithLogger(logger *logrus.Entry) Option { 52 | return func(o *options) { 53 | if logger != nil { 54 | o.logger = logger 55 | } 56 | } 57 | } 58 | 59 | // WithTransport sets the HTTP transport to use when pulling and pushing models. 60 | func WithTransport(transport http.RoundTripper) Option { 61 | return func(o *options) { 62 | if transport != nil { 63 | o.transport = transport 64 | } 65 | } 66 | } 67 | 68 | // WithUserAgent sets the User-Agent header to use when pulling and pushing models. 69 | func WithUserAgent(ua string) Option { 70 | return func(o *options) { 71 | if ua != "" { 72 | o.userAgent = ua 73 | } 74 | } 75 | } 76 | 77 | func defaultOptions() *options { 78 | return &options{ 79 | logger: logrus.NewEntry(logrus.StandardLogger()), 80 | transport: registry.DefaultTransport, 81 | userAgent: registry.DefaultUserAgent, 82 | } 83 | } 84 | 85 | // NewClient creates a new distribution client 86 | func NewClient(opts ...Option) (*Client, error) { 87 | options := defaultOptions() 88 | for _, opt := range opts { 89 | opt(options) 90 | } 91 | 92 | if options.storeRootPath == "" { 93 | return nil, fmt.Errorf("store root path is required") 94 | } 95 | 96 | s, err := store.New(store.Options{ 97 | RootPath: options.storeRootPath, 98 | }) 99 | if err != nil { 100 | return nil, fmt.Errorf("initializing store: %w", err) 101 | } 102 | 103 | options.logger.Infoln("Successfully initialized store") 104 | return &Client{ 105 | store: s, 106 | log: options.logger, 107 | registry: registry.NewClient( 108 | registry.WithTransport(options.transport), 109 | registry.WithUserAgent(options.userAgent), 110 | ), 111 | }, nil 112 | } 113 | 114 | // PullModel pulls a model from a registry and returns the local file path 115 | func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer) error { 116 | c.log.Infoln("Starting model pull:", reference) 117 | 118 | remoteModel, err := c.registry.Model(ctx, reference) 119 | if err != nil { 120 | return fmt.Errorf("reading model from registry: %w", err) 121 | } 122 | 123 | // Check for supported type 124 | if err := checkCompat(remoteModel); err != nil { 125 | return err 126 | } 127 | 128 | // Get the remote image digest 129 | remoteDigest, err := remoteModel.Digest() 130 | if err != nil { 131 | c.log.Errorln("Failed to get remote image digest:", err) 132 | return fmt.Errorf("getting remote image digest: %w", err) 133 | } 134 | c.log.Infoln("Remote model digest:", remoteDigest.String()) 135 | 136 | // Check if model exists in local store 137 | localModel, err := c.store.Read(remoteDigest.String()) 138 | if err == nil { 139 | c.log.Infoln("Model found in local store:", reference) 140 | ggufPath, err := localModel.GGUFPath() 141 | if err != nil { 142 | return fmt.Errorf("getting gguf path: %w", err) 143 | } 144 | 145 | // Get file size for progress reporting 146 | fileInfo, err := os.Stat(ggufPath) 147 | if err != nil { 148 | return fmt.Errorf("getting file info: %w", err) 149 | } 150 | 151 | // Report progress for local model 152 | size := fileInfo.Size() 153 | err = progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %.2f MB", float64(size)/1024/1024)) 154 | if err != nil { 155 | c.log.Warnf("Writing progress: %v", err) 156 | // If we fail to write progress, don't try again 157 | progressWriter = nil 158 | } 159 | 160 | // Ensure model has the correct tag 161 | if err := c.store.AddTags(remoteDigest.String(), []string{reference}); err != nil { 162 | return fmt.Errorf("tagging modle: %w", err) 163 | } 164 | return nil 165 | } else { 166 | c.log.Infoln("Model not found in local store, pulling from remote:", reference) 167 | } 168 | 169 | // Model doesn't exist in local store or digests don't match, pull from remote 170 | 171 | if err = c.store.Write(remoteModel, []string{reference}, progressWriter); err != nil { 172 | if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil { 173 | c.log.Warnf("Failed to write error message: %v", writeErr) 174 | // If we fail to write error message, don't try again 175 | progressWriter = nil 176 | } 177 | return fmt.Errorf("writing image to store: %w", err) 178 | } 179 | 180 | if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil { 181 | c.log.Warnf("Failed to write success message: %v", err) 182 | // If we fail to write success message, don't try again 183 | progressWriter = nil 184 | } 185 | 186 | return nil 187 | } 188 | 189 | // ListModels returns all available models 190 | func (c *Client) ListModels() ([]types.Model, error) { 191 | c.log.Infoln("Listing available models") 192 | modelInfos, err := c.store.List() 193 | if err != nil { 194 | c.log.Errorln("Failed to list models:", err) 195 | return nil, fmt.Errorf("listing models: %w", err) 196 | } 197 | 198 | result := make([]types.Model, 0, len(modelInfos)) 199 | for _, modelInfo := range modelInfos { 200 | // Read the models 201 | model, err := c.store.Read(modelInfo.ID) 202 | if err != nil { 203 | c.log.Warnf("Failed to read model with ID %s: %v", modelInfo.ID, err) 204 | continue 205 | } 206 | result = append(result, model) 207 | } 208 | 209 | c.log.Infoln("Successfully listed models, count:", len(result)) 210 | return result, nil 211 | } 212 | 213 | // GetModel returns a model by reference 214 | func (c *Client) GetModel(reference string) (types.Model, error) { 215 | c.log.Infoln("Getting model by reference:", reference) 216 | model, err := c.store.Read(reference) 217 | if err != nil { 218 | c.log.Errorln("Failed to get model:", err, "reference:", reference) 219 | return nil, fmt.Errorf("get model '%q': %w", reference, err) 220 | } 221 | 222 | return model, nil 223 | } 224 | 225 | // DeleteModel deletes a model 226 | func (c *Client) DeleteModel(reference string, force bool) error { 227 | mdl, err := c.store.Read(reference) 228 | if err != nil { 229 | return err 230 | } 231 | id, err := mdl.ID() 232 | if err != nil { 233 | return fmt.Errorf("getting model ID: %w", err) 234 | } 235 | isTag := id != reference 236 | 237 | if isTag { 238 | c.log.Infoln("Untagging model:", reference) 239 | if err := c.store.RemoveTags([]string{reference}); err != nil { 240 | c.log.Errorln("Failed to untag model:", err, "tag:", reference) 241 | return fmt.Errorf("untagging model: %w", err) 242 | } 243 | } 244 | 245 | if len(mdl.Tags()) > 1 { 246 | if isTag { 247 | return nil // we are done after untagging 248 | } else if !force { 249 | // if the reference is not a tag and there are multiple tags, return an error unless forced 250 | return fmt.Errorf( 251 | "unable to delete %q (must be forced) due to multiple tag references: %w", 252 | reference, ErrConflict, 253 | ) 254 | } 255 | } 256 | 257 | c.log.Infoln("Deleting model:", id) 258 | if err := c.store.Delete(id); err != nil { 259 | c.log.Errorln("Failed to delete model:", err, "tag:", reference) 260 | return fmt.Errorf("deleting model: %w", err) 261 | } 262 | c.log.Infoln("Successfully deleted model:", reference) 263 | return nil 264 | } 265 | 266 | // Tag adds a tag to a model 267 | func (c *Client) Tag(source string, target string) error { 268 | c.log.Infoln("Tagging model, source:", source, "target:", target) 269 | return c.store.AddTags(source, []string{target}) 270 | } 271 | 272 | // PushModel pushes a tagged model from the content store to the registry. 273 | func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Writer) (err error) { 274 | // Parse the tag 275 | target, err := c.registry.NewTarget(tag) 276 | if err != nil { 277 | return fmt.Errorf("new tag: %w", err) 278 | } 279 | 280 | // Get the model from the store 281 | mdl, err := c.store.Read(tag) 282 | if err != nil { 283 | return fmt.Errorf("reading model: %w", err) 284 | } 285 | 286 | // Push the model 287 | c.log.Infoln("Pushing model:", tag) 288 | if err := target.Write(ctx, mdl, progressWriter); err != nil { 289 | c.log.Errorln("Failed to push image:", err, "reference:", tag) 290 | if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil { 291 | c.log.Warnf("Failed to write error message: %v", writeErr) 292 | } 293 | return fmt.Errorf("pushing image: %w", err) 294 | } 295 | 296 | c.log.Infoln("Successfully pushed model:", tag) 297 | if err := progress.WriteSuccess(progressWriter, "Model pushed successfully"); err != nil { 298 | c.log.Warnf("Failed to write success message: %v", err) 299 | } 300 | 301 | return nil 302 | } 303 | 304 | func (c *Client) ResetStore() error { 305 | c.log.Infoln("Resetting store") 306 | if err := c.store.Reset(); err != nil { 307 | c.log.Errorln("Failed to reset store:", err) 308 | return fmt.Errorf("resetting store: %w", err) 309 | } 310 | return nil 311 | } 312 | 313 | func checkCompat(image types.ModelArtifact) error { 314 | manifest, err := image.Manifest() 315 | if err != nil { 316 | return err 317 | } 318 | if manifest.Config.MediaType != types.MediaTypeModelConfigV01 { 319 | return fmt.Errorf("config type %q is unsupported: %w", manifest.Config.MediaType, ErrUnsupportedMediaType) 320 | } 321 | return nil 322 | } 323 | -------------------------------------------------------------------------------- /distribution/client_test.go: -------------------------------------------------------------------------------- 1 | package distribution 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "crypto/rand" 8 | "encoding/json" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net/http/httptest" 13 | "net/url" 14 | "os" 15 | "path/filepath" 16 | "strings" 17 | "testing" 18 | 19 | "github.com/docker/model-distribution/internal/progress" 20 | "github.com/google/go-containerregistry/pkg/name" 21 | "github.com/google/go-containerregistry/pkg/registry" 22 | "github.com/google/go-containerregistry/pkg/v1/remote" 23 | "github.com/sirupsen/logrus" 24 | 25 | "github.com/docker/model-distribution/internal/gguf" 26 | "github.com/docker/model-distribution/internal/mutate" 27 | mdregistry "github.com/docker/model-distribution/registry" 28 | ) 29 | 30 | var ( 31 | testGGUFFile = filepath.Join("..", "assets", "dummy.gguf") 32 | ) 33 | 34 | func TestClientPullModel(t *testing.T) { 35 | // Set up test registry 36 | server := httptest.NewServer(registry.New()) 37 | defer server.Close() 38 | registryURL, err := url.Parse(server.URL) 39 | if err != nil { 40 | t.Fatalf("Failed to parse registry URL: %v", err) 41 | } 42 | registry := registryURL.Host 43 | 44 | // Create temp directory for store 45 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 46 | if err != nil { 47 | t.Fatalf("Failed to create temp directory: %v", err) 48 | } 49 | defer os.RemoveAll(tempDir) 50 | 51 | // Create client 52 | client, err := NewClient(WithStoreRootPath(tempDir)) 53 | if err != nil { 54 | t.Fatalf("Failed to create client: %v", err) 55 | } 56 | 57 | // Read model content for verification later 58 | modelContent, err := os.ReadFile(testGGUFFile) 59 | if err != nil { 60 | t.Fatalf("Failed to read test model file: %v", err) 61 | } 62 | 63 | model, err := gguf.NewModel(testGGUFFile) 64 | if err != nil { 65 | t.Fatalf("Failed to create model: %v", err) 66 | } 67 | tag := registry + "/testmodel:v1.0.0" 68 | ref, err := name.ParseReference(tag) 69 | if err != nil { 70 | t.Fatalf("Failed to parse reference: %v", err) 71 | } 72 | if err := remote.Write(ref, model); err != nil { 73 | t.Fatalf("Failed to push model: %v", err) 74 | } 75 | 76 | t.Run("pull without progress writer", func(t *testing.T) { 77 | // Pull model from registry without progress writer 78 | err := client.PullModel(context.Background(), tag, nil) 79 | if err != nil { 80 | t.Fatalf("Failed to pull model: %v", err) 81 | } 82 | 83 | model, err := client.GetModel(tag) 84 | if err != nil { 85 | t.Fatalf("Failed to get model: %v", err) 86 | } 87 | 88 | modelPath, err := model.GGUFPath() 89 | if err != nil { 90 | t.Fatalf("Failed to get model path: %v", err) 91 | } 92 | // Verify model content 93 | pulledContent, err := os.ReadFile(modelPath) 94 | if err != nil { 95 | t.Fatalf("Failed to read pulled model: %v", err) 96 | } 97 | 98 | if string(pulledContent) != string(modelContent) { 99 | t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) 100 | } 101 | }) 102 | 103 | t.Run("pull with progress writer", func(t *testing.T) { 104 | // Create a buffer to capture progress output 105 | var progressBuffer bytes.Buffer 106 | 107 | // Pull model from registry with progress writer 108 | if err := client.PullModel(context.Background(), tag, &progressBuffer); err != nil { 109 | t.Fatalf("Failed to pull model: %v", err) 110 | } 111 | 112 | // Verify progress output 113 | progressOutput := progressBuffer.String() 114 | if !strings.Contains(progressOutput, "Using cached model") && !strings.Contains(progressOutput, "Downloading") { 115 | t.Errorf("Progress output doesn't contain expected text: got %q", progressOutput) 116 | } 117 | 118 | model, err := client.GetModel(tag) 119 | if err != nil { 120 | t.Fatalf("Failed to get model: %v", err) 121 | } 122 | 123 | modelPath, err := model.GGUFPath() 124 | if err != nil { 125 | t.Fatalf("Failed to get model path: %v", err) 126 | } 127 | 128 | // Verify model content 129 | pulledContent, err := os.ReadFile(modelPath) 130 | if err != nil { 131 | t.Fatalf("Failed to read pulled model: %v", err) 132 | } 133 | 134 | if string(pulledContent) != string(modelContent) { 135 | t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) 136 | } 137 | }) 138 | 139 | t.Run("pull non-existent model", func(t *testing.T) { 140 | // Create temp directory for store 141 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 142 | if err != nil { 143 | t.Fatalf("Failed to create temp directory: %v", err) 144 | } 145 | defer os.RemoveAll(tempDir) 146 | 147 | // Create client 148 | client, err := NewClient(WithStoreRootPath(tempDir)) 149 | if err != nil { 150 | t.Fatalf("Failed to create client: %v", err) 151 | } 152 | 153 | // Create a buffer to capture progress output 154 | var progressBuffer bytes.Buffer 155 | 156 | // Test with non-existent repository 157 | nonExistentRef := registry + "/nonexistent/model:v1.0.0" 158 | err = client.PullModel(context.Background(), nonExistentRef, &progressBuffer) 159 | if err == nil { 160 | t.Fatal("Expected error for non-existent model, got nil") 161 | } 162 | 163 | // Verify it's a registry.Error 164 | var pullErr *mdregistry.Error 165 | ok := errors.As(err, &pullErr) 166 | if !ok { 167 | t.Fatalf("Expected PullError, got %T", err) 168 | } 169 | 170 | // Verify error fields 171 | if pullErr.Reference != nonExistentRef { 172 | t.Errorf("Expected reference %q, got %q", nonExistentRef, pullErr.Reference) 173 | } 174 | if pullErr.Code != "NAME_UNKNOWN" { 175 | t.Errorf("Expected error code MANIFEST_UNKNOWN, got %q", pullErr.Code) 176 | } 177 | if pullErr.Message != "Repository not found" { 178 | t.Errorf("Expected message '\"Repository not found', got %q", pullErr.Message) 179 | } 180 | if pullErr.Err == nil { 181 | t.Error("Expected underlying error to be non-nil") 182 | } 183 | if !errors.Is(pullErr, mdregistry.ErrModelNotFound) { 184 | t.Errorf("Expected underlying error to match ErrModelNotFound, got %v", pullErr.Err) 185 | } 186 | }) 187 | 188 | t.Run("pull with incomplete files", func(t *testing.T) { 189 | // Create temp directory for store 190 | tempDir, err := os.MkdirTemp("", "model-distribution-incomplete-test-*") 191 | if err != nil { 192 | t.Fatalf("Failed to create temp directory: %v", err) 193 | } 194 | defer os.RemoveAll(tempDir) 195 | 196 | // Create client 197 | client, err := NewClient(WithStoreRootPath(tempDir)) 198 | if err != nil { 199 | t.Fatalf("Failed to create client: %v", err) 200 | } 201 | 202 | // Use the dummy.gguf file from assets directory 203 | mdl, err := gguf.NewModel(testGGUFFile) 204 | if err != nil { 205 | t.Fatalf("Failed to create model: %v", err) 206 | } 207 | 208 | // Push model to local store 209 | tag := registry + "/incomplete-test/model:v1.0.0" 210 | if err := client.store.Write(mdl, []string{tag}, nil); err != nil { 211 | t.Fatalf("Failed to push model to store: %v", err) 212 | } 213 | 214 | // Push model to registry 215 | if err := client.PushModel(context.Background(), tag, nil); err != nil { 216 | t.Fatalf("Failed to pull model: %v", err) 217 | } 218 | 219 | // Get the model to find the GGUF path 220 | model, err := client.GetModel(tag) 221 | if err != nil { 222 | t.Fatalf("Failed to get model: %v", err) 223 | } 224 | 225 | ggufPath, err := model.GGUFPath() 226 | if err != nil { 227 | t.Fatalf("Failed to get GGUF path: %v", err) 228 | } 229 | 230 | // Create an incomplete file by copying the GGUF file and adding .incomplete suffix 231 | incompletePath := ggufPath + ".incomplete" 232 | originalContent, err := os.ReadFile(ggufPath) 233 | if err != nil { 234 | t.Fatalf("Failed to read GGUF file: %v", err) 235 | } 236 | 237 | // Write partial content to simulate an incomplete download 238 | partialContent := originalContent[:len(originalContent)/2] 239 | if err := os.WriteFile(incompletePath, partialContent, 0644); err != nil { 240 | t.Fatalf("Failed to create incomplete file: %v", err) 241 | } 242 | 243 | // Verify the incomplete file exists 244 | if _, err := os.Stat(incompletePath); os.IsNotExist(err) { 245 | t.Fatalf("Failed to create incomplete file: %v", err) 246 | } 247 | 248 | // Delete the local model to force a pull 249 | if err := client.DeleteModel(tag, false); err != nil { 250 | t.Fatalf("Failed to delete model: %v", err) 251 | } 252 | 253 | // Create a buffer to capture progress output 254 | var progressBuffer bytes.Buffer 255 | 256 | // Pull the model again - this should detect the incomplete file and pull again 257 | if err := client.PullModel(context.Background(), tag, &progressBuffer); err != nil { 258 | t.Fatalf("Failed to pull model: %v", err) 259 | } 260 | 261 | // Verify progress output indicates a new download, not using cached model 262 | progressOutput := progressBuffer.String() 263 | if strings.Contains(progressOutput, "Using cached model") { 264 | t.Errorf("Expected to pull model again due to incomplete file, but used cached model") 265 | } 266 | 267 | // Verify the incomplete file no longer exists 268 | if _, err := os.Stat(incompletePath); !os.IsNotExist(err) { 269 | t.Errorf("Incomplete file still exists after successful pull: %s", incompletePath) 270 | } 271 | 272 | // Verify the complete file exists 273 | if _, err := os.Stat(ggufPath); os.IsNotExist(err) { 274 | t.Errorf("GGUF file doesn't exist after pull: %s", ggufPath) 275 | } 276 | 277 | // Verify the content of the pulled file matches the original 278 | pulledContent, err := os.ReadFile(ggufPath) 279 | if err != nil { 280 | t.Fatalf("Failed to read pulled GGUF file: %v", err) 281 | } 282 | 283 | if !bytes.Equal(pulledContent, originalContent) { 284 | t.Errorf("Pulled content doesn't match original content") 285 | } 286 | }) 287 | 288 | t.Run("pull updated model with same tag", func(t *testing.T) { 289 | // Create temp directory for store 290 | tempDir, err := os.MkdirTemp("", "model-distribution-update-test-*") 291 | if err != nil { 292 | t.Fatalf("Failed to create temp directory: %v", err) 293 | } 294 | defer os.RemoveAll(tempDir) 295 | 296 | // Create client 297 | client, err := NewClient(WithStoreRootPath(tempDir)) 298 | if err != nil { 299 | t.Fatalf("Failed to create client: %v", err) 300 | } 301 | 302 | // Read model content for verification later 303 | modelContent, err := os.ReadFile(testGGUFFile) 304 | if err != nil { 305 | t.Fatalf("Failed to read test model file: %v", err) 306 | } 307 | 308 | // Push first version of model to registry 309 | tag := registry + "/update-test:v1.0.0" 310 | if err := writeToRegistry(testGGUFFile, tag); err != nil { 311 | t.Fatalf("Failed to push first version of model: %v", err) 312 | } 313 | 314 | // Pull first version of model 315 | if err := client.PullModel(context.Background(), tag, nil); err != nil { 316 | t.Fatalf("Failed to pull first version of model: %v", err) 317 | } 318 | 319 | // Verify first version is in local store 320 | model, err := client.GetModel(tag) 321 | if err != nil { 322 | t.Fatalf("Failed to get first version of model: %v", err) 323 | } 324 | 325 | modelPath, err := model.GGUFPath() 326 | if err != nil { 327 | t.Fatalf("Failed to get model path: %v", err) 328 | } 329 | 330 | // Verify first version content 331 | pulledContent, err := os.ReadFile(modelPath) 332 | if err != nil { 333 | t.Fatalf("Failed to read pulled model: %v", err) 334 | } 335 | 336 | if string(pulledContent) != string(modelContent) { 337 | t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) 338 | } 339 | 340 | // Create a modified version of the model 341 | updatedModelFile := filepath.Join(tempDir, "updated-dummy.gguf") 342 | updatedContent := append(modelContent, []byte("UPDATED CONTENT")...) 343 | if err := os.WriteFile(updatedModelFile, updatedContent, 0644); err != nil { 344 | t.Fatalf("Failed to create updated model file: %v", err) 345 | } 346 | 347 | // Push updated model with same tag 348 | if err := writeToRegistry(updatedModelFile, tag); err != nil { 349 | t.Fatalf("Failed to push updated model: %v", err) 350 | } 351 | 352 | // Create a buffer to capture progress output 353 | var progressBuffer bytes.Buffer 354 | 355 | // Pull model again - should get the updated version 356 | if err := client.PullModel(context.Background(), tag, &progressBuffer); err != nil { 357 | t.Fatalf("Failed to pull updated model: %v", err) 358 | } 359 | 360 | // Verify progress output indicates a new download, not using cached model 361 | progressOutput := progressBuffer.String() 362 | if strings.Contains(progressOutput, "Using cached model") { 363 | t.Errorf("Expected to pull updated model, but used cached model") 364 | } 365 | 366 | // Get the model again to verify it's the updated version 367 | updatedModel, err := client.GetModel(tag) 368 | if err != nil { 369 | t.Fatalf("Failed to get updated model: %v", err) 370 | } 371 | 372 | updatedModelPath, err := updatedModel.GGUFPath() 373 | if err != nil { 374 | t.Fatalf("Failed to get updated model path: %v", err) 375 | } 376 | 377 | // Verify updated content 378 | updatedPulledContent, err := os.ReadFile(updatedModelPath) 379 | if err != nil { 380 | t.Fatalf("Failed to read updated pulled model: %v", err) 381 | } 382 | 383 | if string(updatedPulledContent) != string(updatedContent) { 384 | t.Errorf("Updated pulled model content doesn't match: got %q, want %q", updatedPulledContent, updatedContent) 385 | } 386 | }) 387 | 388 | t.Run("pull unsupported (newer) version", func(t *testing.T) { 389 | newMdl := mutate.ConfigMediaType(model, "application/vnd.docker.ai.model.config.v0.2+json") 390 | // Push model to local store 391 | tag := registry + "/unsupported-test/model:v1.0.0" 392 | ref, err := name.ParseReference(tag) 393 | if err != nil { 394 | t.Fatalf("Failed to parse reference: %v", err) 395 | } 396 | if err := remote.Write(ref, newMdl); err != nil { 397 | t.Fatalf("Failed to push model: %v", err) 398 | } 399 | if err := client.PullModel(context.Background(), tag, nil); err == nil || !errors.Is(err, ErrUnsupportedMediaType) { 400 | t.Fatalf("Expected artifact version error, got %v", err) 401 | } 402 | }) 403 | 404 | t.Run("pull with JSON progress messages", func(t *testing.T) { 405 | // Create temp directory for store 406 | tempDir, err := os.MkdirTemp("", "model-distribution-json-test-*") 407 | if err != nil { 408 | t.Fatalf("Failed to create temp directory: %v", err) 409 | } 410 | defer os.RemoveAll(tempDir) 411 | 412 | // Create client 413 | client, err := NewClient(WithStoreRootPath(tempDir)) 414 | if err != nil { 415 | t.Fatalf("Failed to create client: %v", err) 416 | } 417 | 418 | // Create a buffer to capture progress output 419 | var progressBuffer bytes.Buffer 420 | 421 | // Pull model from registry with progress writer 422 | if err := client.PullModel(context.Background(), tag, &progressBuffer); err != nil { 423 | t.Fatalf("Failed to pull model: %v", err) 424 | } 425 | 426 | // Parse progress output as JSON 427 | var messages []progress.Message 428 | scanner := bufio.NewScanner(&progressBuffer) 429 | for scanner.Scan() { 430 | line := scanner.Text() 431 | var msg progress.Message 432 | if err := json.Unmarshal([]byte(line), &msg); err != nil { 433 | t.Fatalf("Failed to parse JSON progress message: %v, line: %s", err, line) 434 | } 435 | messages = append(messages, msg) 436 | } 437 | 438 | if err := scanner.Err(); err != nil { 439 | t.Fatalf("Error reading progress output: %v", err) 440 | } 441 | 442 | // Verify we got some messages 443 | if len(messages) == 0 { 444 | t.Fatal("No progress messages received") 445 | } 446 | 447 | // Check the last message is a success message 448 | lastMsg := messages[len(messages)-1] 449 | if lastMsg.Type != "success" { 450 | t.Errorf("Expected last message to be success, got type: %s, message: %s", lastMsg.Type, lastMsg.Message) 451 | } 452 | 453 | // Verify model was pulled correctly 454 | model, err := client.GetModel(tag) 455 | if err != nil { 456 | t.Fatalf("Failed to get model: %v", err) 457 | } 458 | 459 | modelPath, err := model.GGUFPath() 460 | if err != nil { 461 | t.Fatalf("Failed to get model path: %v", err) 462 | } 463 | 464 | // Verify model content 465 | pulledContent, err := os.ReadFile(modelPath) 466 | if err != nil { 467 | t.Fatalf("Failed to read pulled model: %v", err) 468 | } 469 | 470 | if string(pulledContent) != string(modelContent) { 471 | t.Errorf("Pulled model content doesn't match original") 472 | } 473 | }) 474 | 475 | t.Run("pull with error and JSON progress messages", func(t *testing.T) { 476 | // Create temp directory for store 477 | tempDir, err := os.MkdirTemp("", "model-distribution-json-error-test-*") 478 | if err != nil { 479 | t.Fatalf("Failed to create temp directory: %v", err) 480 | } 481 | defer os.RemoveAll(tempDir) 482 | 483 | // Create client 484 | client, err := NewClient(WithStoreRootPath(tempDir)) 485 | if err != nil { 486 | t.Fatalf("Failed to create client: %v", err) 487 | } 488 | 489 | // Create a buffer to capture progress output 490 | var progressBuffer bytes.Buffer 491 | 492 | // Test with non-existent model 493 | nonExistentRef := registry + "/nonexistent/model:v1.0.0" 494 | err = client.PullModel(context.Background(), nonExistentRef, &progressBuffer) 495 | 496 | // Expect an error 497 | if err == nil { 498 | t.Fatal("Expected error for non-existent model, got nil") 499 | } 500 | 501 | // Verify it matches registry.ErrModelNotFound 502 | if !errors.Is(err, mdregistry.ErrModelNotFound) { 503 | t.Fatalf("Expected registry.ErrModelNotFound, got %T", err) 504 | } 505 | 506 | // No JSON messages should be in the buffer for this error case 507 | // since the error happens before we start streaming progress 508 | }) 509 | } 510 | 511 | func TestClientGetModel(t *testing.T) { 512 | // Create temp directory for store 513 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 514 | if err != nil { 515 | t.Fatalf("Failed to create temp directory: %v", err) 516 | } 517 | defer os.RemoveAll(tempDir) 518 | 519 | // Create client 520 | client, err := NewClient(WithStoreRootPath(tempDir)) 521 | if err != nil { 522 | t.Fatalf("Failed to create client: %v", err) 523 | } 524 | 525 | // Create model from test GGUF file 526 | model, err := gguf.NewModel(testGGUFFile) 527 | if err != nil { 528 | t.Fatalf("Failed to create model: %v", err) 529 | } 530 | 531 | // Push model to local store 532 | tag := "test/model:v1.0.0" 533 | if err := client.store.Write(model, []string{tag}, nil); err != nil { 534 | t.Fatalf("Failed to push model to store: %v", err) 535 | } 536 | 537 | // Get model 538 | mi, err := client.GetModel(tag) 539 | if err != nil { 540 | t.Fatalf("Failed to get model: %v", err) 541 | } 542 | 543 | // Verify model 544 | if len(mi.Tags()) == 0 || mi.Tags()[0] != tag { 545 | t.Errorf("Model tags don't match: got %v, want [%s]", mi.Tags(), tag) 546 | } 547 | } 548 | 549 | func TestClientGetModelNotFound(t *testing.T) { 550 | // Create temp directory for store 551 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 552 | if err != nil { 553 | t.Fatalf("Failed to create temp directory: %v", err) 554 | } 555 | defer os.RemoveAll(tempDir) 556 | 557 | // Create client 558 | client, err := NewClient(WithStoreRootPath(tempDir)) 559 | if err != nil { 560 | t.Fatalf("Failed to create client: %v", err) 561 | } 562 | 563 | // Get non-existent model 564 | _, err = client.GetModel("nonexistent/model:v1.0.0") 565 | if !errors.Is(err, ErrModelNotFound) { 566 | t.Errorf("Expected ErrModelNotFound, got %v", err) 567 | } 568 | } 569 | 570 | func TestClientListModels(t *testing.T) { 571 | // Create temp directory for store 572 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 573 | if err != nil { 574 | t.Fatalf("Failed to create temp directory: %v", err) 575 | } 576 | defer os.RemoveAll(tempDir) 577 | 578 | // Create client 579 | client, err := NewClient(WithStoreRootPath(tempDir)) 580 | if err != nil { 581 | t.Fatalf("Failed to create client: %v", err) 582 | } 583 | 584 | // Create test model file 585 | modelContent := []byte("test model content") 586 | modelFile := filepath.Join(tempDir, "test-model.gguf") 587 | if err := os.WriteFile(modelFile, modelContent, 0644); err != nil { 588 | t.Fatalf("Failed to write test model file: %v", err) 589 | } 590 | 591 | mdl, err := gguf.NewModel(modelFile) 592 | if err != nil { 593 | t.Fatalf("Failed to create model: %v", err) 594 | } 595 | 596 | // Push models to local store with different manifest digests 597 | // First model 598 | tag1 := "test/model1:v1.0.0" 599 | if err := client.store.Write(mdl, []string{tag1}, nil); err != nil { 600 | t.Fatalf("Failed to push model to store: %v", err) 601 | } 602 | 603 | // Create a slightly different model file for the second model 604 | modelContent2 := []byte("test model content 2") 605 | modelFile2 := filepath.Join(tempDir, "test-model2.gguf") 606 | if err := os.WriteFile(modelFile2, modelContent2, 0644); err != nil { 607 | t.Fatalf("Failed to write test model file: %v", err) 608 | } 609 | mdl2, err := gguf.NewModel(modelFile2) 610 | if err != nil { 611 | t.Fatalf("Failed to create model: %v", err) 612 | } 613 | 614 | // Second model 615 | tag2 := "test/model2:v1.0.0" 616 | if err := client.store.Write(mdl2, []string{tag2}, nil); err != nil { 617 | t.Fatalf("Failed to push model to store: %v", err) 618 | } 619 | 620 | // Tags for verification 621 | tags := []string{tag1, tag2} 622 | 623 | // List models 624 | models, err := client.ListModels() 625 | if err != nil { 626 | t.Fatalf("Failed to list models: %v", err) 627 | } 628 | 629 | // Verify models 630 | if len(models) != len(tags) { 631 | t.Errorf("Expected %d models, got %d", len(tags), len(models)) 632 | } 633 | 634 | // Check if all tags are present 635 | tagMap := make(map[string]bool) 636 | for _, model := range models { 637 | for _, tag := range model.Tags() { 638 | tagMap[tag] = true 639 | } 640 | } 641 | 642 | for _, tag := range tags { 643 | if !tagMap[tag] { 644 | t.Errorf("Tag %s not found in models", tag) 645 | } 646 | } 647 | } 648 | 649 | func TestClientGetStorePath(t *testing.T) { 650 | // Create temp directory for store 651 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 652 | if err != nil { 653 | t.Fatalf("Failed to create temp directory: %v", err) 654 | } 655 | defer os.RemoveAll(tempDir) 656 | 657 | // Create client 658 | client, err := NewClient(WithStoreRootPath(tempDir)) 659 | if err != nil { 660 | t.Fatalf("Failed to create client: %v", err) 661 | } 662 | 663 | // Get store path 664 | storePath := client.GetStorePath() 665 | 666 | // Verify store path matches the temp directory 667 | if storePath != tempDir { 668 | t.Errorf("Store path doesn't match: got %s, want %s", storePath, tempDir) 669 | } 670 | 671 | // Verify the store directory exists 672 | if _, err := os.Stat(storePath); os.IsNotExist(err) { 673 | t.Errorf("Store directory does not exist: %s", storePath) 674 | } 675 | } 676 | 677 | func TestClientDefaultLogger(t *testing.T) { 678 | // Create temp directory for store 679 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 680 | if err != nil { 681 | t.Fatalf("Failed to create temp directory: %v", err) 682 | } 683 | defer os.RemoveAll(tempDir) 684 | 685 | // Create client without specifying logger 686 | client, err := NewClient(WithStoreRootPath(tempDir)) 687 | if err != nil { 688 | t.Fatalf("Failed to create client: %v", err) 689 | } 690 | 691 | // Verify that logger is not nil 692 | if client.log == nil { 693 | t.Error("Default logger should not be nil") 694 | } 695 | 696 | // Create client with custom logger 697 | customLogger := logrus.NewEntry(logrus.New()) 698 | client, err = NewClient( 699 | WithStoreRootPath(tempDir), 700 | WithLogger(customLogger), 701 | ) 702 | if err != nil { 703 | t.Fatalf("Failed to create client: %v", err) 704 | } 705 | 706 | // Verify that custom logger is used 707 | if client.log != customLogger { 708 | t.Error("Custom logger should be used when specified") 709 | } 710 | } 711 | 712 | func TestWithFunctionsNilChecks(t *testing.T) { 713 | // Create temp directory for store 714 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 715 | if err != nil { 716 | t.Fatalf("Failed to create temp directory: %v", err) 717 | } 718 | defer os.RemoveAll(tempDir) 719 | 720 | // Test WithStoreRootPath with empty string 721 | t.Run("WithStoreRootPath empty string", func(t *testing.T) { 722 | // Create options with a valid path first 723 | opts := defaultOptions() 724 | WithStoreRootPath(tempDir)(opts) 725 | 726 | // Then try to override with empty string 727 | WithStoreRootPath("")(opts) 728 | 729 | // Verify the path wasn't changed to empty 730 | if opts.storeRootPath != tempDir { 731 | t.Errorf("WithStoreRootPath with empty string changed the path: got %q, want %q", 732 | opts.storeRootPath, tempDir) 733 | } 734 | }) 735 | 736 | // Test WithLogger with nil 737 | t.Run("WithLogger nil", func(t *testing.T) { 738 | // Create options with default logger 739 | opts := defaultOptions() 740 | defaultLogger := opts.logger 741 | 742 | // Try to override with nil 743 | WithLogger(nil)(opts) 744 | 745 | // Verify the logger wasn't changed to nil 746 | if opts.logger == nil { 747 | t.Error("WithLogger with nil changed logger to nil") 748 | } 749 | 750 | // Verify it's still the default logger 751 | if opts.logger != defaultLogger { 752 | t.Error("WithLogger with nil changed the logger") 753 | } 754 | }) 755 | 756 | // Test WithTransport with nil 757 | t.Run("WithTransport nil", func(t *testing.T) { 758 | // Create options with default transport 759 | opts := defaultOptions() 760 | defaultTransport := opts.transport 761 | 762 | // Try to override with nil 763 | WithTransport(nil)(opts) 764 | 765 | // Verify the transport wasn't changed to nil 766 | if opts.transport == nil { 767 | t.Error("WithTransport with nil changed transport to nil") 768 | } 769 | 770 | // Verify it's still the default transport 771 | if opts.transport != defaultTransport { 772 | t.Error("WithTransport with nil changed the transport") 773 | } 774 | }) 775 | 776 | // Test WithUserAgent with empty string 777 | t.Run("WithUserAgent empty string", func(t *testing.T) { 778 | // Create options with default user agent 779 | opts := defaultOptions() 780 | defaultUA := opts.userAgent 781 | 782 | // Try to override with empty string 783 | WithUserAgent("")(opts) 784 | 785 | // Verify the user agent wasn't changed to empty 786 | if opts.userAgent == "" { 787 | t.Error("WithUserAgent with empty string changed user agent to empty") 788 | } 789 | 790 | // Verify it's still the default user agent 791 | if opts.userAgent != defaultUA { 792 | t.Errorf("WithUserAgent with empty string changed the user agent: got %q, want %q", 793 | opts.userAgent, defaultUA) 794 | } 795 | }) 796 | } 797 | 798 | func TestNewReferenceError(t *testing.T) { 799 | // Create temp directory for store 800 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 801 | if err != nil { 802 | t.Fatalf("Failed to create temp directory: %v", err) 803 | } 804 | defer os.RemoveAll(tempDir) 805 | 806 | // Create client 807 | client, err := NewClient(WithStoreRootPath(tempDir)) 808 | if err != nil { 809 | t.Fatalf("Failed to create client: %v", err) 810 | } 811 | 812 | // Test with invalid reference 813 | invalidRef := "invalid:reference:format" 814 | err = client.PullModel(context.Background(), invalidRef, nil) 815 | if err == nil { 816 | t.Fatal("Expected error for invalid reference, got nil") 817 | } 818 | 819 | if !errors.Is(err, ErrInvalidReference) { 820 | t.Fatalf("Expected error to match sentinel invalid reference error, got %v", err) 821 | } 822 | } 823 | 824 | func TestPush(t *testing.T) { 825 | // Create temp directory for store 826 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 827 | if err != nil { 828 | t.Fatalf("Failed to create temp directory: %v", err) 829 | } 830 | defer os.RemoveAll(tempDir) 831 | 832 | // Create client 833 | client, err := NewClient(WithStoreRootPath(tempDir)) 834 | if err != nil { 835 | t.Fatalf("Failed to create client: %v", err) 836 | } 837 | 838 | // Create a test registry 839 | server := httptest.NewServer(registry.New()) 840 | defer server.Close() 841 | 842 | // Create a tag for the model 843 | uri, err := url.Parse(server.URL) 844 | if err != nil { 845 | t.Fatalf("Failed to parse registry URL: %v", err) 846 | } 847 | tag := uri.Host + "/incomplete-test/model:v1.0.0" 848 | 849 | // Write a test model to the store with the given tag 850 | mdl, err := gguf.NewModel(testGGUFFile) 851 | if err != nil { 852 | t.Fatalf("Failed to create model: %v", err) 853 | } 854 | digest, err := mdl.ID() 855 | if err != nil { 856 | t.Fatalf("Failed to get digest of original model: %v", err) 857 | } 858 | 859 | if err := client.store.Write(mdl, []string{tag}, nil); err != nil { 860 | t.Fatalf("Failed to push model to store: %v", err) 861 | } 862 | 863 | // Push the model to the registry 864 | if err := client.PushModel(context.Background(), tag, nil); err != nil { 865 | t.Fatalf("Failed to push model: %v", err) 866 | } 867 | 868 | // Delete local copy (so we can test pulling) 869 | if err := client.DeleteModel(tag, false); err != nil { 870 | t.Fatalf("Failed to delete model: %v", err) 871 | } 872 | 873 | // Test that model can be pulled successfully 874 | if err := client.PullModel(context.Background(), tag, nil); err != nil { 875 | t.Fatalf("Failed to pull model: %v", err) 876 | } 877 | 878 | // Test that model the pulled model is the same as the original (matching digests) 879 | mdl2, err := client.GetModel(tag) 880 | if err != nil { 881 | t.Fatalf("Failed to get pulled model: %v", err) 882 | } 883 | digest2, err := mdl2.ID() 884 | if err != nil { 885 | t.Fatalf("Failed to get digest of the pulled model: %v", err) 886 | } 887 | if digest != digest2 { 888 | t.Fatalf("Digests don't match: got %s, want %s", digest2, digest) 889 | } 890 | } 891 | 892 | func TestPushProgress(t *testing.T) { 893 | // Create temp directory for store 894 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 895 | if err != nil { 896 | t.Fatalf("Failed to create temp directory: %v", err) 897 | } 898 | defer os.RemoveAll(tempDir) 899 | 900 | // Create client 901 | client, err := NewClient(WithStoreRootPath(tempDir)) 902 | if err != nil { 903 | t.Fatalf("Failed to create client: %v", err) 904 | } 905 | 906 | // Create a test registry 907 | server := httptest.NewServer(registry.New()) 908 | defer server.Close() 909 | 910 | // Create a tag for the model 911 | uri, err := url.Parse(server.URL) 912 | if err != nil { 913 | t.Fatalf("Failed to parse registry URL: %v", err) 914 | } 915 | tag := uri.Host + "/some/model/repo:some-tag" 916 | 917 | // Create random "model" of a given size - make it large enough to ensure multiple updates 918 | // We want at least 2MB to ensure we get both time-based and byte-based updates 919 | sz := int64(progress.MinBytesForUpdate * 2) 920 | path, err := randomFile(sz) 921 | if err != nil { 922 | t.Fatalf("Failed to create temp file: %v", err) 923 | } 924 | defer os.Remove(path) 925 | 926 | mdl, err := gguf.NewModel(path) 927 | if err != nil { 928 | t.Fatalf("Failed to create model: %v", err) 929 | } 930 | 931 | if err := client.store.Write(mdl, []string{tag}, nil); err != nil { 932 | t.Fatalf("Failed to write model to store: %v", err) 933 | } 934 | 935 | // Create a buffer to capture progress output 936 | pr, pw := io.Pipe() 937 | done := make(chan error, 1) 938 | go func() { 939 | defer pw.Close() 940 | done <- client.PushModel(t.Context(), tag, pw) 941 | close(done) 942 | }() 943 | 944 | var lines []string 945 | sc := bufio.NewScanner(pr) 946 | for sc.Scan() { 947 | line := sc.Text() 948 | t.Log(line) 949 | lines = append(lines, line) 950 | } 951 | 952 | // Wait for the push to complete 953 | if err := <-done; err != nil { 954 | t.Fatalf("Failed to push model: %v", err) 955 | } 956 | 957 | // Verify we got at least 3 messages (2 progress + 1 success) 958 | if len(lines) < 3 { 959 | t.Fatalf("Expected at least 3 progress messages, got %d", len(lines)) 960 | } 961 | 962 | // Verify the last two messages 963 | lastTwo := lines[len(lines)-2:] 964 | if !strings.Contains(lastTwo[0], "Uploaded:") { 965 | t.Fatalf("Expected progress message to contain 'Uploaded: x MB', got %q", lastTwo[0]) 966 | } 967 | if !strings.Contains(lastTwo[1], "success") { 968 | t.Fatalf("Expected last progress message to contain 'success', got %q", lastTwo[1]) 969 | } 970 | } 971 | 972 | func TestTag(t *testing.T) { 973 | // Create temp directory for store 974 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 975 | if err != nil { 976 | t.Fatalf("Failed to create temp directory: %v", err) 977 | } 978 | defer os.RemoveAll(tempDir) 979 | 980 | // Create client 981 | client, err := NewClient(WithStoreRootPath(tempDir)) 982 | if err != nil { 983 | t.Fatalf("Failed to create client: %v", err) 984 | } 985 | 986 | // Create a test model 987 | model, err := gguf.NewModel(testGGUFFile) 988 | if err != nil { 989 | t.Fatalf("Failed to create model: %v", err) 990 | } 991 | id, err := model.ID() 992 | if err != nil { 993 | t.Fatalf("Failed to get model ID: %v", err) 994 | } 995 | 996 | // Push the model to the store 997 | if err := client.store.Write(model, []string{"some-repo:some-tag"}, nil); err != nil { 998 | t.Fatalf("Failed to push model to store: %v", err) 999 | } 1000 | 1001 | // Tag the model by ID 1002 | if err := client.Tag(id, "other-repo:tag1"); err != nil { 1003 | t.Fatalf("Failed to tag model %q: %v", id, err) 1004 | } 1005 | 1006 | // Tag the model by tag 1007 | if err := client.Tag(id, "other-repo:tag2"); err != nil { 1008 | t.Fatalf("Failed to tag model %q: %v", id, err) 1009 | } 1010 | 1011 | // Verify the model has all 3 tags 1012 | modelInfo, err := client.GetModel("some-repo:some-tag") 1013 | if err != nil { 1014 | t.Fatalf("Failed to get model: %v", err) 1015 | } 1016 | 1017 | if len(modelInfo.Tags()) != 3 { 1018 | t.Fatalf("Expected 3 tags, got %d", len(modelInfo.Tags())) 1019 | } 1020 | 1021 | // Verify the model can be accessed by new tags 1022 | if _, err := client.GetModel("other-repo:tag1"); err != nil { 1023 | t.Fatalf("Failed to get model by tag: %v", err) 1024 | } 1025 | if _, err := client.GetModel("other-repo:tag2"); err != nil { 1026 | t.Fatalf("Failed to get model by tag: %v", err) 1027 | } 1028 | } 1029 | 1030 | func TestTagNotFound(t *testing.T) { 1031 | // Create temp directory for store 1032 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 1033 | if err != nil { 1034 | t.Fatalf("Failed to create temp directory: %v", err) 1035 | } 1036 | defer os.RemoveAll(tempDir) 1037 | 1038 | // Create client 1039 | client, err := NewClient(WithStoreRootPath(tempDir)) 1040 | if err != nil { 1041 | t.Fatalf("Failed to create client: %v", err) 1042 | } 1043 | 1044 | // Tag the model by ID 1045 | if err := client.Tag("non-existent-model:latest", "other-repo:tag1"); !errors.Is(err, ErrModelNotFound) { 1046 | t.Fatalf("Expected ErrModelNotFound, got: %v", err) 1047 | } 1048 | } 1049 | 1050 | func TestClientPushModelNotFound(t *testing.T) { 1051 | // Create temp directory for store 1052 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 1053 | if err != nil { 1054 | t.Fatalf("Failed to create temp directory: %v", err) 1055 | } 1056 | defer os.RemoveAll(tempDir) 1057 | 1058 | // Create client 1059 | client, err := NewClient(WithStoreRootPath(tempDir)) 1060 | if err != nil { 1061 | t.Fatalf("Failed to create client: %v", err) 1062 | } 1063 | 1064 | if err := client.PushModel(t.Context(), "non-existent-model:latest", nil); !errors.Is(err, ErrModelNotFound) { 1065 | t.Fatalf("Expected ErrModelNotFound got: %v", err) 1066 | } 1067 | } 1068 | 1069 | // writeToRegistry writes a GGUF model to a registry. 1070 | func writeToRegistry(source, reference string) error { 1071 | 1072 | // Parse the reference 1073 | ref, err := name.ParseReference(reference) 1074 | if err != nil { 1075 | return fmt.Errorf("parse ref: %w", err) 1076 | } 1077 | 1078 | // Create image with layer 1079 | mdl, err := gguf.NewModel(source) 1080 | if err != nil { 1081 | return fmt.Errorf("new model: %w", err) 1082 | } 1083 | 1084 | // Push the image 1085 | if err := remote.Write(ref, mdl); err != nil { 1086 | return fmt.Errorf("write: %w", err) 1087 | } 1088 | 1089 | return nil 1090 | } 1091 | 1092 | func randomFile(size int64) (string, error) { 1093 | // Create a temporary "gguf" file 1094 | f, err := os.CreateTemp("", "test-*.gguf") 1095 | if err != nil { 1096 | panic(fmt.Sprintf("Failed to create temp file: %v", err)) 1097 | } 1098 | defer f.Close() 1099 | 1100 | // Fill with random data 1101 | if _, err := io.Copy(f, io.LimitReader(rand.Reader, size)); err != nil { 1102 | return "", fmt.Errorf("Failed to write random data: %v", err) 1103 | } 1104 | 1105 | return f.Name(), nil 1106 | } 1107 | -------------------------------------------------------------------------------- /distribution/delete_test.go: -------------------------------------------------------------------------------- 1 | package distribution 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | "testing" 7 | 8 | "github.com/docker/model-distribution/internal/gguf" 9 | ) 10 | 11 | func TestDeleteModel(t *testing.T) { 12 | // Create temp directory for store 13 | tempDir, err := os.MkdirTemp("", "model-distribution-test-*") 14 | if err != nil { 15 | t.Fatalf("Failed to create temp directory: %v", err) 16 | } 17 | defer os.RemoveAll(tempDir) 18 | 19 | // Create client 20 | client, err := NewClient(WithStoreRootPath(tempDir)) 21 | if err != nil { 22 | t.Fatalf("Failed to create client: %v", err) 23 | } 24 | 25 | // Use the dummy.gguf file from assets directory 26 | mdl, err := gguf.NewModel(testGGUFFile) 27 | if err != nil { 28 | t.Fatalf("Failed to create model: %v", err) 29 | } 30 | id, err := mdl.ID() 31 | if err != nil { 32 | t.Fatalf("Failed to get model ID: %v", err) 33 | } 34 | if err := client.store.Write(mdl, []string{}, nil); err != nil { 35 | t.Fatalf("Failed to write model to store: %v", err) 36 | } 37 | 38 | type testCase struct { 39 | ref string // ref to delete by (id or tag) 40 | tags []string // applied tags 41 | force bool 42 | expectedErr error 43 | description string 44 | untagOnly bool 45 | } 46 | 47 | tcs := []testCase{ 48 | { 49 | ref: id, 50 | description: "untagged, by ID", 51 | }, 52 | { 53 | ref: id, 54 | force: true, 55 | description: "untagged, by ID, with force", 56 | }, 57 | { 58 | ref: id, 59 | tags: []string{"some-repo:some-tag"}, 60 | description: "one tag, by ID", 61 | }, 62 | { 63 | ref: "some-repo:some-tag", 64 | tags: []string{"some-repo:some-tag"}, 65 | description: "one tag, by tag", 66 | }, 67 | { 68 | ref: id, 69 | tags: []string{"some-repo:some-tag"}, 70 | force: true, 71 | description: "one tag, by ID, with force", 72 | }, 73 | { 74 | ref: "some-repo:some-tag", 75 | tags: []string{"some-repo:some-tag"}, 76 | force: true, 77 | description: "one tag, by tag, with force", 78 | }, 79 | { 80 | ref: id, 81 | tags: []string{"some-repo:some-tag", "other-repo:other-tag"}, 82 | expectedErr: ErrConflict, 83 | description: "multiple tags, by ID", 84 | }, 85 | { 86 | ref: id, 87 | tags: []string{"some-repo:some-tag", "other-repo:other-tag"}, 88 | force: true, 89 | description: "multiple tags, by ID, with force", 90 | }, 91 | { 92 | ref: "some-repo:some-tag", 93 | tags: []string{"some-repo:some-tag", "other-repo:other-tag"}, 94 | untagOnly: true, 95 | description: "multiple tags, by tag", 96 | }, 97 | { 98 | ref: "some-repo:some-tag", 99 | tags: []string{"some-repo:some-tag", "other-repo:other-tag"}, 100 | force: true, 101 | untagOnly: true, 102 | description: "multiple tags, by tag, with force", 103 | }, 104 | { 105 | ref: "not-existing:tag", 106 | tags: []string{}, 107 | expectedErr: ErrModelNotFound, 108 | description: "no such model", 109 | }, 110 | } 111 | 112 | for _, tc := range tcs { 113 | t.Run(tc.description, func(t *testing.T) { 114 | // Setup model with tags 115 | if err := client.store.Write(mdl, []string{}, nil); err != nil { 116 | t.Fatalf("Failed to write model to store: %v", err) 117 | } 118 | for _, tag := range tc.tags { 119 | if err := client.Tag(id, tag); err != nil { 120 | t.Fatalf("Failed to tag model: %v", err) 121 | } 122 | } 123 | 124 | // Attempt to delete the model and check for expected error 125 | if err := client.DeleteModel(tc.ref, tc.force); !errors.Is(err, tc.expectedErr) { 126 | t.Fatalf("Expected error %v, got: %v", tc.expectedErr, err) 127 | } 128 | if tc.expectedErr != nil { 129 | return 130 | } 131 | 132 | // Verify model ref unreachable by ref (untagged) 133 | _, err = client.GetModel(tc.ref) 134 | if !errors.Is(err, ErrModelNotFound) { 135 | t.Errorf("Expected ErrModelNotFound after deletion, got %v", err) 136 | } 137 | 138 | // Verify if underlying model is deleted 139 | if _, err = client.GetModel(id); !tc.untagOnly && !errors.Is(err, ErrModelNotFound) { 140 | t.Errorf("Expected ErrModelNotFound after deletion, got %v", err) 141 | } else if tc.untagOnly && err != nil { 142 | t.Errorf("Expected model to remain but was deleted") 143 | } 144 | }) 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /distribution/ecr_test.go: -------------------------------------------------------------------------------- 1 | package distribution 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/docker/model-distribution/internal/gguf" 9 | ) 10 | 11 | func TestECRIntegration(t *testing.T) { 12 | // Skip if ECR integration is not enabled 13 | if os.Getenv("TEST_ECR_ENABLED") != "true" { 14 | t.Skip("Skipping ECR integration test") 15 | } 16 | 17 | // Get ECR tag from environment 18 | ecrTag := os.Getenv("TEST_ECR_TAG") 19 | if ecrTag == "" { 20 | t.Fatal("TEST_ECR_TAG environment variable is required") 21 | } 22 | 23 | // Create temp directory for store 24 | tempDir, err := os.MkdirTemp("", "model-distribution-ecr-test-*") 25 | if err != nil { 26 | t.Fatalf("Failed to create temp directory: %v", err) 27 | } 28 | defer os.RemoveAll(tempDir) 29 | 30 | // Create client 31 | client, err := NewClient(WithStoreRootPath(tempDir)) 32 | if err != nil { 33 | t.Fatalf("Failed to create client: %v", err) 34 | } 35 | 36 | // Read test model file 37 | modelFile := "../assets/dummy.gguf" 38 | modelContent, err := os.ReadFile(modelFile) 39 | if err != nil { 40 | t.Fatalf("Failed to read test model file: %v", err) 41 | } 42 | 43 | t.Run("Push", func(t *testing.T) { 44 | mdl, err := gguf.NewModel(testGGUFFile) 45 | if err != nil { 46 | t.Fatalf("Failed to create model: %v", err) 47 | } 48 | if err := client.store.Write(mdl, []string{ecrTag}, nil); err != nil { 49 | t.Fatalf("Failed to write model to store: %v", err) 50 | } 51 | if err := client.PushModel(context.Background(), ecrTag, nil); err != nil { 52 | t.Fatalf("Failed to push model to ECR: %v", err) 53 | } 54 | if err := client.DeleteModel(ecrTag, false); err != nil { // cleanup 55 | t.Fatalf("Failed to delete model from store: %v", err) 56 | } 57 | }) 58 | 59 | // Test pull from ECR 60 | t.Run("Pull without progress", func(t *testing.T) { 61 | err := client.PullModel(context.Background(), ecrTag, nil) 62 | if err != nil { 63 | t.Fatalf("Failed to pull model from ECR: %v", err) 64 | } 65 | 66 | model, err := client.GetModel(ecrTag) 67 | if err != nil { 68 | t.Fatalf("Failed to get model: %v", err) 69 | } 70 | 71 | modelPath, err := model.GGUFPath() 72 | if err != nil { 73 | t.Fatalf("Failed to get model path: %v", err) 74 | } 75 | 76 | defer os.Remove(modelPath) 77 | 78 | // Verify model content 79 | pulledContent, err := os.ReadFile(modelPath) 80 | if err != nil { 81 | t.Fatalf("Failed to read pulled model: %v", err) 82 | } 83 | 84 | if string(pulledContent) != string(modelContent) { 85 | t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) 86 | } 87 | }) 88 | 89 | // Test get model info 90 | t.Run("GetModel", func(t *testing.T) { 91 | model, err := client.GetModel(ecrTag) 92 | if err != nil { 93 | t.Fatalf("Failed to get model info: %v", err) 94 | } 95 | 96 | if len(model.Tags()) == 0 || model.Tags()[0] != ecrTag { 97 | t.Errorf("Model tags don't match: got %v, want [%s]", model.Tags(), ecrTag) 98 | } 99 | }) 100 | } 101 | -------------------------------------------------------------------------------- /distribution/errors.go: -------------------------------------------------------------------------------- 1 | package distribution 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/docker/model-distribution/internal/store" 8 | "github.com/docker/model-distribution/registry" 9 | "github.com/docker/model-distribution/types" 10 | ) 11 | 12 | var ( 13 | ErrInvalidReference = registry.ErrInvalidReference 14 | ErrModelNotFound = store.ErrModelNotFound // model not found in store 15 | ErrUnsupportedMediaType = errors.New(fmt.Sprintf( 16 | "client supports only models of type %q and older - try upgrading", 17 | types.MediaTypeModelConfigV01, 18 | )) 19 | ErrConflict = errors.New("resource conflict") 20 | ) 21 | 22 | // ReferenceError represents an error related to an invalid model reference 23 | type ReferenceError struct { 24 | Reference string 25 | Err error 26 | } 27 | 28 | func (e *ReferenceError) Error() string { 29 | return fmt.Sprintf("invalid model reference %q: %v", e.Reference, e.Err) 30 | } 31 | 32 | func (e *ReferenceError) Unwrap() error { 33 | return e.Err 34 | } 35 | 36 | // Is implements error matching for ReferenceError 37 | func (e *ReferenceError) Is(target error) bool { 38 | return target == ErrInvalidReference 39 | } 40 | -------------------------------------------------------------------------------- /distribution/gar_test.go: -------------------------------------------------------------------------------- 1 | package distribution 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/docker/model-distribution/internal/gguf" 9 | ) 10 | 11 | func TestGARIntegration(t *testing.T) { 12 | // Skip if GAR integration is not enabled 13 | if os.Getenv("TEST_GAR_ENABLED") != "true" { 14 | t.Skip("Skipping GAR integration test") 15 | } 16 | 17 | // Get GAR tag from environment 18 | garTag := os.Getenv("TEST_GAR_TAG") 19 | if garTag == "" { 20 | t.Fatal("TEST_GAR_TAG environment variable is required") 21 | } 22 | 23 | // Create temp directory for store 24 | tempDir, err := os.MkdirTemp("", "model-distribution-gar-test-*") 25 | if err != nil { 26 | t.Fatalf("Failed to create temp directory: %v", err) 27 | } 28 | defer os.RemoveAll(tempDir) 29 | 30 | // Create client 31 | client, err := NewClient(WithStoreRootPath(tempDir)) 32 | if err != nil { 33 | t.Fatalf("Failed to create client: %v", err) 34 | } 35 | 36 | // Read test model file 37 | modelFile := "../assets/dummy.gguf" 38 | modelContent, err := os.ReadFile(modelFile) 39 | if err != nil { 40 | t.Fatalf("Failed to read test model file: %v", err) 41 | } 42 | 43 | // Test push to GAR 44 | t.Run("Push", func(t *testing.T) { 45 | mdl, err := gguf.NewModel(testGGUFFile) 46 | if err != nil { 47 | t.Fatalf("Failed to create model: %v", err) 48 | } 49 | if err := client.store.Write(mdl, []string{garTag}, nil); err != nil { 50 | t.Fatalf("Failed to write model to store: %v", err) 51 | } 52 | if err := client.PushModel(context.Background(), garTag, nil); err != nil { 53 | t.Fatalf("Failed to push model to ECR: %v", err) 54 | } 55 | if err := client.DeleteModel(garTag, false); err != nil { // cleanup 56 | t.Fatalf("Failed to delete model from store: %v", err) 57 | } 58 | }) 59 | 60 | // Test pull from GAR 61 | t.Run("Pull without progress", func(t *testing.T) { 62 | err := client.PullModel(context.Background(), garTag, nil) 63 | if err != nil { 64 | t.Fatalf("Failed to pull model from GAR: %v", err) 65 | } 66 | 67 | model, err := client.GetModel(garTag) 68 | if err != nil { 69 | t.Fatalf("Failed to get model: %v", err) 70 | } 71 | 72 | modelPath, err := model.GGUFPath() 73 | if err != nil { 74 | t.Fatalf("Failed to get model path: %v", err) 75 | } 76 | defer os.Remove(modelPath) 77 | 78 | // Verify model content 79 | pulledContent, err := os.ReadFile(modelPath) 80 | if err != nil { 81 | t.Fatalf("Failed to read pulled model: %v", err) 82 | } 83 | 84 | if string(pulledContent) != string(modelContent) { 85 | t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent) 86 | } 87 | }) 88 | 89 | // Test get model info 90 | t.Run("GetModel", func(t *testing.T) { 91 | model, err := client.GetModel(garTag) 92 | if err != nil { 93 | t.Fatalf("Failed to get model info: %v", err) 94 | } 95 | 96 | if len(model.Tags()) == 0 || model.Tags()[0] != garTag { 97 | t.Errorf("Model tags don't match: got %v, want [%s]", model.Tags(), garTag) 98 | } 99 | }) 100 | } 101 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/docker/model-distribution 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.1 6 | 7 | require ( 8 | github.com/google/go-containerregistry v0.20.3 9 | github.com/gpustack/gguf-parser-go v0.14.1 10 | github.com/pkg/errors v0.9.1 11 | github.com/sirupsen/logrus v1.9.3 12 | ) 13 | 14 | require ( 15 | github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect 16 | github.com/docker/cli v27.5.0+incompatible // indirect 17 | github.com/docker/distribution v2.8.3+incompatible // indirect 18 | github.com/docker/docker-credential-helpers v0.8.2 // indirect 19 | github.com/google/go-cmp v0.7.0 // indirect 20 | github.com/henvic/httpretty v0.1.4 // indirect 21 | github.com/json-iterator/go v1.1.12 // indirect 22 | github.com/klauspost/compress v1.17.11 // indirect 23 | github.com/mitchellh/go-homedir v1.1.0 // indirect 24 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 25 | github.com/modern-go/reflect2 v1.0.2 // indirect 26 | github.com/opencontainers/go-digest v1.0.0 // indirect 27 | github.com/opencontainers/image-spec v1.1.1 // indirect 28 | github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 // indirect 29 | github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d // indirect 30 | github.com/stretchr/testify v1.10.0 // indirect 31 | github.com/vbatts/tar-split v0.11.6 // indirect 32 | golang.org/x/crypto v0.35.0 // indirect 33 | golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f // indirect 34 | golang.org/x/mod v0.22.0 // indirect 35 | golang.org/x/sync v0.10.0 // indirect 36 | golang.org/x/sys v0.31.0 // indirect 37 | golang.org/x/tools v0.29.0 // indirect 38 | gonum.org/v1/gonum v0.15.1 // indirect 39 | gotest.tools/v3 v3.5.1 // indirect 40 | ) 41 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= 2 | github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/docker/cli v27.5.0+incompatible h1:aMphQkcGtpHixwwhAXJT1rrK/detk2JIvDaFkLctbGM= 7 | github.com/docker/cli v27.5.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= 8 | github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= 9 | github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= 10 | github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZSeVMrFgOr3T+zrFAo= 11 | github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M= 12 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 13 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 14 | github.com/google/go-containerregistry v0.20.3 h1:oNx7IdTI936V8CQRveCjaxOiegWwvM7kqkbXTpyiovI= 15 | github.com/google/go-containerregistry v0.20.3/go.mod h1:w00pIgBRDVUDFM6bq+Qx8lwNWK+cxgCuX1vd3PIBDNI= 16 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 17 | github.com/gpustack/gguf-parser-go v0.14.1 h1:tmz2eTnSEFfE52V10FESqo9oAUquZ6JKQFntWC/wrEg= 18 | github.com/gpustack/gguf-parser-go v0.14.1/go.mod h1:GvHh1Kvvq5ojCOsJ5UpwiJJmIjFw3Qk5cW7R+CZ3IJo= 19 | github.com/henvic/httpretty v0.1.4 h1:Jo7uwIRWVFxkqOnErcoYfH90o3ddQyVrSANeS4cxYmU= 20 | github.com/henvic/httpretty v0.1.4/go.mod h1:Dn60sQTZfbt2dYsdUSNsCljyF4AfdqnuJFDLJA1I4AM= 21 | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= 22 | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= 23 | github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= 24 | github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= 25 | github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= 26 | github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= 27 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 28 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= 29 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 30 | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= 31 | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= 32 | github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= 33 | github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= 34 | github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= 35 | github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= 36 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 37 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 38 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 39 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 40 | github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 h1:18kd+8ZUlt/ARXhljq+14TwAoKa61q6dX8jtwOf6DH8= 41 | github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529/go.mod h1:qe5TWALJ8/a1Lqznoc5BDHpYX/8HU60Hm2AwRmqzxqA= 42 | github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= 43 | github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= 44 | github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d h1:3VwvTjiRPA7cqtgOWddEL+JrcijMlXUmj99c/6YyZoY= 45 | github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= 46 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 47 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 48 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 49 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 50 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 51 | github.com/vbatts/tar-split v0.11.6 h1:4SjTW5+PU11n6fZenf2IPoV8/tz3AaYHMWjf23envGs= 52 | github.com/vbatts/tar-split v0.11.6/go.mod h1:dqKNtesIOr2j2Qv3W/cHjnvk9I8+G7oAkFDFN6TCBEI= 53 | golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= 54 | golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= 55 | golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= 56 | golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= 57 | golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= 58 | golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= 59 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 60 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= 61 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 62 | golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 63 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= 64 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 65 | golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= 66 | golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= 67 | golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= 68 | golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588= 69 | gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0= 70 | gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o= 71 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 72 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 73 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 74 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 75 | gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= 76 | gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= 77 | -------------------------------------------------------------------------------- /internal/gguf/create.go: -------------------------------------------------------------------------------- 1 | package gguf 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | 8 | v1 "github.com/google/go-containerregistry/pkg/v1" 9 | parser "github.com/gpustack/gguf-parser-go" 10 | 11 | "github.com/docker/model-distribution/internal/partial" 12 | "github.com/docker/model-distribution/types" 13 | ) 14 | 15 | func NewModel(path string) (*Model, error) { 16 | layer, err := partial.NewLayer(path, types.MediaTypeGGUF) 17 | if err != nil { 18 | return nil, fmt.Errorf("create gguf layer: %w", err) 19 | } 20 | diffID, err := layer.DiffID() 21 | if err != nil { 22 | return nil, fmt.Errorf("get gguf layer diffID: %w", err) 23 | } 24 | 25 | created := time.Now() 26 | return &Model{ 27 | configFile: types.ConfigFile{ 28 | Config: configFromFile(path), 29 | Descriptor: types.Descriptor{ 30 | Created: &created, 31 | }, 32 | RootFS: v1.RootFS{ 33 | Type: "rootfs", 34 | DiffIDs: []v1.Hash{ 35 | diffID, 36 | }, 37 | }, 38 | }, 39 | layers: []v1.Layer{layer}, 40 | }, nil 41 | } 42 | 43 | func configFromFile(path string) types.Config { 44 | gguf, err := parser.ParseGGUFFile(path) 45 | if err != nil { 46 | return types.Config{} // continue without metadata 47 | } 48 | return types.Config{ 49 | Format: types.FormatGGUF, 50 | Parameters: strings.TrimSpace(gguf.Metadata().Parameters.String()), 51 | Architecture: strings.TrimSpace(gguf.Metadata().Architecture), 52 | Quantization: strings.TrimSpace(gguf.Metadata().FileType.String()), 53 | Size: strings.TrimSpace(gguf.Metadata().Size.String()), 54 | GGUF: extractGGUFMetadata(&gguf.Header), 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /internal/gguf/metadata.go: -------------------------------------------------------------------------------- 1 | package gguf 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | parser "github.com/gpustack/gguf-parser-go" 8 | ) 9 | 10 | const maxArraySize = 50 11 | 12 | // extractGGUFMetadata converts the GGUF header metadata into a string map. 13 | func extractGGUFMetadata(header *parser.GGUFHeader) map[string]string { 14 | metadata := make(map[string]string) 15 | 16 | for _, kv := range header.MetadataKV { 17 | if kv.ValueType == parser.GGUFMetadataValueTypeArray { 18 | arrayValue := kv.ValueArray() 19 | if arrayValue.Len > maxArraySize { 20 | continue 21 | } 22 | } 23 | var value string 24 | switch kv.ValueType { 25 | case parser.GGUFMetadataValueTypeUint8: 26 | value = fmt.Sprintf("%d", kv.ValueUint8()) 27 | case parser.GGUFMetadataValueTypeInt8: 28 | value = fmt.Sprintf("%d", kv.ValueInt8()) 29 | case parser.GGUFMetadataValueTypeUint16: 30 | value = fmt.Sprintf("%d", kv.ValueUint16()) 31 | case parser.GGUFMetadataValueTypeInt16: 32 | value = fmt.Sprintf("%d", kv.ValueInt16()) 33 | case parser.GGUFMetadataValueTypeUint32: 34 | value = fmt.Sprintf("%d", kv.ValueUint32()) 35 | case parser.GGUFMetadataValueTypeInt32: 36 | value = fmt.Sprintf("%d", kv.ValueInt32()) 37 | case parser.GGUFMetadataValueTypeUint64: 38 | value = fmt.Sprintf("%d", kv.ValueUint64()) 39 | case parser.GGUFMetadataValueTypeInt64: 40 | value = fmt.Sprintf("%d", kv.ValueInt64()) 41 | case parser.GGUFMetadataValueTypeFloat32: 42 | value = fmt.Sprintf("%f", kv.ValueFloat32()) 43 | case parser.GGUFMetadataValueTypeFloat64: 44 | value = fmt.Sprintf("%f", kv.ValueFloat64()) 45 | case parser.GGUFMetadataValueTypeBool: 46 | value = fmt.Sprintf("%t", kv.ValueBool()) 47 | case parser.GGUFMetadataValueTypeString: 48 | value = kv.ValueString() 49 | case parser.GGUFMetadataValueTypeArray: 50 | value = handleArray(kv.ValueArray()) 51 | default: 52 | value = fmt.Sprintf("[unknown type %d]", kv.ValueType) 53 | } 54 | metadata[kv.Key] = value 55 | } 56 | 57 | return metadata 58 | } 59 | 60 | // handleArray processes an array value and returns its string representation 61 | func handleArray(arrayValue parser.GGUFMetadataKVArrayValue) string { 62 | var values []string 63 | for _, v := range arrayValue.Array { 64 | switch arrayValue.Type { 65 | case parser.GGUFMetadataValueTypeUint8: 66 | values = append(values, fmt.Sprintf("%d", v.(uint8))) 67 | case parser.GGUFMetadataValueTypeInt8: 68 | values = append(values, fmt.Sprintf("%d", v.(int8))) 69 | case parser.GGUFMetadataValueTypeUint16: 70 | values = append(values, fmt.Sprintf("%d", v.(uint16))) 71 | case parser.GGUFMetadataValueTypeInt16: 72 | values = append(values, fmt.Sprintf("%d", v.(int16))) 73 | case parser.GGUFMetadataValueTypeUint32: 74 | values = append(values, fmt.Sprintf("%d", v.(uint32))) 75 | case parser.GGUFMetadataValueTypeInt32: 76 | values = append(values, fmt.Sprintf("%d", v.(int32))) 77 | case parser.GGUFMetadataValueTypeUint64: 78 | values = append(values, fmt.Sprintf("%d", v.(uint64))) 79 | case parser.GGUFMetadataValueTypeInt64: 80 | values = append(values, fmt.Sprintf("%d", v.(int64))) 81 | case parser.GGUFMetadataValueTypeFloat32: 82 | values = append(values, fmt.Sprintf("%f", v.(float32))) 83 | case parser.GGUFMetadataValueTypeFloat64: 84 | values = append(values, fmt.Sprintf("%f", v.(float64))) 85 | case parser.GGUFMetadataValueTypeBool: 86 | values = append(values, fmt.Sprintf("%t", v.(bool))) 87 | case parser.GGUFMetadataValueTypeString: 88 | values = append(values, v.(string)) 89 | default: 90 | // Do nothing 91 | } 92 | } 93 | 94 | return strings.Join(values, ", ") 95 | } 96 | -------------------------------------------------------------------------------- /internal/gguf/model.go: -------------------------------------------------------------------------------- 1 | package gguf 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | v1 "github.com/google/go-containerregistry/pkg/v1" 8 | "github.com/google/go-containerregistry/pkg/v1/partial" 9 | ggcr "github.com/google/go-containerregistry/pkg/v1/types" 10 | 11 | mdpartial "github.com/docker/model-distribution/internal/partial" 12 | "github.com/docker/model-distribution/types" 13 | ) 14 | 15 | var _ types.ModelArtifact = &Model{} 16 | 17 | type Model struct { 18 | configFile types.ConfigFile 19 | layers []v1.Layer 20 | manifest *v1.Manifest 21 | } 22 | 23 | func (m *Model) Layers() ([]v1.Layer, error) { 24 | return m.layers, nil 25 | } 26 | 27 | func (m *Model) Size() (int64, error) { 28 | return partial.Size(m) 29 | } 30 | 31 | func (m *Model) ConfigName() (v1.Hash, error) { 32 | return partial.ConfigName(m) 33 | } 34 | 35 | func (m *Model) ConfigFile() (*v1.ConfigFile, error) { 36 | return nil, fmt.Errorf("invalid for model") 37 | } 38 | 39 | func (m *Model) Digest() (v1.Hash, error) { 40 | return partial.Digest(m) 41 | } 42 | 43 | func (m *Model) Manifest() (*v1.Manifest, error) { 44 | return mdpartial.ManifestForLayers(m) 45 | } 46 | 47 | func (m *Model) LayerByDigest(hash v1.Hash) (v1.Layer, error) { 48 | for _, l := range m.layers { 49 | d, err := l.Digest() 50 | if err != nil { 51 | return nil, fmt.Errorf("get layer digest: %w", err) 52 | } 53 | if d == hash { 54 | return l, nil 55 | } 56 | } 57 | return nil, fmt.Errorf("layer not found") 58 | } 59 | 60 | func (m *Model) LayerByDiffID(hash v1.Hash) (v1.Layer, error) { 61 | for _, l := range m.layers { 62 | d, err := l.DiffID() 63 | if err != nil { 64 | return nil, fmt.Errorf("get layer digest: %w", err) 65 | } 66 | if d == hash { 67 | return l, nil 68 | } 69 | } 70 | return nil, fmt.Errorf("layer not found") 71 | } 72 | 73 | func (m *Model) RawManifest() ([]byte, error) { 74 | return partial.RawManifest(m) 75 | } 76 | 77 | func (m *Model) RawConfigFile() ([]byte, error) { 78 | return json.Marshal(m.configFile) 79 | } 80 | 81 | func (m *Model) MediaType() (ggcr.MediaType, error) { 82 | manifest, err := m.Manifest() 83 | if err != nil { 84 | return "", fmt.Errorf("compute maniest: %w", err) 85 | } 86 | return manifest.MediaType, nil 87 | } 88 | 89 | func (m *Model) ID() (string, error) { 90 | return mdpartial.ID(m) 91 | } 92 | 93 | func (m *Model) Config() (types.Config, error) { 94 | return mdpartial.Config(m) 95 | } 96 | 97 | func (m *Model) Descriptor() (types.Descriptor, error) { 98 | return mdpartial.Descriptor(m) 99 | } 100 | -------------------------------------------------------------------------------- /internal/gguf/model_test.go: -------------------------------------------------------------------------------- 1 | package gguf_test 2 | 3 | import ( 4 | "path/filepath" 5 | "testing" 6 | 7 | "github.com/docker/model-distribution/internal/gguf" 8 | "github.com/docker/model-distribution/types" 9 | ) 10 | 11 | func TestGGUF(t *testing.T) { 12 | t.Run("TestGGUFModel", func(t *testing.T) { 13 | mdl, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf")) 14 | if err != nil { 15 | t.Fatalf("Failed to create model: %v", err) 16 | } 17 | 18 | t.Run("TestConfig", func(t *testing.T) { 19 | cfg, err := mdl.Config() 20 | if err != nil { 21 | t.Fatalf("Failed to get config: %v", err) 22 | } 23 | if cfg.Format != types.FormatGGUF { 24 | t.Fatalf("Unexpected format: got %s expected %s", cfg.Format, types.FormatGGUF) 25 | } 26 | if cfg.Parameters != "183" { 27 | t.Fatalf("Unexpected parameters: got %s expected %s", cfg.Parameters, "183") 28 | } 29 | if cfg.Architecture != "llama" { 30 | t.Fatalf("Unexpected architecture: got %s expected %s", cfg.Parameters, "llama") 31 | } 32 | if cfg.Quantization != "Unknown" { // todo: testdata with a real value 33 | t.Fatalf("Unexpected quantization: got %s expected %s", cfg.Quantization, "Unknown") 34 | } 35 | if cfg.Size != "864 B" { 36 | t.Fatalf("Unexpected quantization: got %s expected %s", cfg.Quantization, "Unknown") 37 | } 38 | 39 | // Test GGUF metadata 40 | if cfg.GGUF == nil { 41 | t.Fatal("Expected GGUF metadata to be present") 42 | } 43 | // Verify all expected metadata fields from the example https://github.com/ggml-org/llama.cpp/blob/44cd8d91ff2c9e4a0f2e3151f8d6f04c928e2571/examples/gguf/gguf.cpp#L24 44 | expectedParams := map[string]string{ 45 | "some.parameter.uint8": "18", // 0x12 46 | "some.parameter.int8": "-19", // -0x13 47 | "some.parameter.uint16": "4660", // 0x1234 48 | "some.parameter.int16": "-4661", // -0x1235 49 | "some.parameter.uint32": "305419896", // 0x12345678 50 | "some.parameter.int32": "-305419897", // -0x12345679 51 | "some.parameter.float32": "0.123457", // 0.123456789f 52 | "some.parameter.uint64": "1311768467463790320", // 0x123456789abcdef0 53 | "some.parameter.int64": "-1311768467463790321", // -0x123456789abcdef1 54 | "some.parameter.float64": "0.123457", // 0.1234567890123456789 55 | "some.parameter.bool": "true", 56 | "some.parameter.string": "hello world", 57 | "some.parameter.arr.i16": "1, 2, 3, 4", 58 | } 59 | 60 | for key, expectedValue := range expectedParams { 61 | actualValue, ok := cfg.GGUF[key] 62 | if !ok { 63 | t.Errorf("Expected key '%s' in GGUF metadata", key) 64 | continue 65 | } 66 | if actualValue != expectedValue { 67 | t.Errorf("For key '%s': expected value '%s', got '%s'", key, expectedValue, actualValue) 68 | } 69 | } 70 | }) 71 | 72 | t.Run("TestDescriptor", func(t *testing.T) { 73 | desc, err := mdl.Descriptor() 74 | if err != nil { 75 | t.Fatalf("Failed to get config: %v", err) 76 | } 77 | if desc.Created == nil { 78 | t.Fatal("Expected created time to be set: got ni") 79 | } 80 | }) 81 | 82 | t.Run("TestManifest", func(t *testing.T) { 83 | manifest, err := mdl.Manifest() 84 | if err != nil { 85 | t.Fatalf("Failed to get config: %v", err) 86 | } 87 | if len(manifest.Layers) != 1 { 88 | t.Fatalf("Expected 1 layer, got %d", len(manifest.Layers)) 89 | } 90 | if manifest.Layers[0].MediaType != types.MediaTypeGGUF { 91 | t.Fatalf("Expected layer with media type %s, got %s", types.MediaTypeGGUF, manifest.Layers[0].MediaType) 92 | } 93 | }) 94 | }) 95 | } 96 | -------------------------------------------------------------------------------- /internal/mutate/model.go: -------------------------------------------------------------------------------- 1 | package mutate 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | v1 "github.com/google/go-containerregistry/pkg/v1" 8 | ggcrpartial "github.com/google/go-containerregistry/pkg/v1/partial" 9 | ggcr "github.com/google/go-containerregistry/pkg/v1/types" 10 | 11 | "github.com/docker/model-distribution/internal/partial" 12 | "github.com/docker/model-distribution/types" 13 | ) 14 | 15 | type model struct { 16 | base types.ModelArtifact 17 | appended []v1.Layer 18 | configMediaType ggcr.MediaType 19 | } 20 | 21 | func (m *model) Descriptor() (types.Descriptor, error) { 22 | return partial.Descriptor(m.base) 23 | } 24 | 25 | func (m *model) ID() (string, error) { 26 | return partial.ID(m) 27 | } 28 | 29 | func (m *model) Config() (types.Config, error) { 30 | return partial.Config(m) 31 | } 32 | 33 | func (m *model) MediaType() (ggcr.MediaType, error) { 34 | manifest, err := m.Manifest() 35 | if err != nil { 36 | return "", fmt.Errorf("compute maniest: %w", err) 37 | } 38 | return manifest.MediaType, nil 39 | } 40 | 41 | func (m *model) Size() (int64, error) { 42 | return ggcrpartial.Size(m) 43 | } 44 | 45 | func (m *model) ConfigName() (v1.Hash, error) { 46 | return ggcrpartial.ConfigName(m) 47 | } 48 | 49 | func (m *model) ConfigFile() (*v1.ConfigFile, error) { 50 | return nil, fmt.Errorf("invalid for model") 51 | } 52 | 53 | func (m *model) Digest() (v1.Hash, error) { 54 | return ggcrpartial.Digest(m) 55 | } 56 | 57 | func (m *model) RawManifest() ([]byte, error) { 58 | return ggcrpartial.RawManifest(m) 59 | } 60 | 61 | func (m *model) LayerByDigest(hash v1.Hash) (v1.Layer, error) { 62 | ls, err := m.Layers() 63 | if err != nil { 64 | return nil, err 65 | } 66 | for _, l := range ls { 67 | d, err := l.Digest() 68 | if err != nil { 69 | return nil, fmt.Errorf("get layer digest: %w", err) 70 | } 71 | if d == hash { 72 | return l, nil 73 | } 74 | } 75 | return nil, fmt.Errorf("layer not found") 76 | } 77 | 78 | func (m *model) LayerByDiffID(hash v1.Hash) (v1.Layer, error) { 79 | ls, err := m.Layers() 80 | if err != nil { 81 | return nil, err 82 | } 83 | for _, l := range ls { 84 | d, err := l.Digest() 85 | if err != nil { 86 | return nil, fmt.Errorf("get layer digest: %w", err) 87 | } 88 | if d == hash { 89 | return l, nil 90 | } 91 | } 92 | return nil, fmt.Errorf("layer not found") 93 | } 94 | 95 | func (m *model) Layers() ([]v1.Layer, error) { 96 | ls, err := m.base.Layers() 97 | if err != nil { 98 | return nil, err 99 | } 100 | return append(ls, m.appended...), nil 101 | } 102 | 103 | func (m *model) Manifest() (*v1.Manifest, error) { 104 | manifest, err := partial.ManifestForLayers(m) 105 | if err != nil { 106 | return nil, err 107 | } 108 | if m.configMediaType != "" { 109 | manifest.Config.MediaType = m.configMediaType 110 | } 111 | return manifest, nil 112 | } 113 | 114 | func (m *model) RawConfigFile() ([]byte, error) { 115 | cf, err := partial.ConfigFile(m.base) 116 | if err != nil { 117 | return nil, err 118 | } 119 | for _, l := range m.appended { 120 | diffID, err := l.DiffID() 121 | if err != nil { 122 | return nil, err 123 | } 124 | cf.RootFS.DiffIDs = append(cf.RootFS.DiffIDs, diffID) 125 | } 126 | raw, err := json.Marshal(cf) 127 | if err != nil { 128 | return nil, err 129 | } 130 | return raw, err 131 | } 132 | -------------------------------------------------------------------------------- /internal/mutate/mutate.go: -------------------------------------------------------------------------------- 1 | package mutate 2 | 3 | import ( 4 | v1 "github.com/google/go-containerregistry/pkg/v1" 5 | ggcr "github.com/google/go-containerregistry/pkg/v1/types" 6 | 7 | "github.com/docker/model-distribution/types" 8 | ) 9 | 10 | func AppendLayers(mdl types.ModelArtifact, layers ...v1.Layer) types.ModelArtifact { 11 | return &model{ 12 | base: mdl, 13 | appended: layers, 14 | } 15 | } 16 | 17 | func ConfigMediaType(mdl types.ModelArtifact, mt ggcr.MediaType) types.ModelArtifact { 18 | return &model{ 19 | base: mdl, 20 | configMediaType: mt, 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /internal/mutate/mutate_test.go: -------------------------------------------------------------------------------- 1 | package mutate_test 2 | 3 | import ( 4 | "encoding/json" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/google/go-containerregistry/pkg/v1/static" 9 | ggcr "github.com/google/go-containerregistry/pkg/v1/types" 10 | 11 | "github.com/docker/model-distribution/internal/gguf" 12 | "github.com/docker/model-distribution/internal/mutate" 13 | "github.com/docker/model-distribution/types" 14 | ) 15 | 16 | func TestAppendLayer(t *testing.T) { 17 | mdl1, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf")) 18 | if err != nil { 19 | t.Fatalf("Failed to create model: %v", err) 20 | } 21 | manifest1, err := mdl1.Manifest() 22 | if err != nil { 23 | t.Fatalf("Failed to create model: %v", err) 24 | } 25 | if len(manifest1.Layers) != 1 { // begin with one layer 26 | t.Fatalf("Expected 1 layer, got %d", len(manifest1.Layers)) 27 | } 28 | 29 | // Append a layer 30 | mdl2 := mutate.AppendLayers(mdl1, 31 | static.NewLayer([]byte("some layer content"), "application/vnd.example.some.media.type"), 32 | ) 33 | if err != nil { 34 | t.Fatalf("Failed to create layer: %v", err) 35 | } 36 | if mdl2 == nil { 37 | t.Fatal("Expected non-nil model") 38 | } 39 | 40 | // Check the manifest 41 | manifest2, err := mdl2.Manifest() 42 | if err != nil { 43 | t.Fatalf("Failed to create model: %v", err) 44 | } 45 | if len(manifest2.Layers) != 2 { // begin with one layer 46 | t.Fatalf("Expected 2 layers, got %d", len(manifest1.Layers)) 47 | } 48 | 49 | // Check the config file 50 | rawCfg, err := mdl2.RawConfigFile() 51 | if err != nil { 52 | t.Fatalf("Failed to get raw config file: %v", err) 53 | } 54 | var cfg types.ConfigFile 55 | if err := json.Unmarshal(rawCfg, &cfg); err != nil { 56 | t.Fatalf("Failed to unmarshal config file: %v", err) 57 | } 58 | if len(cfg.RootFS.DiffIDs) != 2 { 59 | t.Fatalf("Expected 2 diff ids in rootfs, got %d", len(cfg.RootFS.DiffIDs)) 60 | } 61 | } 62 | 63 | func TestConfigMediaTypes(t *testing.T) { 64 | mdl1, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf")) 65 | if err != nil { 66 | t.Fatalf("Failed to create model: %v", err) 67 | } 68 | manifest1, err := mdl1.Manifest() 69 | if err != nil { 70 | t.Fatalf("Failed to create model: %v", err) 71 | } 72 | if manifest1.Config.MediaType != types.MediaTypeModelConfigV01 { 73 | t.Fatalf("Expected media type %s, got %s", types.MediaTypeModelConfigV01, manifest1.Config.MediaType) 74 | } 75 | 76 | newMediaType := ggcr.MediaType("application/vnd.example.other.type") 77 | mdl2 := mutate.ConfigMediaType(mdl1, newMediaType) 78 | manifest2, err := mdl2.Manifest() 79 | if err != nil { 80 | t.Fatalf("Failed to create model: %v", err) 81 | } 82 | if manifest2.Config.MediaType != newMediaType { 83 | t.Fatalf("Expected media type %s, got %s", newMediaType, manifest2.Config.MediaType) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /internal/partial/layer.go: -------------------------------------------------------------------------------- 1 | package partial 2 | 3 | import ( 4 | "io" 5 | "os" 6 | 7 | "github.com/google/go-containerregistry/pkg/v1" 8 | ggcrtypes "github.com/google/go-containerregistry/pkg/v1/types" 9 | ) 10 | 11 | var _ v1.Layer = &Layer{} 12 | 13 | type Layer struct { 14 | Path string 15 | v1.Descriptor 16 | } 17 | 18 | func NewLayer(path string, mt ggcrtypes.MediaType) (*Layer, error) { 19 | f, err := os.Open(path) 20 | if err != nil { 21 | return nil, err 22 | } 23 | defer f.Close() 24 | hash, size, err := v1.SHA256(f) 25 | return &Layer{ 26 | Path: path, 27 | Descriptor: v1.Descriptor{ 28 | Size: size, 29 | Digest: hash, 30 | MediaType: mt, 31 | }, 32 | }, nil 33 | } 34 | 35 | func (l Layer) Digest() (v1.Hash, error) { 36 | return l.DiffID() 37 | } 38 | 39 | func (l Layer) DiffID() (v1.Hash, error) { 40 | return l.Descriptor.Digest, nil 41 | } 42 | 43 | func (l Layer) Compressed() (io.ReadCloser, error) { 44 | return l.Uncompressed() 45 | } 46 | 47 | func (l Layer) Uncompressed() (io.ReadCloser, error) { 48 | return os.Open(l.Path) 49 | } 50 | 51 | func (l Layer) Size() (int64, error) { 52 | return l.Descriptor.Size, nil 53 | } 54 | 55 | func (l Layer) MediaType() (ggcrtypes.MediaType, error) { 56 | return l.Descriptor.MediaType, nil 57 | } 58 | -------------------------------------------------------------------------------- /internal/partial/partial.go: -------------------------------------------------------------------------------- 1 | package partial 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | v1 "github.com/google/go-containerregistry/pkg/v1" 8 | "github.com/google/go-containerregistry/pkg/v1/partial" 9 | ggcr "github.com/google/go-containerregistry/pkg/v1/types" 10 | "github.com/pkg/errors" 11 | 12 | "github.com/docker/model-distribution/types" 13 | ) 14 | 15 | type WithRawConfigFile interface { 16 | // RawConfigFile returns the serialized bytes of this model's config file. 17 | RawConfigFile() ([]byte, error) 18 | } 19 | 20 | func ConfigFile(i WithRawConfigFile) (*types.ConfigFile, error) { 21 | raw, err := i.RawConfigFile() 22 | if err != nil { 23 | return nil, fmt.Errorf("get raw config file: %w", err) 24 | } 25 | var cf types.ConfigFile 26 | if err := json.Unmarshal(raw, &cf); err != nil { 27 | return nil, fmt.Errorf("unmarshal : %w", err) 28 | } 29 | return &cf, nil 30 | } 31 | 32 | // Config returns the types.Config for the model. 33 | func Config(i WithRawConfigFile) (types.Config, error) { 34 | cf, err := ConfigFile(i) 35 | if err != nil { 36 | return types.Config{}, fmt.Errorf("config file: %w", err) 37 | } 38 | return cf.Config, nil 39 | } 40 | 41 | // Descriptor returns the types.Descriptor for the model. 42 | func Descriptor(i WithRawConfigFile) (types.Descriptor, error) { 43 | cf, err := ConfigFile(i) 44 | if err != nil { 45 | return types.Descriptor{}, fmt.Errorf("config file: %w", err) 46 | } 47 | return cf.Descriptor, nil 48 | } 49 | 50 | // WithRawManifest defines the subset of types.Model used by these helper methods 51 | type WithRawManifest interface { 52 | // RawManifest returns the serialized bytes of this model's manifest file. 53 | RawManifest() ([]byte, error) 54 | } 55 | 56 | func ID(i WithRawManifest) (string, error) { 57 | digest, err := partial.Digest(i) 58 | if err != nil { 59 | return "", fmt.Errorf("get digest: %w", err) 60 | } 61 | return digest.String(), nil 62 | } 63 | 64 | type WithLayers interface { 65 | WithRawConfigFile 66 | Layers() ([]v1.Layer, error) 67 | } 68 | 69 | func GGUFPath(i WithLayers) (string, error) { 70 | layers, err := i.Layers() 71 | if err != nil { 72 | return "", fmt.Errorf("get layers: %w", err) 73 | } 74 | for _, l := range layers { 75 | mt, err := l.MediaType() 76 | if err != nil || mt != types.MediaTypeGGUF { 77 | continue 78 | } 79 | ggufLayer, ok := l.(*Layer) 80 | if !ok { 81 | return "", errors.New("gguf Layer is not available locally") 82 | } 83 | return ggufLayer.Path, nil 84 | } 85 | return "", errors.New("model does not contain a GGUF layer") 86 | } 87 | 88 | func ManifestForLayers(i WithLayers) (*v1.Manifest, error) { 89 | cfgLayer, err := partial.ConfigLayer(i) 90 | if err != nil { 91 | return nil, fmt.Errorf("get raw config file: %w", err) 92 | } 93 | cfgDsc, err := partial.Descriptor(cfgLayer) 94 | if err != nil { 95 | return nil, fmt.Errorf("get config descriptor: %w", err) 96 | } 97 | cfgDsc.MediaType = types.MediaTypeModelConfigV01 98 | 99 | ls, err := i.Layers() 100 | if err != nil { 101 | return nil, fmt.Errorf("get layers: %w", err) 102 | } 103 | 104 | var layers []v1.Descriptor 105 | for _, l := range ls { 106 | desc, err := partial.Descriptor(l) 107 | if err != nil { 108 | return nil, fmt.Errorf("get layer descriptor: %w", err) 109 | } 110 | layers = append(layers, *desc) 111 | } 112 | 113 | return &v1.Manifest{ 114 | SchemaVersion: 2, 115 | MediaType: ggcr.OCIManifestSchema1, 116 | Config: *cfgDsc, 117 | Layers: layers, 118 | }, nil 119 | } 120 | -------------------------------------------------------------------------------- /internal/progress/reporter.go: -------------------------------------------------------------------------------- 1 | package progress 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "time" 8 | 9 | v1 "github.com/google/go-containerregistry/pkg/v1" 10 | ) 11 | 12 | // UpdateInterval defines how often progress updates should be sent 13 | const UpdateInterval = 100 * time.Millisecond 14 | 15 | // MinBytesForUpdate defines the minimum number of bytes that need to be transferred 16 | // before sending a progress update 17 | const MinBytesForUpdate = 1024 * 1024 // 1MB 18 | 19 | type Layer struct { 20 | ID string // Layer ID 21 | Size uint64 // Layer size 22 | Current uint64 // Current bytes transferred 23 | } 24 | 25 | // Message represents a structured message for progress reporting 26 | type Message struct { 27 | Type string `json:"type"` // "progress", "success", or "error" 28 | Message string `json:"message"` // Human-readable message 29 | Total uint64 `json:"total"` // Deprecated: use Layer.Size 30 | Pulled uint64 `json:"pulled"` // Deprecated: use Layer.Current 31 | Layer Layer `json:"layer,omitempty"` // Current layer information 32 | } 33 | 34 | type Reporter struct { 35 | progress chan v1.Update 36 | done chan struct{} 37 | err error 38 | out io.Writer 39 | format progressF 40 | layer v1.Layer 41 | TotalLayers int // Total number of layers 42 | } 43 | 44 | type progressF func(update v1.Update) string 45 | 46 | func PullMsg(update v1.Update) string { 47 | return fmt.Sprintf("Downloaded: %.2f MB", float64(update.Complete)/1024/1024) 48 | } 49 | 50 | func PushMsg(update v1.Update) string { 51 | return fmt.Sprintf("Uploaded: %.2f MB", float64(update.Complete)/1024/1024) 52 | } 53 | 54 | func NewProgressReporter(w io.Writer, msgF progressF, layer v1.Layer) *Reporter { 55 | return &Reporter{ 56 | out: w, 57 | progress: make(chan v1.Update, 1), 58 | done: make(chan struct{}), 59 | format: msgF, 60 | layer: layer, 61 | } 62 | } 63 | 64 | // safeUint64 converts an int64 to uint64, ensuring the value is non-negative 65 | func safeUint64(n int64) uint64 { 66 | if n < 0 { 67 | return 0 68 | } 69 | return uint64(n) 70 | } 71 | 72 | // Updates returns a channel for receiving progress Updates. It is the responsibility of the caller to close 73 | // the channel when they are done sending Updates. Should only be called once per Reporter instance. 74 | func (r *Reporter) Updates() chan<- v1.Update { 75 | go func() { 76 | var lastComplete int64 77 | var lastUpdate time.Time 78 | 79 | for p := range r.progress { 80 | if r.out == nil || r.err != nil { 81 | continue // If we fail to write progress, don't try again 82 | } 83 | now := time.Now() 84 | var total int64 85 | var layerID string 86 | if r.layer != nil { // In case of Push there is no layer yet 87 | id, err := r.layer.DiffID() 88 | if err != nil { 89 | r.err = err 90 | continue 91 | } 92 | layerID = id.String() 93 | size, err := r.layer.Size() 94 | if err != nil { 95 | r.err = err 96 | continue 97 | } 98 | total = size 99 | } else { 100 | total = p.Total 101 | } 102 | incrementalBytes := p.Complete - lastComplete 103 | 104 | // Only update if enough time has passed or enough bytes downloaded or finished 105 | if now.Sub(lastUpdate) >= UpdateInterval || 106 | incrementalBytes >= MinBytesForUpdate { 107 | if err := WriteProgress(r.out, r.format(p), safeUint64(total), safeUint64(p.Complete), layerID); err != nil { 108 | r.err = err 109 | } 110 | lastUpdate = now 111 | lastComplete = p.Complete 112 | } 113 | } 114 | close(r.done) // Close the done channel when progress is complete 115 | }() 116 | return r.progress 117 | } 118 | 119 | // Wait waits for the progress Reporter to finish and returns any error encountered. 120 | func (r *Reporter) Wait() error { 121 | <-r.done 122 | return r.err 123 | } 124 | 125 | // WriteProgress writes a progress update message 126 | func WriteProgress(w io.Writer, msg string, total, current uint64, layerID string) error { 127 | return write(w, Message{ 128 | Type: "progress", 129 | Message: msg, 130 | Total: total, 131 | Pulled: current, 132 | Layer: Layer{ 133 | ID: layerID, 134 | Size: total, 135 | Current: current, 136 | }, 137 | }) 138 | } 139 | 140 | // WriteSuccess writes a success message 141 | func WriteSuccess(w io.Writer, message string) error { 142 | return write(w, Message{ 143 | Type: "success", 144 | Message: message, 145 | }) 146 | } 147 | 148 | // WriteError writes an error message 149 | func WriteError(w io.Writer, message string) error { 150 | return write(w, Message{ 151 | Type: "error", 152 | Message: message, 153 | }) 154 | } 155 | 156 | // write writes a JSON-formatted progress message to the writer 157 | func write(w io.Writer, msg Message) error { 158 | if w == nil { 159 | return nil 160 | } 161 | data, err := json.Marshal(msg) 162 | if err != nil { 163 | return err 164 | } 165 | _, err = fmt.Fprintf(w, "%s\n", data) 166 | return err 167 | } 168 | -------------------------------------------------------------------------------- /internal/progress/reporter_test.go: -------------------------------------------------------------------------------- 1 | package progress 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "io" 7 | "testing" 8 | "time" 9 | 10 | v1 "github.com/google/go-containerregistry/pkg/v1" 11 | v1types "github.com/google/go-containerregistry/pkg/v1/types" 12 | 13 | "github.com/docker/model-distribution/types" 14 | ) 15 | 16 | // mockLayer implements v1.Layer for testing 17 | type mockLayer struct { 18 | size int64 19 | diffID string 20 | mediaType v1types.MediaType 21 | } 22 | 23 | func (m *mockLayer) Digest() (v1.Hash, error) { 24 | return v1.Hash{}, nil 25 | } 26 | 27 | func (m *mockLayer) DiffID() (v1.Hash, error) { 28 | return v1.NewHash(m.diffID) 29 | } 30 | 31 | func (m *mockLayer) Compressed() (io.ReadCloser, error) { 32 | return nil, nil 33 | } 34 | 35 | func (m *mockLayer) Uncompressed() (io.ReadCloser, error) { 36 | return nil, nil 37 | } 38 | 39 | func (m *mockLayer) Size() (int64, error) { 40 | return m.size, nil 41 | } 42 | 43 | func (m *mockLayer) MediaType() (v1types.MediaType, error) { 44 | return m.mediaType, nil 45 | } 46 | 47 | func newMockLayer(size int64) *mockLayer { 48 | return &mockLayer{ 49 | size: size, 50 | diffID: "sha256:c7790a0a70161f1bfd441cf157313e9efb8fcd1f0831193101def035ead23b32", 51 | mediaType: types.MediaTypeGGUF, 52 | } 53 | } 54 | 55 | func TestMessages(t *testing.T) { 56 | t.Run("writeProgress", func(t *testing.T) { 57 | var buf bytes.Buffer 58 | update := v1.Update{ 59 | Complete: 1024 * 1024, 60 | } 61 | layer := newMockLayer(2016) 62 | size := layer.size 63 | 64 | err := WriteProgress(&buf, PullMsg(update), uint64(size), uint64(update.Complete), layer.diffID) 65 | if err != nil { 66 | t.Fatalf("Failed to write progress message: %v", err) 67 | } 68 | 69 | var msg Message 70 | if err := json.Unmarshal(buf.Bytes(), &msg); err != nil { 71 | t.Fatalf("Failed to parse JSON: %v", err) 72 | } 73 | 74 | if msg.Type != "progress" { 75 | t.Errorf("Expected type 'progress', got '%s'", msg.Type) 76 | } 77 | if msg.Message != "Downloaded: 1.00 MB" { 78 | t.Errorf("Expected message 'Downloaded: 1.00 MB', got '%s'", msg.Message) 79 | } 80 | if msg.Pulled != uint64(1024*1024) { 81 | t.Errorf("Expected pulled 1MB, got %d", msg.Pulled) 82 | } 83 | if msg.Layer == (Layer{}) { 84 | t.Errorf("Expected layer to be set") 85 | } 86 | if msg.Layer.ID != "sha256:c7790a0a70161f1bfd441cf157313e9efb8fcd1f0831193101def035ead23b32" { 87 | t.Errorf("Expected layer ID to be %s, got %s", "sha256:c7790a0a70161f1bfd441cf157313e9efb8fcd1f0831193101def035ead23b32", msg.Layer.ID) 88 | } 89 | if msg.Layer.Size != uint64(2016) { 90 | t.Errorf("Expected layer size to be %d, got %d", 2016, msg.Layer.Size) 91 | } 92 | if msg.Layer.Current != uint64(1048576) { 93 | t.Errorf("Expected layer current to be %d, got %d", 1048576, msg.Layer.Current) 94 | } 95 | }) 96 | 97 | t.Run("writeSuccess", func(t *testing.T) { 98 | var buf bytes.Buffer 99 | err := WriteSuccess(&buf, "Model pulled successfully") 100 | if err != nil { 101 | t.Fatalf("Failed to write success message: %v", err) 102 | } 103 | 104 | var msg Message 105 | if err := json.Unmarshal(buf.Bytes(), &msg); err != nil { 106 | t.Fatalf("Failed to parse JSON: %v", err) 107 | } 108 | 109 | if msg.Type != "success" { 110 | t.Errorf("Expected type 'success', got '%s'", msg.Type) 111 | } 112 | if msg.Message != "Model pulled successfully" { 113 | t.Errorf("Expected message 'Model pulled successfully', got '%s'", msg.Message) 114 | } 115 | }) 116 | 117 | t.Run("writeError", func(t *testing.T) { 118 | var buf bytes.Buffer 119 | err := WriteError(&buf, "Error: something went wrong") 120 | if err != nil { 121 | t.Fatalf("Failed to write error message: %v", err) 122 | } 123 | 124 | var msg Message 125 | if err := json.Unmarshal(buf.Bytes(), &msg); err != nil { 126 | t.Fatalf("Failed to parse JSON: %v", err) 127 | } 128 | 129 | if msg.Type != "error" { 130 | t.Errorf("Expected type 'error', got '%s'", msg.Type) 131 | } 132 | if msg.Message != "Error: something went wrong" { 133 | t.Errorf("Expected message 'Error: something went wrong', got '%s'", msg.Message) 134 | } 135 | }) 136 | } 137 | 138 | func TestProgressEmissionScenarios(t *testing.T) { 139 | tests := []struct { 140 | name string 141 | updates []v1.Update 142 | delays []time.Duration 143 | expectedCount int 144 | description string 145 | layerSize int64 146 | }{ 147 | { 148 | name: "time-based updates", 149 | updates: []v1.Update{ 150 | {Complete: 100}, // First update always sent 151 | {Complete: 100}, // Sent after interval 152 | {Complete: 1000}, // Sent after interval 153 | }, 154 | delays: []time.Duration{ 155 | UpdateInterval + 100*time.Millisecond, 156 | UpdateInterval + 100*time.Millisecond, 157 | }, 158 | expectedCount: 3, // First update + 2 time-based updates 159 | description: "should emit updates based on time interval", 160 | layerSize: 100, 161 | }, 162 | { 163 | name: "byte-based updates", 164 | updates: []v1.Update{ 165 | {Complete: MinBytesForUpdate}, // First update always sent 166 | {Complete: MinBytesForUpdate * 2}, // Second update with 1MB difference 167 | }, 168 | delays: []time.Duration{ 169 | 10 * time.Millisecond, // Short delay, should trigger based on bytes 170 | }, 171 | expectedCount: 2, // First update + 1 byte-based update 172 | description: "should emit update based on byte threshold", 173 | layerSize: MinBytesForUpdate + 1, 174 | }, 175 | { 176 | name: "no updates - too frequent", 177 | updates: []v1.Update{ 178 | {Complete: 100}, // First update always sent 179 | {Complete: 100}, // Too frequent, no update 180 | {Complete: 100}, // Too frequent, no update 181 | }, 182 | delays: []time.Duration{ 183 | 10 * time.Millisecond, // Too short 184 | 10 * time.Millisecond, // Too short 185 | }, 186 | expectedCount: 1, // Only first update 187 | description: "should not emit updates if too frequent", 188 | layerSize: 100, 189 | }, 190 | { 191 | name: "no updates - too few bytes", 192 | updates: []v1.Update{ 193 | {Complete: 50}, // First update always sent 194 | {Complete: MinBytesForUpdate}, // Too few bytes 195 | {Complete: MinBytesForUpdate + 100}, // enough bytes now 196 | }, 197 | delays: []time.Duration{ 198 | 10 * time.Millisecond, 199 | }, 200 | expectedCount: 2, // First update and last update 201 | description: "should emit updates based on time even with few bytes", 202 | layerSize: 100, 203 | }, 204 | } 205 | 206 | for _, tt := range tests { 207 | t.Run(tt.name, func(t *testing.T) { 208 | var buf bytes.Buffer 209 | layer := newMockLayer(tt.layerSize) 210 | 211 | reporter := NewProgressReporter(&buf, PullMsg, layer) 212 | updates := reporter.Updates() 213 | 214 | // Send updates with delays 215 | for i, update := range tt.updates { 216 | updates <- update 217 | if i < len(tt.delays) { 218 | time.Sleep(tt.delays[i]) 219 | } 220 | } 221 | close(updates) 222 | 223 | // Wait for processing to complete 224 | if err := reporter.Wait(); err != nil { 225 | t.Fatalf("Reporter.Wait() failed: %v", err) 226 | } 227 | 228 | // Parse messages 229 | lines := bytes.Split(buf.Bytes(), []byte("\n")) 230 | var messages []Message 231 | for _, line := range lines { 232 | if len(line) == 0 { 233 | continue 234 | } 235 | var msg Message 236 | if err := json.Unmarshal(line, &msg); err != nil { 237 | t.Fatalf("Failed to parse JSON: %v", err) 238 | } 239 | messages = append(messages, msg) 240 | } 241 | 242 | if len(messages) != tt.expectedCount { 243 | t.Errorf("%s: expected %d messages, got %d", tt.description, tt.expectedCount, len(messages)) 244 | } 245 | 246 | // Verify message format for any messages received 247 | for i, msg := range messages { 248 | if msg.Type != "progress" { 249 | t.Errorf("message %d: expected type 'progress', got '%s'", i, msg.Type) 250 | } 251 | if msg.Layer.ID == "" { 252 | t.Errorf("message %d: expected layer ID to be set", i) 253 | } 254 | if msg.Layer.Size != uint64(tt.layerSize) { 255 | t.Errorf("message %d: expected layer size %d, got %d", i, tt.layerSize, msg.Layer.Size) 256 | } 257 | } 258 | }) 259 | } 260 | } 261 | -------------------------------------------------------------------------------- /internal/store/blobs.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "path/filepath" 8 | 9 | v1 "github.com/google/go-containerregistry/pkg/v1" 10 | ) 11 | 12 | const ( 13 | blobsDir = "blobs" 14 | ) 15 | 16 | // blobDir returns the path to the blobs directory 17 | func (s *LocalStore) blobsDir() string { 18 | return filepath.Join(s.rootPath, blobsDir) 19 | } 20 | 21 | // blobPath returns the path to the blob for the given hash. 22 | func (s *LocalStore) blobPath(hash v1.Hash) string { 23 | return filepath.Join(s.rootPath, blobsDir, hash.Algorithm, hash.Hex) 24 | } 25 | 26 | type blob interface { 27 | DiffID() (v1.Hash, error) 28 | Uncompressed() (io.ReadCloser, error) 29 | } 30 | 31 | // writeBlob write the blob to the store, reporting progress to the given channel. 32 | // If the blob is already in the store, it is a no-op. 33 | func (s *LocalStore) writeBlob(layer blob, progress chan<- v1.Update) error { 34 | hash, err := layer.DiffID() 35 | if err != nil { 36 | return fmt.Errorf("get file hash: %w", err) 37 | } 38 | if s.hasBlob(hash) { 39 | // todo: write something to the progress channel (we probably need to redo progress reporting a little bit) 40 | return nil 41 | } 42 | 43 | path := s.blobPath(hash) 44 | lr, err := layer.Uncompressed() 45 | if err != nil { 46 | return fmt.Errorf("get blob contents: %w", err) 47 | } 48 | defer lr.Close() 49 | r := withProgress(lr, progress) 50 | 51 | f, err := createFile(incompletePath(path)) 52 | if err != nil { 53 | return fmt.Errorf("create blob file: %w", err) 54 | } 55 | defer os.Remove(incompletePath(path)) 56 | defer f.Close() 57 | 58 | if _, err := io.Copy(f, r); err != nil { 59 | return fmt.Errorf("copy blob %q to store: %w", hash.String(), err) 60 | } 61 | 62 | f.Close() // Rename will fail on Windows if the file is still open. 63 | if err := os.Rename(incompletePath(path), path); err != nil { 64 | return fmt.Errorf("rename blob file: %w", err) 65 | } 66 | return nil 67 | } 68 | 69 | // removeBlob removes the blob with the given hash from the store. 70 | func (s *LocalStore) removeBlob(hash v1.Hash) error { 71 | return os.Remove(s.blobPath(hash)) 72 | } 73 | 74 | func (s *LocalStore) hasBlob(hash v1.Hash) bool { 75 | if _, err := os.Stat(s.blobPath(hash)); err == nil { 76 | return true 77 | } 78 | return false 79 | } 80 | 81 | // createFile is a wrapper around os.Create that creates any parent directories as needed. 82 | func createFile(path string) (*os.File, error) { 83 | if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { 84 | return nil, fmt.Errorf("create parent directory %q: %w", filepath.Dir(path), err) 85 | } 86 | return os.Create(path) 87 | } 88 | 89 | // withProgress returns a reader that reports progress to the given channel. 90 | func withProgress(r io.Reader, progress chan<- v1.Update) io.Reader { 91 | if progress == nil { 92 | return r 93 | } 94 | return &ProgressReader{ 95 | Reader: r, 96 | ProgressChan: progress, 97 | } 98 | } 99 | 100 | // incompletePath returns the path to the incomplete file for the given path. 101 | func incompletePath(path string) string { 102 | return path + ".incomplete" 103 | } 104 | 105 | // writeConfigFile writes the model config JSON file to the blob store 106 | func (s *LocalStore) writeConfigFile(mdl v1.Image) error { 107 | hash, err := mdl.ConfigName() 108 | if err != nil { 109 | return fmt.Errorf("get digest: %w", err) 110 | } 111 | rcf, err := mdl.RawConfigFile() 112 | if err != nil { 113 | return fmt.Errorf("get raw manifest: %w", err) 114 | } 115 | return writeFile(s.blobPath(hash), rcf) 116 | } 117 | -------------------------------------------------------------------------------- /internal/store/blobs_test.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "os" 7 | "path/filepath" 8 | "testing" 9 | 10 | v1 "github.com/google/go-containerregistry/pkg/v1" 11 | "github.com/google/go-containerregistry/pkg/v1/static" 12 | ) 13 | 14 | func TestBlobs(t *testing.T) { 15 | tmpDir, err := os.MkdirTemp("", "blob-test") 16 | if err != nil { 17 | t.Fatalf("error creating temp dir: %v", err) 18 | } 19 | rootDir := filepath.Join(tmpDir, "store") 20 | store, err := New(Options{RootPath: rootDir}) 21 | if err != nil { 22 | t.Fatalf("error creating store: %v", err) 23 | } 24 | 25 | t.Run("writeBlob with missing dir", func(t *testing.T) { 26 | // remove blobs directory to ensure it is recreated as needed 27 | if err := os.RemoveAll(store.blobsDir()); err != nil { 28 | t.Fatalf("expected blobs directory not be present") 29 | } 30 | 31 | // create the blob 32 | expectedContent := "some data" 33 | blob := static.NewLayer([]byte(expectedContent), "application/vnd.example.some.mt") 34 | hash, err := blob.DiffID() 35 | if err != nil { 36 | t.Fatalf("error getting blob hash: %v", err) 37 | } 38 | 39 | // write the blob 40 | if err := store.writeBlob(blob, nil); err != nil { 41 | t.Fatalf("error writing blob: %v", err) 42 | } 43 | 44 | // ensure blob file exists 45 | content, err := os.ReadFile(store.blobPath(hash)) 46 | if err != nil { 47 | t.Fatalf("error reading blob file: %v", err) 48 | } 49 | 50 | // ensure correct content 51 | if string(content) != expectedContent { 52 | t.Fatalf("unexpected blob content: got %v expected %s", string(content), expectedContent) 53 | } 54 | 55 | // ensure incomplete blob file does not exist 56 | tmpFile := incompletePath(store.blobPath(hash)) 57 | if _, err := os.Stat(tmpFile); !errors.Is(err, os.ErrNotExist) { 58 | t.Fatalf("expected incomplete blob file %s not be present", tmpFile) 59 | } 60 | }) 61 | 62 | t.Run("writeBlob fails", func(t *testing.T) { 63 | // simulate lingering incomplete blob file (if program crashed) 64 | hash := v1.Hash{ 65 | Algorithm: "some-alg", 66 | Hex: "some-hash", 67 | } 68 | if err := writeFile(incompletePath(store.blobPath(hash)), []byte("incomplete")); err != nil { 69 | t.Fatalf("error creating incomplete blob file for test: %v", err) 70 | } 71 | 72 | if err := store.writeBlob(&fakeBlob{ 73 | readCloser: &errorReader{}, 74 | hash: hash, 75 | }, nil); err == nil { 76 | t.Fatalf("expected error writing blob") 77 | } 78 | 79 | // ensure blob file does not exist 80 | if _, err := os.ReadFile(store.blobPath(hash)); !errors.Is(err, os.ErrNotExist) { 81 | t.Fatalf("expected blob file not to exist") 82 | } 83 | 84 | // ensure incomplete file is not left behind 85 | if _, err := os.ReadFile(incompletePath(store.blobPath(hash))); !errors.Is(err, os.ErrNotExist) { 86 | t.Fatalf("expected incomplete blob file not to exist") 87 | } 88 | }) 89 | 90 | t.Run("writeBlob reuses existing blob", func(t *testing.T) { 91 | // simulate existing blob 92 | hash := v1.Hash{ 93 | Algorithm: "some-alg", 94 | Hex: "some-hash", 95 | } 96 | if err := writeFile(store.blobPath(hash), []byte("some-data")); err != nil { 97 | t.Fatalf("error creating incomplete blob file for test: %v", err) 98 | } 99 | 100 | if err := store.writeBlob(&fakeBlob{ 101 | readCloser: &errorReader{}, // will error if existing blob is not reused 102 | hash: hash, 103 | }, nil); err != nil { 104 | t.Fatalf("error writing blob: %v", err) 105 | } 106 | 107 | // ensure blob file exists 108 | content, err := os.ReadFile(store.blobPath(hash)) 109 | if err != nil { 110 | t.Fatalf("error reading blob file: %v", err) 111 | } 112 | 113 | // ensure correct content 114 | if string(content) != "some-data" { 115 | t.Fatalf("unexpected blob content: got %v expected %s", string(content), "some-data") 116 | } 117 | }) 118 | } 119 | 120 | type fakeBlob struct { 121 | readCloser io.ReadCloser 122 | hash v1.Hash 123 | } 124 | 125 | func (f fakeBlob) DiffID() (v1.Hash, error) { 126 | return f.hash, nil 127 | } 128 | 129 | func (f fakeBlob) Uncompressed() (io.ReadCloser, error) { 130 | return f.readCloser, nil 131 | } 132 | 133 | var _ io.Reader = &errorReader{} 134 | 135 | type errorReader struct { 136 | } 137 | 138 | func (e errorReader) Read(p []byte) (n int, err error) { 139 | return 0, errors.New("fake error") 140 | } 141 | 142 | func (e errorReader) Close() error { 143 | return nil 144 | } 145 | -------------------------------------------------------------------------------- /internal/store/errors.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | var ErrModelNotFound = errors.New("model not found") 8 | -------------------------------------------------------------------------------- /internal/store/index.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/google/go-containerregistry/pkg/name" 11 | v1 "github.com/google/go-containerregistry/pkg/v1" 12 | ) 13 | 14 | // Index represents the index of all models in the store 15 | type Index struct { 16 | Models []IndexEntry `json:"models"` 17 | } 18 | 19 | func (i Index) Tag(reference string, tag string) (Index, error) { 20 | tagRef, err := name.NewTag(tag) 21 | if err != nil { 22 | return Index{}, fmt.Errorf("invalid tag: %w", err) 23 | } 24 | 25 | result := Index{} 26 | var tagged bool 27 | for _, entry := range i.Models { 28 | if entry.MatchesReference(reference) { 29 | result.Models = append(result.Models, entry.Tag(tagRef)) 30 | tagged = true 31 | } else { 32 | result.Models = append(result.Models, entry.UnTag(tagRef)) 33 | } 34 | } 35 | if !tagged { 36 | return Index{}, ErrModelNotFound 37 | } 38 | 39 | return result, nil 40 | } 41 | 42 | func (i Index) UnTag(tag string) Index { 43 | tagRef, err := name.NewTag(tag) 44 | if err != nil { 45 | return Index{} 46 | } 47 | 48 | result := Index{ 49 | Models: make([]IndexEntry, 0, len(i.Models)), 50 | } 51 | for _, entry := range i.Models { 52 | result.Models = append(result.Models, entry.UnTag(tagRef)) 53 | } 54 | 55 | return result 56 | } 57 | 58 | func (i Index) Find(reference string) (IndexEntry, int, bool) { 59 | for n, entry := range i.Models { 60 | if entry.MatchesReference(reference) { 61 | return i.Models[n], n, true 62 | } 63 | } 64 | 65 | return IndexEntry{}, 0, false 66 | } 67 | 68 | func (i Index) Remove(reference string) Index { 69 | var result Index 70 | for _, entry := range i.Models { 71 | if entry.MatchesReference(reference) { 72 | continue 73 | } 74 | result.Models = append(result.Models, entry) 75 | } 76 | 77 | return result 78 | } 79 | 80 | func (i Index) Add(entry IndexEntry) Index { 81 | _, _, ok := i.Find(entry.ID) 82 | if ok { 83 | return i 84 | } 85 | return Index{ 86 | Models: append(i.Models, entry), 87 | } 88 | } 89 | 90 | // indexPath returns the path to the index file 91 | func (s *LocalStore) indexPath() string { 92 | return filepath.Join(s.rootPath, "models.json") 93 | } 94 | 95 | // writeIndex writes the index to the index file 96 | func (s *LocalStore) writeIndex(index Index) error { 97 | // Marshal the models index 98 | modelsData, err := json.MarshalIndent(index, "", " ") 99 | if err != nil { 100 | return fmt.Errorf("marshaling models: %w", err) 101 | } 102 | 103 | // Write the models index 104 | if err := writeFile(s.indexPath(), modelsData); err != nil { 105 | return fmt.Errorf("writing models file: %w", err) 106 | } 107 | 108 | if err := s.ensureLayout(); err != nil { 109 | return fmt.Errorf("ensuring layout file exists: %w", err) 110 | } 111 | 112 | return nil 113 | } 114 | 115 | // readIndex reads the index from the index file 116 | func (s *LocalStore) readIndex() (Index, error) { 117 | // Read the models index 118 | modelsData, err := os.ReadFile(s.indexPath()) 119 | if errors.Is(err, os.ErrNotExist) { 120 | return Index{}, nil 121 | } else if err != nil { 122 | return Index{}, fmt.Errorf("reading models file: %w", err) 123 | } 124 | 125 | // Unmarshal the models index 126 | var index Index 127 | if err := json.Unmarshal(modelsData, &index); err != nil { 128 | return Index{}, fmt.Errorf("unmarshaling models: %w", err) 129 | } 130 | 131 | return index, nil 132 | } 133 | 134 | // IndexEntry represents a model with its metadata and tags 135 | type IndexEntry struct { 136 | // ID is the globally unique model identifier. 137 | ID string `json:"id"` 138 | // Tags are the list of tags associated with the model. 139 | Tags []string `json:"tags"` 140 | // Files are the GGUF files associated with the model. 141 | Files []string `json:"files"` 142 | } 143 | 144 | func newEntry(image v1.Image) (IndexEntry, error) { 145 | digest, err := image.Digest() 146 | if err != nil { 147 | return IndexEntry{}, fmt.Errorf("getting digest: %w", err) 148 | } 149 | 150 | layers, err := image.Layers() 151 | if err != nil { 152 | return IndexEntry{}, fmt.Errorf("getting layers: %w", err) 153 | } 154 | files := make([]string, len(layers)+1) 155 | for i, layer := range layers { 156 | diffID, err := layer.DiffID() 157 | if err != nil { 158 | return IndexEntry{}, fmt.Errorf("getting diffID: %w", err) 159 | } 160 | files[i] = diffID.String() 161 | } 162 | cfgName, err := image.ConfigName() 163 | if err != nil { 164 | return IndexEntry{}, fmt.Errorf("getting config name: %w", err) 165 | } 166 | files[len(layers)] = cfgName.String() 167 | 168 | return IndexEntry{ 169 | ID: digest.String(), 170 | Files: files, 171 | }, nil 172 | } 173 | 174 | func (e IndexEntry) HasTag(tag string) bool { 175 | ref, err := name.NewTag(tag) 176 | if err != nil { 177 | return false 178 | } 179 | for _, t := range e.Tags { 180 | tr, err := name.ParseReference(t) 181 | if err != nil { 182 | continue 183 | } 184 | if tr.Name() == ref.Name() { 185 | return true 186 | } 187 | } 188 | return false 189 | } 190 | 191 | func (e IndexEntry) hasTag(tag name.Tag) bool { 192 | for _, t := range e.Tags { 193 | tr, err := name.ParseReference(t) 194 | if err != nil { 195 | continue 196 | } 197 | if tr.Name() == tag.Name() { 198 | return true 199 | } 200 | } 201 | return false 202 | } 203 | 204 | func (e IndexEntry) MatchesReference(reference string) bool { 205 | if e.ID == reference { 206 | return true 207 | } 208 | ref, err := name.ParseReference(reference) 209 | if err != nil { 210 | return false 211 | } 212 | if dgst, ok := ref.(name.Digest); ok { 213 | if dgst.DigestStr() == e.ID { 214 | return true 215 | } 216 | } 217 | return e.HasTag(reference) 218 | } 219 | 220 | func (e IndexEntry) Tag(tag name.Tag) IndexEntry { 221 | if e.hasTag(tag) { 222 | return e 223 | } 224 | return IndexEntry{ 225 | ID: e.ID, 226 | Tags: append(e.Tags, tag.String()), 227 | Files: e.Files, 228 | } 229 | } 230 | 231 | func (e IndexEntry) UnTag(tag name.Tag) IndexEntry { 232 | var tags []string 233 | for i, t := range e.Tags { 234 | tr, err := name.ParseReference(t) 235 | if err != nil { 236 | continue 237 | } 238 | if tr.Name() == tag.Name() { 239 | continue 240 | } 241 | tags = append(tags, e.Tags[i]) 242 | } 243 | return IndexEntry{ 244 | ID: e.ID, 245 | Tags: tags, 246 | Files: e.Files, 247 | } 248 | } 249 | -------------------------------------------------------------------------------- /internal/store/index_test.go: -------------------------------------------------------------------------------- 1 | package store_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/docker/model-distribution/internal/store" 7 | ) 8 | 9 | func TestMatchReference(t *testing.T) { 10 | type testCase struct { 11 | entry store.IndexEntry 12 | reference string 13 | shouldMatch bool 14 | description string 15 | } 16 | tcs := []testCase{ 17 | { 18 | entry: store.IndexEntry{ 19 | ID: "sha256:232a0650cd323d3b760854c4030f63ef11023d6eb3ef78327883f3f739f99def", 20 | Tags: []string{"some-repo:latest", "some-repo:some-tag"}, 21 | }, 22 | reference: "sha256:232a0650cd323d3b760854c4030f63ef11023d6eb3ef78327883f3f739f99def", 23 | shouldMatch: true, 24 | description: "ID match", 25 | }, 26 | { 27 | entry: store.IndexEntry{ 28 | ID: "sha256:232a0650cd323d3b760854c4030f63ef11023d6eb3ef78327883f3f739f99def", 29 | Tags: []string{"some-repo:latest", "some-repo:some-tag"}, 30 | }, 31 | reference: "some-repo:some-tag", 32 | shouldMatch: true, 33 | description: "exact tag match", 34 | }, 35 | { 36 | entry: store.IndexEntry{ 37 | ID: "sha256:232a0650cd323d3b760854c4030f63ef11023d6eb3ef78327883f3f739f99def", 38 | Tags: []string{"some-repo:latest", "some-repo:some-tag"}, 39 | }, 40 | reference: "some-repo", 41 | shouldMatch: true, 42 | description: "implicit tag match", 43 | }, 44 | { 45 | entry: store.IndexEntry{ 46 | ID: "sha256:232a0650cd323d3b760854c4030f63ef11023d6eb3ef78327883f3f739f99def", 47 | Tags: []string{"some-repo:latest", "some-repo:some-tag"}, 48 | }, 49 | reference: "docker.io/library/some-repo:latest", 50 | shouldMatch: true, 51 | description: "implicit registry match", 52 | }, 53 | { 54 | entry: store.IndexEntry{ 55 | ID: "sha256:232a0650cd323d3b760854c4030f63ef11023d6eb3ef78327883f3f739f99def", 56 | Tags: []string{"some-repo:latest", "some-repo:some-tag"}, 57 | }, 58 | reference: "docker.io/some-org/some-repo:some-tag", 59 | shouldMatch: false, 60 | description: "mismatch tag reference", 61 | }, 62 | { 63 | entry: store.IndexEntry{ 64 | ID: "sha256:232a0650cd323d3b760854c4030f63ef11023d6eb3ef78327883f3f739f99def", 65 | Tags: []string{"some-repo:latest", "some-repo:some-tag"}, 66 | }, 67 | reference: "docker.io/some-org/some-repo@sha256:232a0650cd323d3b760854c4030f63ef11023d6eb3ef78327883f3f739f99def", 68 | shouldMatch: true, 69 | description: "digest reference match", 70 | }, 71 | } 72 | for _, tc := range tcs { 73 | t.Run(tc.description, func(t *testing.T) { 74 | if tc.entry.MatchesReference(tc.reference) != tc.shouldMatch { 75 | t.Errorf("Expected %v, got %v", tc.shouldMatch, !tc.shouldMatch) 76 | } 77 | }) 78 | } 79 | } 80 | 81 | func TestTag(t *testing.T) { 82 | t.Run("Tagging an entry", func(t *testing.T) { 83 | idx := store.Index{ 84 | Models: []store.IndexEntry{ 85 | { 86 | ID: "some-id", 87 | Tags: []string{"some-tag"}, 88 | }, 89 | { 90 | ID: "other-id", 91 | Tags: []string{"other-tag"}, 92 | }, 93 | }, 94 | } 95 | idx, err := idx.Tag("some-id", "other-tag") 96 | if err != nil { 97 | t.Fatalf("Error tagging entry: %v", err) 98 | } 99 | // Check that both models are still present 100 | if len(idx.Models) != 2 { 101 | t.Fatalf("Expected 2 models, got %d", len(idx.Models)) 102 | } 103 | if idx.Models[0].ID != "some-id" { 104 | t.Fatalf("Expected ID 'some-id', got '%s'", idx.Models[0].ID) 105 | } 106 | if idx.Models[1].ID != "other-id" { 107 | t.Fatalf("Expected ID 'other-id', got '%s'", idx.Models[1].ID) 108 | } 109 | 110 | // Check that new tag is added to the first model 111 | if len(idx.Models[0].Tags) != 2 { 112 | t.Fatalf("Expected 2 tags, got %d", len(idx.Models[0].Tags)) 113 | } 114 | if idx.Models[0].Tags[1] != "other-tag" { 115 | t.Fatalf("Expected tag 'other-tag', got '%s'", idx.Models[0].Tags[1]) 116 | } 117 | 118 | // Check that tag is removed from the second model 119 | if len(idx.Models[1].Tags) != 0 { 120 | t.Fatalf("Expected 0 tags, got %d", len(idx.Models[1].Tags)) 121 | } 122 | 123 | // Try to add a redundant tag 124 | idx, err = idx.Tag("some-id", "other-tag") 125 | if err != nil { 126 | t.Fatalf("Error tagging entry: %v", err) 127 | } 128 | // Check that the tag was not added again 129 | if len(idx.Models[0].Tags) != 2 { 130 | t.Fatalf("Expected 2 tags, got %d", len(idx.Models[0].Tags)) 131 | } 132 | }) 133 | } 134 | 135 | func TestUntag(t *testing.T) { 136 | t.Run("UnTagging an entry", func(t *testing.T) { 137 | idx := store.Index{ 138 | Models: []store.IndexEntry{ 139 | { 140 | ID: "some-id", 141 | Tags: []string{"some-tag", "other-tag"}, 142 | }, 143 | { 144 | ID: "other-id", 145 | Tags: []string{}, 146 | }, 147 | }, 148 | } 149 | idx = idx.UnTag("other-tag") 150 | if len(idx.Models) != 2 { 151 | t.Fatalf("Expected 2 models, got %d", len(idx.Models)) 152 | } 153 | if len(idx.Models[0].Tags) != 1 { 154 | t.Fatalf("Expected 1 tag, got %d", len(idx.Models[0].Tags)) 155 | } 156 | }) 157 | } 158 | -------------------------------------------------------------------------------- /internal/store/layout.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | ) 9 | 10 | // Layout represents the layout information of the store 11 | type Layout struct { 12 | Version string `json:"version"` 13 | } 14 | 15 | // layoutPath returns the path to the layout file 16 | func (s *LocalStore) layoutPath() string { 17 | return filepath.Join(s.rootPath, "layout.json") 18 | } 19 | 20 | // readLayout reads the layout file and returns the layout information 21 | func (s *LocalStore) readLayout() (Layout, error) { 22 | // Version returns the store version 23 | // Read the layout file 24 | layoutData, err := os.ReadFile(s.layoutPath()) 25 | if err != nil { 26 | return Layout{}, fmt.Errorf("read layout path path %q: %w", s.layoutPath(), err) 27 | } 28 | 29 | // Unmarshal the layout 30 | var layout Layout 31 | if err := json.Unmarshal(layoutData, &layout); err != nil { 32 | return Layout{}, fmt.Errorf("unmarshal layout: %w", err) 33 | } 34 | 35 | return layout, nil 36 | } 37 | 38 | // ensureLayout ensure a layout file exists 39 | func (s *LocalStore) ensureLayout() error { 40 | if _, err := os.Stat(s.layoutPath()); os.IsNotExist(err) { 41 | layout := Layout{ 42 | Version: CurrentVersion, 43 | } 44 | if err := s.writeLayout(layout); err != nil { 45 | return fmt.Errorf("initializing layout file: %w", err) 46 | } 47 | } 48 | return nil 49 | } 50 | 51 | // writeLayout write the layout file 52 | func (s *LocalStore) writeLayout(layout Layout) error { 53 | layoutData, err := json.MarshalIndent(layout, "", " ") 54 | if err != nil { 55 | return fmt.Errorf("marshaling layout: %w", err) 56 | } 57 | if err := writeFile(s.layoutPath(), layoutData); err != nil { 58 | return fmt.Errorf("writing layout file: %w", err) 59 | } 60 | return nil 61 | } 62 | -------------------------------------------------------------------------------- /internal/store/manifests.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | 8 | "github.com/google/go-containerregistry/pkg/v1" 9 | ) 10 | 11 | const ( 12 | manifestsDir = "manifests" 13 | ) 14 | 15 | // manifestPath returns the path to the manifest file for the given hash. 16 | func (s *LocalStore) manifestPath(hash v1.Hash) string { 17 | return filepath.Join(s.rootPath, manifestsDir, hash.Algorithm, hash.Hex) 18 | } 19 | 20 | // writeManifest writes the model's manifest to the store 21 | func (s *LocalStore) writeManifest(mdl v1.Image) error { 22 | digest, err := mdl.Digest() 23 | if err != nil { 24 | return fmt.Errorf("get digest: %w", err) 25 | } 26 | rm, err := mdl.RawManifest() 27 | if err != nil { 28 | return fmt.Errorf("get raw manifest: %w", err) 29 | } 30 | return writeFile(s.manifestPath(digest), rm) 31 | } 32 | 33 | // removeManifest removes the manifest file from the store 34 | func (s *LocalStore) removeManifest(hash v1.Hash) error { 35 | return os.Remove(s.manifestPath(hash)) 36 | } 37 | 38 | // writeFile is a wrapper around os.WriteFile that creates any parent directories as needed. 39 | func writeFile(path string, data []byte) error { 40 | if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil { 41 | return fmt.Errorf("create parent directory %q: %w", filepath.Dir(path), err) 42 | } 43 | return os.WriteFile(path, data, 0666) 44 | } 45 | -------------------------------------------------------------------------------- /internal/store/model.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "os" 8 | 9 | v1 "github.com/google/go-containerregistry/pkg/v1" 10 | "github.com/google/go-containerregistry/pkg/v1/partial" 11 | "github.com/google/go-containerregistry/pkg/v1/types" 12 | 13 | mdpartial "github.com/docker/model-distribution/internal/partial" 14 | mdtypes "github.com/docker/model-distribution/types" 15 | ) 16 | 17 | var _ v1.Image = &Model{} 18 | 19 | type Model struct { 20 | rawManifest []byte 21 | manifest *v1.Manifest 22 | rawConfigFile []byte 23 | layers []v1.Layer 24 | tags []string 25 | } 26 | 27 | func (s *LocalStore) newModel(digest v1.Hash, tags []string) (*Model, error) { 28 | rawManifest, err := os.ReadFile(s.manifestPath(digest)) 29 | if err != nil { 30 | return nil, fmt.Errorf("read manifest: %w", err) 31 | } 32 | 33 | manifest, err := v1.ParseManifest(bytes.NewReader(rawManifest)) 34 | if err != nil { 35 | return nil, fmt.Errorf("parse manifest: %w", err) 36 | } 37 | 38 | rawConfigFile, err := os.ReadFile(s.blobPath(manifest.Config.Digest)) 39 | if err != nil { 40 | return nil, fmt.Errorf("read config file: %w", err) 41 | } 42 | 43 | layers := make([]v1.Layer, len(manifest.Layers)) 44 | for i, ld := range manifest.Layers { 45 | layers[i] = &mdpartial.Layer{ 46 | Path: s.blobPath(ld.Digest), 47 | Descriptor: ld, 48 | } 49 | } 50 | 51 | return &Model{ 52 | rawManifest: rawManifest, 53 | manifest: manifest, 54 | rawConfigFile: rawConfigFile, 55 | tags: tags, 56 | layers: layers, 57 | }, err 58 | } 59 | 60 | func (m *Model) Layers() ([]v1.Layer, error) { 61 | return m.layers, nil 62 | } 63 | 64 | func (m *Model) MediaType() (types.MediaType, error) { 65 | return m.manifest.MediaType, nil 66 | } 67 | 68 | func (m *Model) Size() (int64, error) { 69 | return partial.Size(m) 70 | } 71 | 72 | func (m *Model) ConfigName() (v1.Hash, error) { 73 | return partial.ConfigName(m) 74 | } 75 | 76 | func (m *Model) ConfigFile() (*v1.ConfigFile, error) { 77 | return nil, errors.New("invalid for model") 78 | } 79 | 80 | func (m *Model) RawConfigFile() ([]byte, error) { 81 | return m.rawConfigFile, nil 82 | } 83 | 84 | func (m *Model) Digest() (v1.Hash, error) { 85 | return partial.Digest(m) 86 | } 87 | 88 | func (m *Model) Manifest() (*v1.Manifest, error) { 89 | return partial.Manifest(m) 90 | } 91 | 92 | func (m *Model) RawManifest() ([]byte, error) { 93 | return m.rawManifest, nil 94 | } 95 | 96 | func (m *Model) LayerByDigest(hash v1.Hash) (v1.Layer, error) { 97 | for _, l := range m.layers { 98 | d, err := l.Digest() 99 | if err != nil { 100 | return nil, fmt.Errorf("get digest: %w", err) 101 | } 102 | if d == hash { 103 | return l, nil 104 | } 105 | } 106 | return nil, fmt.Errorf("layer with digest %s not found", hash) 107 | } 108 | 109 | func (m *Model) LayerByDiffID(hash v1.Hash) (v1.Layer, error) { 110 | return m.LayerByDigest(hash) 111 | } 112 | 113 | func (m *Model) GGUFPath() (string, error) { 114 | return mdpartial.GGUFPath(m) 115 | } 116 | 117 | func (m *Model) Tags() []string { 118 | return m.tags 119 | } 120 | 121 | func (m *Model) ID() (string, error) { 122 | return mdpartial.ID(m) 123 | } 124 | 125 | func (m *Model) Config() (mdtypes.Config, error) { 126 | return mdpartial.Config(m) 127 | } 128 | 129 | func (m *Model) Descriptor() (mdtypes.Descriptor, error) { 130 | return mdpartial.Descriptor(m) 131 | } 132 | -------------------------------------------------------------------------------- /internal/store/store.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "path/filepath" 8 | 9 | "github.com/docker/model-distribution/internal/progress" 10 | 11 | v1 "github.com/google/go-containerregistry/pkg/v1" 12 | ) 13 | 14 | const ( 15 | // CurrentVersion is the current version of the store layout 16 | CurrentVersion = "1.0.0" 17 | ) 18 | 19 | // LocalStore implements the Store interface for local storage 20 | type LocalStore struct { 21 | rootPath string 22 | } 23 | 24 | // RootPath returns the root path of the store 25 | func (s *LocalStore) RootPath() string { 26 | return s.rootPath 27 | } 28 | 29 | // Options represents options for creating a store 30 | type Options struct { 31 | RootPath string 32 | } 33 | 34 | // New creates a new LocalStore 35 | func New(opts Options) (*LocalStore, error) { 36 | store := &LocalStore{ 37 | rootPath: opts.RootPath, 38 | } 39 | 40 | // Initialize store if it doesn't exist 41 | if err := store.initialize(); err != nil { 42 | return nil, fmt.Errorf("initializing store: %w", err) 43 | } 44 | 45 | return store, nil 46 | } 47 | 48 | // Reset clears all contents of the store directory and reinitializes the store. 49 | // It removes all files and subdirectories within the store's root path, but preserves the root directory itself. 50 | // This allows the method to work correctly when the store directory is a mounted volume (e.g., in Docker CE). 51 | func (s *LocalStore) Reset() error { 52 | entries, err := os.ReadDir(s.rootPath) 53 | if err != nil { 54 | return fmt.Errorf("reading store directory: %w", err) 55 | } 56 | 57 | for _, entry := range entries { 58 | entryPath := filepath.Join(s.rootPath, entry.Name()) 59 | if err := os.RemoveAll(entryPath); err != nil { 60 | return fmt.Errorf("removing %s: %w", entryPath, err) 61 | } 62 | } 63 | 64 | return s.initialize() 65 | } 66 | 67 | // initialize creates the store directory structure if it doesn't exist 68 | func (s *LocalStore) initialize() error { 69 | // Check if layout.json exists, create if not 70 | if err := s.ensureLayout(); err != nil { 71 | return err 72 | } 73 | 74 | // Check if models.json exists, create if not 75 | if _, err := os.Stat(s.indexPath()); os.IsNotExist(err) { 76 | if err := s.writeIndex(Index{ 77 | Models: []IndexEntry{}, 78 | }); err != nil { 79 | return fmt.Errorf("initializing index file: %w", err) 80 | } 81 | } 82 | 83 | return nil 84 | } 85 | 86 | // List lists all models in the store 87 | func (s *LocalStore) List() ([]IndexEntry, error) { 88 | index, err := s.readIndex() 89 | if err != nil { 90 | return nil, fmt.Errorf("reading models index: %w", err) 91 | } 92 | return index.Models, nil 93 | } 94 | 95 | // Delete deletes a model by reference 96 | func (s *LocalStore) Delete(ref string) error { 97 | idx, err := s.readIndex() 98 | if err != nil { 99 | return fmt.Errorf("reading models file: %w", err) 100 | } 101 | model, _, ok := idx.Find(ref) 102 | if !ok { 103 | return ErrModelNotFound 104 | } 105 | 106 | // Remove manifest file 107 | if digest, err := v1.NewHash(model.ID); err != nil { 108 | fmt.Printf("Warning: failed to parse manifest digest %s: %v\n", digest, err) 109 | } else if err := s.removeManifest(digest); err != nil { 110 | fmt.Printf("Warning: failed to remove manifest %q: %v\n", 111 | digest, err, 112 | ) 113 | } 114 | // Before deleting blobs, check if they are referenced by other models 115 | blobRefs := make(map[string]int) 116 | for _, m := range idx.Models { 117 | if m.ID == model.ID { 118 | continue // Skip the model being deleted 119 | } 120 | for _, file := range m.Files { 121 | blobRefs[file]++ 122 | } 123 | } 124 | // Only delete blobs that are not referenced by other models 125 | for _, blobFile := range model.Files { 126 | if blobRefs[blobFile] > 0 { 127 | // Skip deletion if blob is referenced by other models 128 | continue 129 | } 130 | hash, err := v1.NewHash(blobFile) 131 | if err != nil { 132 | fmt.Printf("Warning: failed to parse blob hash %s: %v\n", blobFile, err) 133 | continue 134 | } 135 | if err := s.removeBlob(hash); err != nil { 136 | // Just log the error but don't fail the operation 137 | fmt.Printf("Warning: failed to remove blob %q from store: %v\n", hash.String(), err) 138 | } 139 | } 140 | 141 | idx = idx.Remove(model.ID) 142 | 143 | return s.writeIndex(idx) 144 | } 145 | 146 | // AddTags adds tags to an existing model 147 | func (s *LocalStore) AddTags(ref string, newTags []string) error { 148 | index, err := s.readIndex() 149 | if err != nil { 150 | return fmt.Errorf("reading models file: %w", err) 151 | } 152 | for _, t := range newTags { 153 | index, err = index.Tag(ref, t) 154 | if err != nil { 155 | return fmt.Errorf("tagging model: %w", err) 156 | } 157 | } 158 | 159 | return s.writeIndex(index) 160 | } 161 | 162 | // RemoveTags removes tags from models 163 | func (s *LocalStore) RemoveTags(tags []string) error { 164 | index, err := s.readIndex() 165 | if err != nil { 166 | return fmt.Errorf("reading modelss index: %w", err) 167 | } 168 | for _, tag := range tags { 169 | index = index.UnTag(tag) 170 | } 171 | return s.writeIndex(index) 172 | } 173 | 174 | // Version returns the store version 175 | func (s *LocalStore) Version() string { 176 | layout, err := s.readLayout() 177 | if err != nil { 178 | return "unknown" 179 | } 180 | 181 | return layout.Version 182 | } 183 | 184 | // Write writes a model to the store 185 | func (s *LocalStore) Write(mdl v1.Image, tags []string, w io.Writer) error { 186 | 187 | // Write the config JSON file 188 | if err := s.writeConfigFile(mdl); err != nil { 189 | return fmt.Errorf("writing config file: %w", err) 190 | } 191 | 192 | // Write the blobs 193 | layers, err := mdl.Layers() 194 | if err != nil { 195 | return fmt.Errorf("getting layers: %w", err) 196 | } 197 | 198 | for _, layer := range layers { 199 | var pr *progress.Reporter 200 | var progressChan chan<- v1.Update 201 | if w != nil { 202 | pr = progress.NewProgressReporter(w, progress.PullMsg, layer) 203 | progressChan = pr.Updates() 204 | } 205 | if err := s.writeBlob(layer, progressChan); err != nil { 206 | close(progressChan) 207 | return fmt.Errorf("writing blob: %w", err) 208 | } 209 | if pr != nil { 210 | close(progressChan) 211 | } 212 | } 213 | 214 | // Write the manifest 215 | if err := s.writeManifest(mdl); err != nil { 216 | return fmt.Errorf("writing manifest: %w", err) 217 | } 218 | 219 | // Add the model to the index 220 | idx, err := s.readIndex() 221 | if err != nil { 222 | return fmt.Errorf("reading models: %w", err) 223 | } 224 | entry, err := newEntry(mdl) 225 | if err != nil { 226 | return fmt.Errorf("creating index entry: %w", err) 227 | } 228 | 229 | // Add the model tags 230 | idx = idx.Add(entry) 231 | for _, tag := range tags { 232 | updatedIdx, err := idx.Tag(entry.ID, tag) 233 | if err != nil { 234 | fmt.Printf("Warning: failed to tag model %q with tag %q: %v\n", entry.ID, tag, err) 235 | continue 236 | } 237 | idx = updatedIdx 238 | } 239 | 240 | return s.writeIndex(idx) 241 | } 242 | 243 | // Read reads a model from the store by reference (either tag or ID) 244 | func (s *LocalStore) Read(reference string) (*Model, error) { 245 | models, err := s.List() 246 | if err != nil { 247 | return nil, fmt.Errorf("reading models file: %w", err) 248 | } 249 | 250 | // Find the model by tag 251 | for _, model := range models { 252 | if model.MatchesReference(reference) { 253 | hash, err := v1.NewHash(model.ID) 254 | if err != nil { 255 | return nil, fmt.Errorf("parsing hash: %w", err) 256 | } 257 | return s.newModel(hash, model.Tags) 258 | } 259 | } 260 | 261 | return nil, ErrModelNotFound 262 | } 263 | 264 | // ProgressReader wraps an io.Reader to track reading progress 265 | type ProgressReader struct { 266 | Reader io.Reader 267 | ProgressChan chan<- v1.Update 268 | Total int64 269 | } 270 | 271 | func (pr *ProgressReader) Read(p []byte) (int, error) { 272 | n, err := pr.Reader.Read(p) 273 | pr.Total += int64(n) 274 | if err == io.EOF { 275 | pr.ProgressChan <- v1.Update{Complete: pr.Total} 276 | } else if n > 0 { 277 | select { 278 | case pr.ProgressChan <- v1.Update{Complete: pr.Total}: 279 | default: // if the progress channel is full, it skips sending rather than blocking the Read() call. 280 | } 281 | } 282 | return n, err 283 | } 284 | -------------------------------------------------------------------------------- /internal/store/store_test.go: -------------------------------------------------------------------------------- 1 | package store_test 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/hex" 6 | "errors" 7 | "fmt" 8 | "os" 9 | "path/filepath" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/docker/model-distribution/internal/gguf" 14 | "github.com/docker/model-distribution/internal/mutate" 15 | "github.com/docker/model-distribution/internal/partial" 16 | "github.com/docker/model-distribution/internal/store" 17 | "github.com/docker/model-distribution/types" 18 | ) 19 | 20 | // TestStoreAPI tests the store API directly 21 | func TestStoreAPI(t *testing.T) { 22 | // Create a temporary directory for the test store 23 | tempDir, err := os.MkdirTemp("", "store-api-test") 24 | if err != nil { 25 | t.Fatalf("Failed to create temp directory: %v", err) 26 | } 27 | defer os.RemoveAll(tempDir) 28 | 29 | // Create store 30 | storePath := filepath.Join(tempDir, "api-model-store") 31 | s, err := store.New(store.Options{ 32 | RootPath: storePath, 33 | }) 34 | if err != nil { 35 | t.Fatalf("Failed to create store: %v", err) 36 | } 37 | // Everything must handle directory deletion 38 | if err := os.RemoveAll(storePath); err != nil { 39 | t.Fatalf("Failed to remove store directory: %v", err) 40 | } 41 | 42 | model := newTestModel(t) 43 | layers, err := model.Layers() 44 | if err != nil { 45 | t.Fatalf("Failed to get layers: %v", err) 46 | } 47 | ggufDiffID, err := layers[0].DiffID() 48 | if err != nil { 49 | t.Fatalf("Failed to get diff ID: %v", err) 50 | } 51 | expectedBlobHash := ggufDiffID.String() 52 | 53 | digest, err := model.Digest() 54 | if err != nil { 55 | t.Fatalf("Digest failed: %v", err) 56 | } 57 | if err := s.Write(model, []string{"api-model:latest"}, nil); err != nil { 58 | t.Fatalf("Write failed: %v", err) 59 | } 60 | 61 | t.Run("ReadByTag", func(t *testing.T) { 62 | mdl2, err := s.Read("api-model:latest") 63 | if err != nil { 64 | t.Fatalf("Read failed: %v", err) 65 | } 66 | readDigest, err := mdl2.Digest() 67 | if err != nil { 68 | t.Fatalf("Digest failed: %v", err) 69 | } 70 | if digest != readDigest { 71 | t.Fatalf("Digest mismatch %s != %s", digest.Hex, readDigest.Hex) 72 | } 73 | }) 74 | 75 | t.Run("ReadByID", func(t *testing.T) { 76 | id, err := model.ID() 77 | if err != nil { 78 | t.Fatalf("ID failed: %v", err) 79 | } 80 | mdl2, err := s.Read(id) 81 | if err != nil { 82 | t.Fatalf("Read failed: %v", err) 83 | } 84 | readDigest, err := mdl2.Digest() 85 | if err != nil { 86 | t.Fatalf("Digest failed: %v", err) 87 | } 88 | if digest != readDigest { 89 | t.Fatalf("Digest mismatch %s != %s", digest.Hex, readDigest.Hex) 90 | } 91 | if !containsTag(mdl2.Tags(), "api-model:latest") { 92 | t.Errorf("Expected tag api-model:latest, got %v", mdl2.Tags()) 93 | } 94 | 95 | }) 96 | 97 | t.Run("ReadNotFound", func(t *testing.T) { 98 | if _, err := s.Read("non-existent-model:latest"); !errors.Is(err, store.ErrModelNotFound) { 99 | t.Fatalf("Expected ErrModelNotFound got: %v", err) 100 | } 101 | }) 102 | 103 | // Test List 104 | t.Run("List", func(t *testing.T) { 105 | models, err := s.List() 106 | if err != nil { 107 | t.Fatalf("List failed: %v", err) 108 | } 109 | if len(models) != 1 { 110 | t.Fatalf("Expected 1 model, got %d", len(models)) 111 | } 112 | if !containsTag(models[0].Tags, "api-model:latest") { 113 | t.Errorf("Expected tag api-model:latest, got %v", models[0].Tags) 114 | } 115 | if len(models[0].Files) != 3 { 116 | t.Fatalf("Expected 3 files (gguf, license, config), got %d", len(models[0].Files)) 117 | } 118 | if models[0].Files[0] != expectedBlobHash { 119 | t.Errorf("Expected blob hash %s, got %s", expectedBlobHash, models[0].Files[0]) 120 | } 121 | }) 122 | 123 | // Test AddTags 124 | t.Run("AddTags", func(t *testing.T) { 125 | err := s.AddTags("api-model:latest", []string{"api-v1.0", "api-stable"}) 126 | if err != nil { 127 | t.Fatalf("AddTags failed: %v", err) 128 | } 129 | 130 | // Verify tags were added to model 131 | model, err := s.Read("api-model:latest") 132 | if err != nil { 133 | t.Fatalf("GetByTag failed: %v", err) 134 | } 135 | if !containsTag(model.Tags(), "api-v1.0") || !containsTag(model.Tags(), "api-stable") { 136 | t.Errorf("Expected new tags, got %v", model.Tags()) 137 | } 138 | 139 | // Verify tags were added to list 140 | models, err := s.List() 141 | if err != nil { 142 | t.Fatalf("List failed: %v", err) 143 | } 144 | if len(models) != 1 { 145 | t.Fatalf("Expected 1 model, got %d", len(models)) 146 | } 147 | if len(models[0].Tags) != 3 { 148 | t.Fatalf("Expected 3 tags, got %d", len(models[0].Tags)) 149 | } 150 | }) 151 | 152 | // Test RemoveTags 153 | t.Run("RemoveTags", func(t *testing.T) { 154 | err := s.RemoveTags([]string{"api-model:api-v1.0"}) 155 | if err != nil { 156 | t.Fatalf("RemoveTags failed: %v", err) 157 | } 158 | 159 | // Verify tag was removed from list 160 | models, err := s.List() 161 | if err != nil { 162 | t.Fatalf("List failed: %v", err) 163 | } 164 | for _, model := range models { 165 | if containsTag(model.Tags, "api-model:api-v1.0") { 166 | t.Errorf("Tag should have been removed, but still present: %v", model.Tags) 167 | } 168 | if model.Files[0] != expectedBlobHash { 169 | t.Errorf("Expected blob hash %s, got %s", expectedBlobHash, model.Files[0]) 170 | } 171 | } 172 | 173 | // Verify read by tag fails 174 | if _, err = s.Read("api-model:api-v1.0"); err == nil { 175 | t.Errorf("Expected read error after tag removal, got nil") 176 | } 177 | }) 178 | 179 | // Test Delete 180 | t.Run("Delete", func(t *testing.T) { 181 | err := s.Delete("api-model:latest") 182 | if err != nil { 183 | t.Fatalf("Delete failed: %v", err) 184 | } 185 | 186 | // Verify model with that tag is gone 187 | _, err = s.Read("api-model:latest") 188 | if err == nil { 189 | t.Errorf("Expected error after deletion, got nil") 190 | } 191 | }) 192 | 193 | // Test Delete Non Existent Model 194 | t.Run("Delete", func(t *testing.T) { 195 | err := s.Delete("non-existent-model:latest") 196 | if !errors.Is(err, store.ErrModelNotFound) { 197 | t.Fatalf("Expected ErrModelNotFound, got %v", err) 198 | } 199 | }) 200 | 201 | // Test that Delete removes the blob files 202 | t.Run("DeleteRemovesBlobs", func(t *testing.T) { 203 | // Create a new model with unique content 204 | modelContent := []byte("unique content for blob deletion test") 205 | modelPath := filepath.Join(tempDir, "blob-deletion-test.gguf") 206 | if err := os.WriteFile(modelPath, modelContent, 0644); err != nil { 207 | t.Fatalf("Failed to create test model file: %v", err) 208 | } 209 | 210 | // Calculate the blob hash to find it later 211 | hash := sha256.Sum256(modelContent) 212 | blobHash := hex.EncodeToString(hash[:]) 213 | 214 | // Add model to store with a unique tag 215 | mdl, err := gguf.NewModel(modelPath) 216 | if err != nil { 217 | t.Fatalf("Create model failed: %v", err) 218 | } 219 | 220 | if err := s.Write(mdl, []string{"blob-test:latest", "blob-test:other"}, nil); err != nil { 221 | t.Fatalf("Write failed: %v", err) 222 | } 223 | 224 | // Get the blob path 225 | blobPath := filepath.Join(storePath, "blobs", "sha256", blobHash) 226 | 227 | // Verify the blob exists on disk before deletion 228 | if _, err := os.Stat(blobPath); err != nil { 229 | t.Fatalf("Failed to stat blob at path '%s': %v", blobPath, err) 230 | } 231 | 232 | // Get the manifest path 233 | digest, err := mdl.Digest() 234 | if err != nil { 235 | t.Fatalf("Failed to get digest: %v", err) 236 | } 237 | 238 | // Verify the model manifest exists 239 | manifestPath := filepath.Join(storePath, "manifests", "sha256", digest.Hex) 240 | if _, err := os.Stat(manifestPath); err != nil { 241 | t.Fatalf("Failed to stat manifest at path '%s': %v", manifestPath, err) 242 | } 243 | 244 | // Delete the model 245 | if err := s.Delete("blob-test:latest"); err != nil { 246 | t.Fatalf("Delete failed: %v", err) 247 | } 248 | 249 | // Verify the blob no longer exists on disk after deletion 250 | if _, err := os.Stat(blobPath); !os.IsNotExist(err) { 251 | t.Errorf("Blob file still exists after deletion: %s", blobPath) 252 | } 253 | 254 | // Verify the manifest no longer exists on disk after deletion 255 | if _, err := os.Stat(manifestPath); !os.IsNotExist(err) { 256 | t.Errorf("Manifest file still exists after deletion: %s", blobPath) 257 | } 258 | }) 259 | 260 | // Test that shared blobs between different models are not deleted 261 | t.Run("SharedBlobsPreservation", func(t *testing.T) { 262 | // Create a model file with content that will be shared 263 | sharedContent := []byte("shared content for multiple models test") 264 | sharedModelPath := filepath.Join(tempDir, "shared-model.gguf") 265 | if err := os.WriteFile(sharedModelPath, sharedContent, 0644); err != nil { 266 | t.Fatalf("Failed to create shared model file: %v", err) 267 | } 268 | 269 | // Calculate the blob hash to find it later 270 | hash := sha256.Sum256(sharedContent) 271 | blobHash := hex.EncodeToString(hash[:]) 272 | expectedBlobDigest := fmt.Sprintf("sha256:%s", blobHash) 273 | 274 | // Create first model with the shared content 275 | model1, err := gguf.NewModel(sharedModelPath) 276 | if err != nil { 277 | t.Fatalf("Create first model failed: %v", err) 278 | } 279 | 280 | // Write the first model 281 | if err := s.Write(model1, []string{"shared-model-1:latest"}, nil); err != nil { 282 | t.Fatalf("Write first model failed: %v", err) 283 | } 284 | 285 | // Create second model with the same shared content 286 | model2, err := gguf.NewModel(sharedModelPath) 287 | if err != nil { 288 | t.Fatalf("Create second model failed: %v", err) 289 | } 290 | 291 | // Write the second model 292 | if err := s.Write(model2, []string{"shared-model-2:latest"}, nil); err != nil { 293 | t.Fatalf("Write second model failed: %v", err) 294 | } 295 | 296 | // Get the blob path 297 | blobPath := filepath.Join(storePath, "blobs", "sha256", blobHash) 298 | 299 | // Get the config blob paths (not shared) 300 | name1, err := model1.ConfigName() 301 | if err != nil { 302 | t.Fatalf("Failed to get config name: %v", err) 303 | } 304 | config1Path := filepath.Join(storePath, "blobs", "sha256", name1.Hex) 305 | name2, err := model2.ConfigName() 306 | if err != nil { 307 | t.Fatalf("Failed to get config name: %v", err) 308 | } 309 | config2Path := filepath.Join(storePath, "blobs", "sha256", name2.Hex) 310 | 311 | // Verify the blobs exists on disk 312 | if _, err := os.Stat(blobPath); os.IsNotExist(err) { 313 | t.Fatalf("Shared blob file doesn't exist: %s", blobPath) 314 | } 315 | if _, err := os.Stat(config1Path); os.IsNotExist(err) { 316 | t.Fatalf("Model 1 config blob file doesn't exist: %s", config1Path) 317 | } 318 | if _, err := os.Stat(config2Path); os.IsNotExist(err) { 319 | t.Fatalf("Model 2 config blob file doesn't exist: %s", config2Path) 320 | } 321 | 322 | // Delete the first model 323 | if err := s.Delete("shared-model-1:latest"); err != nil { 324 | t.Fatalf("Delete first model failed: %v", err) 325 | } 326 | 327 | // Verify the shared blob still exists on disk after deleting the first model 328 | if _, err := os.Stat(blobPath); os.IsNotExist(err) { 329 | t.Errorf("Shared blob file was incorrectly removed: %s", blobPath) 330 | } 331 | 332 | // Verify the first model config blob does not exist 333 | if _, err := os.Stat(config1Path); !os.IsNotExist(err) { 334 | t.Errorf("Model 1 config blob should have been removed: %s", config1Path) 335 | } 336 | 337 | // Verify the second model config blob still exists 338 | if _, err := os.Stat(blobPath); os.IsNotExist(err) { 339 | t.Errorf("Model 2 config blob file was incorrectly removed: %s", config2Path) 340 | } 341 | 342 | // Verify the second model is still in the index 343 | models, err := s.List() 344 | if err != nil { 345 | t.Fatalf("List failed: %v", err) 346 | } 347 | 348 | var foundModel bool 349 | for _, model := range models { 350 | if containsTag(model.Tags, "shared-model-2:latest") { 351 | foundModel = true 352 | // Verify the blob is still associated with the model 353 | if len(model.Files) != 2 { 354 | t.Errorf("Expected 2 blobs, got %d", len(model.Files)) 355 | } 356 | if model.Files[0] != expectedBlobDigest { 357 | t.Errorf("Expected blob %s, got %v", expectedBlobDigest, model.Files) 358 | } 359 | if model.Files[1] != name2.String() { 360 | t.Errorf("Expected blob %s, got %v", expectedBlobDigest, model.Files) 361 | } 362 | break 363 | } 364 | } 365 | 366 | if !foundModel { 367 | t.Errorf("Second model not found after deleting first model") 368 | } 369 | 370 | // Delete the second model 371 | if err := s.Delete("shared-model-2:latest"); err != nil { 372 | t.Fatalf("Delete second model failed: %v", err) 373 | } 374 | 375 | // Now the blob should be deleted since no models reference it 376 | if _, err := os.Stat(blobPath); !os.IsNotExist(err) { 377 | t.Errorf("Shared blob file still exists after deleting all referencing models: %s", blobPath) 378 | } 379 | }) 380 | } 381 | 382 | // TestIncompleteFileHandling tests that files are created with .incomplete suffix and renamed on success 383 | func TestIncompleteFileHandling(t *testing.T) { 384 | // Create a temporary directory for the test store 385 | tempDir, err := os.MkdirTemp("", "incomplete-file-test") 386 | if err != nil { 387 | t.Fatalf("Failed to create temp directory: %v", err) 388 | } 389 | defer os.RemoveAll(tempDir) 390 | 391 | // Create a temporary model file with known content 392 | modelContent := []byte("test model content for incomplete file test") 393 | modelPath := filepath.Join(tempDir, "incomplete-test-model.gguf") 394 | if err := os.WriteFile(modelPath, modelContent, 0644); err != nil { 395 | t.Fatalf("Failed to create test model file: %v", err) 396 | } 397 | 398 | // Calculate expected blob hash 399 | hash := sha256.Sum256(modelContent) 400 | blobHash := hex.EncodeToString(hash[:]) 401 | 402 | // Create store 403 | storePath := filepath.Join(tempDir, "incomplete-model-store") 404 | s, err := store.New(store.Options{ 405 | RootPath: storePath, 406 | }) 407 | if err != nil { 408 | t.Fatalf("Failed to create store: %v", err) 409 | } 410 | 411 | // Create the blobs directory 412 | blobsDir := filepath.Join(storePath, "blobs", "sha256") 413 | if err := os.MkdirAll(blobsDir, 0755); err != nil { 414 | t.Fatalf("Failed to create blobs directory: %v", err) 415 | } 416 | 417 | // Create an incomplete file directly 418 | incompleteFilePath := filepath.Join(blobsDir, blobHash+".incomplete") 419 | if err := os.WriteFile(incompleteFilePath, modelContent, 0644); err != nil { 420 | t.Fatalf("Failed to create incomplete file: %v", err) 421 | } 422 | 423 | // Verify the incomplete file exists 424 | if _, err := os.Stat(incompleteFilePath); os.IsNotExist(err) { 425 | t.Fatalf("Failed to create test .incomplete file") 426 | } 427 | 428 | // Create a model 429 | mdl, err := gguf.NewModel(modelPath) 430 | if err != nil { 431 | t.Fatalf("Create model failed: %v", err) 432 | } 433 | 434 | // Write the model - this should clean up the incomplete file and create the final file 435 | if err := s.Write(mdl, []string{"incomplete-test:latest"}, nil); err != nil { 436 | t.Fatalf("Write failed: %v", err) 437 | } 438 | 439 | // Verify that no .incomplete files remain after successful write 440 | files, err := os.ReadDir(blobsDir) 441 | if err != nil { 442 | t.Fatalf("Failed to read blobs directory: %v", err) 443 | } 444 | 445 | for _, file := range files { 446 | if strings.HasSuffix(file.Name(), ".incomplete") { 447 | t.Errorf("Found .incomplete file after successful write: %s", file.Name()) 448 | } 449 | } 450 | 451 | // Verify the blob exists with its final name 452 | blobPath := filepath.Join(blobsDir, blobHash) 453 | if _, err := os.Stat(blobPath); os.IsNotExist(err) { 454 | t.Errorf("Blob file doesn't exist at expected path: %s", blobPath) 455 | } 456 | } 457 | 458 | // Helper function to check if a tag is in a slice of tags 459 | func containsTag(tags []string, tag string) bool { 460 | for _, t := range tags { 461 | if t == tag { 462 | return true 463 | } 464 | } 465 | return false 466 | } 467 | 468 | func newTestModel(t *testing.T) types.ModelArtifact { 469 | var mdl types.ModelArtifact 470 | var err error 471 | 472 | mdl, err = gguf.NewModel(filepath.Join("testdata", "dummy.gguf")) 473 | if err != nil { 474 | t.Fatalf("failed to create model from gguf file: %v", err) 475 | } 476 | licenseLayer, err := partial.NewLayer(filepath.Join("testdata", "license.txt"), types.MediaTypeLicense) 477 | if err != nil { 478 | t.Fatalf("failed to create license layer: %v", err) 479 | } 480 | mdl = mutate.AppendLayers(mdl, licenseLayer) 481 | return mdl 482 | } 483 | -------------------------------------------------------------------------------- /internal/store/testdata/dummy.gguf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docker/model-distribution/b377026db94a863fceaec1a3088f924e80825888/internal/store/testdata/dummy.gguf -------------------------------------------------------------------------------- /internal/store/testdata/license.txt: -------------------------------------------------------------------------------- 1 | FAKE LICENSE 2 | -------------------------------------------------------------------------------- /internal/utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "net/url" 8 | "os" 9 | "strings" 10 | ) 11 | 12 | // FormatBytes converts bytes to a human-readable string with appropriate unit 13 | func FormatBytes(bytes int) string { 14 | size := float64(bytes) 15 | var unit string 16 | switch { 17 | case size >= 1<<30: 18 | size /= 1 << 30 19 | unit = "GB" 20 | case size >= 1<<20: 21 | size /= 1 << 20 22 | unit = "MB" 23 | case size >= 1<<10: 24 | size /= 1 << 10 25 | unit = "KB" 26 | default: 27 | unit = "bytes" 28 | } 29 | return fmt.Sprintf("%.2f %s", size, unit) 30 | } 31 | 32 | // ShowProgress displays a progress bar for data transfer operations 33 | func ShowProgress(operation string, progressChan chan int64, totalSize int64) { 34 | for bytesComplete := range progressChan { 35 | if totalSize > 0 { 36 | mbComplete := float64(bytesComplete) / (1024 * 1024) 37 | mbTotal := float64(totalSize) / (1024 * 1024) 38 | fmt.Printf("\r%s: %.2f MB / %.2f MB", operation, mbComplete, mbTotal) 39 | } else { 40 | mb := float64(bytesComplete) / (1024 * 1024) 41 | fmt.Printf("\r%s: %.2f MB", operation, mb) 42 | } 43 | } 44 | fmt.Println() // Move to new line after progress 45 | } 46 | 47 | // ReadContent reads content from a local file or URL 48 | func ReadContent(source string) ([]byte, error) { 49 | // Check if the source is a URL 50 | if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") { 51 | // Parse the URL 52 | _, err := url.Parse(source) 53 | if err != nil { 54 | return nil, fmt.Errorf("invalid URL: %v", err) 55 | } 56 | 57 | // Make HTTP request 58 | resp, err := http.Get(source) 59 | if err != nil { 60 | return nil, fmt.Errorf("failed to download file: %v", err) 61 | } 62 | defer resp.Body.Close() 63 | 64 | if resp.StatusCode != http.StatusOK { 65 | return nil, fmt.Errorf("failed to download file: HTTP status %d", resp.StatusCode) 66 | } 67 | 68 | // Create progress reader 69 | contentLength := resp.ContentLength 70 | progressChan := make(chan int64, 100) 71 | 72 | // Start progress reporting goroutine 73 | go ShowProgress("Downloading", progressChan, contentLength) 74 | 75 | // Create a wrapper reader to track progress 76 | progressReader := &ProgressReader{ 77 | Reader: resp.Body, 78 | ProgressChan: progressChan, 79 | } 80 | 81 | // Read the content 82 | content, err := io.ReadAll(progressReader) 83 | close(progressChan) 84 | return content, err 85 | } 86 | 87 | // If not a URL, treat as local file path 88 | return os.ReadFile(source) 89 | } 90 | 91 | // ProgressReader wraps an io.Reader to track reading progress 92 | type ProgressReader struct { 93 | Reader io.Reader 94 | ProgressChan chan int64 95 | Total int64 96 | } 97 | 98 | func (pr *ProgressReader) Read(p []byte) (int, error) { 99 | n, err := pr.Reader.Read(p) 100 | if n > 0 { 101 | pr.Total += int64(n) 102 | pr.ProgressChan <- pr.Total 103 | } 104 | return n, err 105 | } 106 | -------------------------------------------------------------------------------- /registry/artifact.go: -------------------------------------------------------------------------------- 1 | package registry 2 | 3 | import ( 4 | "github.com/docker/model-distribution/internal/partial" 5 | "github.com/docker/model-distribution/types" 6 | v1 "github.com/google/go-containerregistry/pkg/v1" 7 | ) 8 | 9 | var _ types.ModelArtifact = &artifact{} 10 | 11 | type artifact struct { 12 | v1.Image 13 | } 14 | 15 | func (a *artifact) ID() (string, error) { 16 | return partial.ID(a) 17 | } 18 | 19 | func (a *artifact) Config() (types.Config, error) { 20 | return partial.Config(a) 21 | } 22 | 23 | func (a *artifact) Descriptor() (types.Descriptor, error) { 24 | return partial.Descriptor(a) 25 | } 26 | -------------------------------------------------------------------------------- /registry/client.go: -------------------------------------------------------------------------------- 1 | package registry 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "strings" 9 | 10 | "github.com/docker/model-distribution/internal/progress" 11 | "github.com/docker/model-distribution/types" 12 | "github.com/google/go-containerregistry/pkg/authn" 13 | "github.com/google/go-containerregistry/pkg/name" 14 | "github.com/google/go-containerregistry/pkg/v1/remote" 15 | ) 16 | 17 | const ( 18 | DefaultUserAgent = "model-distribution" 19 | ) 20 | 21 | var ( 22 | DefaultTransport = remote.DefaultTransport 23 | ) 24 | 25 | type Client struct { 26 | transport http.RoundTripper 27 | userAgent string 28 | keychain authn.Keychain 29 | } 30 | 31 | type ClientOption func(*Client) 32 | 33 | func WithTransport(transport http.RoundTripper) ClientOption { 34 | return func(c *Client) { 35 | if transport != nil { 36 | c.transport = transport 37 | } 38 | } 39 | } 40 | 41 | func WithUserAgent(userAgent string) ClientOption { 42 | return func(c *Client) { 43 | if userAgent != "" { 44 | c.userAgent = userAgent 45 | } 46 | } 47 | } 48 | 49 | func NewClient(opts ...ClientOption) *Client { 50 | client := &Client{ 51 | transport: remote.DefaultTransport, 52 | userAgent: DefaultUserAgent, 53 | keychain: authn.DefaultKeychain, 54 | } 55 | for _, opt := range opts { 56 | opt(client) 57 | } 58 | return client 59 | } 60 | 61 | func (c *Client) Model(ctx context.Context, reference string) (types.ModelArtifact, error) { 62 | // Parse the reference 63 | ref, err := name.ParseReference(reference) 64 | if err != nil { 65 | return nil, NewReferenceError(reference, err) 66 | } 67 | 68 | // Return the artifact at the given reference 69 | remoteImg, err := remote.Image(ref, 70 | remote.WithContext(ctx), 71 | remote.WithAuthFromKeychain(c.keychain), 72 | remote.WithTransport(c.transport), 73 | remote.WithUserAgent(c.userAgent), 74 | ) 75 | if err != nil { 76 | errStr := err.Error() 77 | if strings.Contains(errStr, "UNAUTHORIZED") { 78 | return nil, NewRegistryError(reference, "UNAUTHORIZED", "Authentication required for this model", err) 79 | } 80 | if strings.Contains(errStr, "MANIFEST_UNKNOWN") { 81 | return nil, NewRegistryError(reference, "MANIFEST_UNKNOWN", "Model not found", err) 82 | } 83 | if strings.Contains(errStr, "NAME_UNKNOWN") { 84 | return nil, NewRegistryError(reference, "NAME_UNKNOWN", "Repository not found", err) 85 | } 86 | return nil, NewRegistryError(reference, "UNKNOWN", err.Error(), err) 87 | } 88 | return &artifact{remoteImg}, nil 89 | } 90 | 91 | type Target struct { 92 | reference name.Reference 93 | transport http.RoundTripper 94 | userAgent string 95 | keychain authn.Keychain 96 | } 97 | 98 | func (c *Client) NewTarget(tag string) (*Target, error) { 99 | ref, err := name.NewTag(tag) 100 | if err != nil { 101 | return nil, fmt.Errorf("invalid tag: %q: %w", tag, err) 102 | } 103 | return &Target{ 104 | reference: ref, 105 | transport: c.transport, 106 | userAgent: c.userAgent, 107 | keychain: c.keychain, 108 | }, nil 109 | } 110 | 111 | func (t *Target) Write(ctx context.Context, model types.ModelArtifact, progressWriter io.Writer) error { 112 | pr := progress.NewProgressReporter(progressWriter, progress.PushMsg, nil) 113 | defer pr.Wait() 114 | 115 | if err := remote.Write(t.reference, model, 116 | remote.WithContext(ctx), 117 | remote.WithAuthFromKeychain(t.keychain), 118 | remote.WithTransport(t.transport), 119 | remote.WithUserAgent(t.userAgent), 120 | remote.WithProgress(pr.Updates()), 121 | ); err != nil { 122 | return fmt.Errorf("write to registry %q: %w", t.reference.String(), err) 123 | } 124 | return nil 125 | } 126 | -------------------------------------------------------------------------------- /registry/errors.go: -------------------------------------------------------------------------------- 1 | package registry 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/docker/model-distribution/types" 8 | ) 9 | 10 | var ( 11 | ErrInvalidReference = errors.New("invalid model reference") 12 | ErrModelNotFound = errors.New("model not found") 13 | ErrUnauthorized = errors.New("unauthorized access to model") 14 | ErrUnsupportedMediaType = errors.New(fmt.Sprintf( 15 | "client supports only models of type %q and older - try upgrading", 16 | types.MediaTypeModelConfigV01, 17 | )) 18 | ) 19 | 20 | // ReferenceError represents an error related to an invalid model reference 21 | type ReferenceError struct { 22 | Reference string 23 | Err error 24 | } 25 | 26 | func (e *ReferenceError) Error() string { 27 | return fmt.Sprintf("invalid model reference %q: %v", e.Reference, e.Err) 28 | } 29 | 30 | func (e *ReferenceError) Unwrap() error { 31 | return e.Err 32 | } 33 | 34 | // Is implements error matching for ReferenceError 35 | func (e *ReferenceError) Is(target error) bool { 36 | return target == ErrInvalidReference 37 | } 38 | 39 | // Error represents an error returned by an OCI registry 40 | type Error struct { 41 | Reference string 42 | // Code should be one of error codes defined in the distribution spec 43 | // (see https://github.com/opencontainers/distribution-spec/blob/583e014d15418d839d67f68152bc2c83821770e0/spec.md#error-codes) 44 | Code string 45 | Message string 46 | Err error 47 | } 48 | 49 | func (e Error) Error() string { 50 | return fmt.Sprintf("failed to pull model %q: %s - %s", e.Reference, e.Code, e.Message) 51 | } 52 | 53 | func (e Error) Unwrap() error { 54 | return e.Err 55 | } 56 | 57 | // Is implements error matching for Error 58 | func (e Error) Is(target error) bool { 59 | switch target { 60 | case ErrModelNotFound: 61 | return e.Code == "MANIFEST_UNKNOWN" || e.Code == "NAME_UNKNOWN" 62 | case ErrUnauthorized: 63 | return e.Code == "UNAUTHORIZED" 64 | default: 65 | return false 66 | } 67 | } 68 | 69 | // NewReferenceError creates a new ReferenceError 70 | func NewReferenceError(reference string, err error) error { 71 | return &ReferenceError{ 72 | Reference: reference, 73 | Err: err, 74 | } 75 | } 76 | 77 | // NewRegistryError creates a new Error 78 | func NewRegistryError(reference, code, message string, err error) error { 79 | return &Error{ 80 | Reference: reference, 81 | Code: code, 82 | Message: message, 83 | Err: err, 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /scripts/model-push/README.md: -------------------------------------------------------------------------------- 1 | # Model Push Script 2 | 3 | This script automates the process of converting models from Hugging Face and pushing them to a container registry using the model-distribution-tool. 4 | 5 | ## Prerequisites 6 | 7 | - Docker installed and running 8 | - Hugging Face account and API token 9 | - model-distribution-tool built (run `make build` in the project root) 10 | 11 | ## Usage 12 | 13 | ```bash 14 | ./push-model.sh [OPTIONS] 15 | ``` 16 | 17 | ### Options 18 | 19 | - `--hf-model HF_NAME/HF_REPO`: Hugging Face model name/repository (required) 20 | - `--repository USER/REPOSITORY`: Target repository (required) 21 | - `--weights WEIGHTS`: Model weights tag (required) 22 | - `--licenses PATH`: Paths to license files (optional, list, default: ./assets/license.txt) 23 | - `--models-dir PATH`: Path to store models (default: ./models) 24 | - `--hf-token TOKEN`: Hugging Face token (required) 25 | - `--quantization TYPE`: Quantization type to use (default: Q4_K_M) 26 | - `--skip-f16`: Skip pushing the F16 (non-quantized) version 27 | - `--help`: Display help message 28 | 29 | ### Quantization Types 30 | 31 | The following quantization types are supported: 32 | 33 | - `Q4_0`, `Q4_1`: 4-bit quantization (different methods) 34 | - `Q5_0`, `Q5_1`: 5-bit quantization (different methods) 35 | - `Q8_0`, `Q8_1`: 8-bit quantization (different methods) 36 | - `Q2_K`, `Q3_K_S`, `Q3_K_M`, `Q3_K_L`: K-quant with 2-3 bits 37 | - `Q4_K_S`, `Q4_K_M`: K-quant with 4 bits (small and medium, Q4_K_M is default) 38 | - `Q5_K_S`, `Q5_K_M`: K-quant with 5 bits (small and medium) 39 | - `Q6_K`: K-quant with 6 bits 40 | - `F16`: 16-bit floating point (no quantization) 41 | - `F32`: 32-bit floating point (no quantization) 42 | 43 | ### Examples 44 | 45 | Basic usage with default quantization (Q4_K_M): 46 | ```bash 47 | ./push-model.sh \ 48 | --hf-model meta-llama/Llama-2-7b-chat-hf \ 49 | --repository myregistry.com/models/llama \ 50 | --weights 7B \ 51 | --hf-token hf_xxx \ 52 | --licenses ./assets/license.txt 53 | ``` 54 | 55 | Using a specific quantization type: 56 | ```bash 57 | ./push-model.sh \ 58 | --hf-model meta-llama/Llama-2-7b-chat-hf \ 59 | --repository myregistry.com/models/llama \ 60 | --weights 7B \ 61 | --hf-token hf_xxx \ 62 | --quantization Q8_0 \ 63 | --licenses ./assets/license.txt 64 | ``` 65 | 66 | Skip pushing the F16 version: 67 | ```bash 68 | ./push-model.sh \ 69 | --hf-model meta-llama/Llama-2-7b-chat-hf \ 70 | --repository myregistry.com/models/llama \ 71 | --weights 7B \ 72 | --hf-token hf_xxx \ 73 | --skip-f16 \ 74 | --licenses ./assets/license.txt 75 | ``` 76 | 77 | Push only the F16 version (no quantization): 78 | ```bash 79 | ./push-model.sh \ 80 | --hf-model meta-llama/Llama-2-7b-chat-hf \ 81 | --repository myregistry.com/models/llama \ 82 | --weights 7B \ 83 | --hf-token hf_xxx \ 84 | --quantization F16 \ 85 | --licenses ./assets/license.txt 86 | ``` 87 | 88 | Using multiple license files: 89 | ```bash 90 | ./push-model.sh \ 91 | --hf-model meta-llama/Llama-2-7b-chat-hf \ 92 | --repository myregistry.com/models/llama \ 93 | --weights 7B \ 94 | --hf-token hf_xxx \ 95 | --licenses ./assets/license1.txt,./assets/license2.txt,./assets/license3.txt 96 | ``` 97 | 98 | ## Process 99 | 100 | The script performs the following steps: 101 | 102 | 1. Runs a Docker container to convert the model from Hugging Face to GGUF format with the specified quantization 103 | 2. Verifies both the quantized model and F16 model files were created successfully 104 | 3. Checks for the license file 105 | 4. Pushes the quantized model to the specified repository 106 | 5. Pushes the F16 model to the same repository with a "-F16" suffix in the tag (unless skipped) 107 | 108 | ## Notes 109 | 110 | - The script creates the models directory if it doesn't exist 111 | - By default, it pushes both the quantized version and the F16 version of the model 112 | - The F16 version is pushed with a "-F16" suffix added to the tag 113 | - If any license file is not found, the script will display an error and exit 114 | - You can specify multiple license files by separating them with commas: `--licenses file1.txt,file2.txt` 115 | - You can skip pushing the F16 version with the `--skip-f16` flag 116 | - If you specify `--quantization F16`, only the F16 version will be pushed 117 | - The script will exit with an error if any critical step fails (Docker not installed, model conversion fails, etc.) 118 | -------------------------------------------------------------------------------- /scripts/model-push/llama-converter/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/ggml-org/llama.cpp:full-b5227 2 | 3 | # Install git-lfs to handle Hugging Face repositories 4 | RUN apt-get update && apt-get install -y git-lfs && \ 5 | git lfs install 6 | 7 | # Copy the modified entrypoint script 8 | COPY entrypoint.sh /entrypoint.sh 9 | RUN chmod +x /entrypoint.sh 10 | 11 | # Allow passing Hugging Face API Token as an environment variable 12 | ENV HUGGINGFACE_TOKEN="" 13 | 14 | ENTRYPOINT ["/entrypoint.sh"] 15 | -------------------------------------------------------------------------------- /scripts/model-push/llama-converter/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Default quantization type 5 | QUANTIZATION="Q4_K_M" 6 | 7 | # Parse arguments 8 | while [[ $# -gt 0 ]]; do 9 | case "$1" in 10 | '--from-hf'|'-f') 11 | FROM_HF=true 12 | shift 13 | HF_REPO="$1" 14 | shift 15 | ;; 16 | '--quantization'|'-q') 17 | shift 18 | QUANTIZATION="$1" 19 | shift 20 | ;; 21 | *) 22 | # Pass other arguments to the original entrypoint 23 | EXTRA_ARGS+=("$1") 24 | shift 25 | ;; 26 | esac 27 | done 28 | 29 | # Validate quantization type 30 | VALID_QUANTIZATIONS=("Q4_0" "Q4_1" "Q5_0" "Q5_1" "Q8_0" "Q8_1" "Q2_K" "Q3_K_S" "Q3_K_M" "Q3_K_L" "Q4_K_S" "Q4_K_M" "Q5_K_S" "Q5_K_M" "Q6_K" "F16" "F32") 31 | VALID=false 32 | for q in "${VALID_QUANTIZATIONS[@]}"; do 33 | if [[ "$q" == "$QUANTIZATION" ]]; then 34 | VALID=true 35 | break 36 | fi 37 | done 38 | 39 | if [[ "$VALID" == "false" ]]; then 40 | echo "Error: Invalid quantization type: $QUANTIZATION" 41 | echo "Valid options are: ${VALID_QUANTIZATIONS[*]}" 42 | exit 1 43 | fi 44 | 45 | if [[ "$FROM_HF" == "true" ]]; then 46 | TARGET_DIR="/models/$(basename $HF_REPO)" 47 | 48 | if [[ -z "$HUGGINGFACE_TOKEN" ]]; then 49 | echo "Error: Hugging Face token is missing. Set HUGGINGFACE_TOKEN environment variable." 50 | exit 1 51 | fi 52 | 53 | if [[ -d "$TARGET_DIR" ]]; then 54 | echo "Repository already cloned at $TARGET_DIR. Skipping cloning." 55 | else 56 | echo "Cloning Hugging Face repository: $HF_REPO into $TARGET_DIR..." 57 | git lfs install 58 | git clone --depth=1 "https://user:$HUGGINGFACE_TOKEN@huggingface.co/$HF_REPO" "$TARGET_DIR" 59 | fi 60 | 61 | echo "Running conversion..." 62 | python3 ./convert_hf_to_gguf.py "$TARGET_DIR" 63 | 64 | # Find the correct *-F16.gguf file 65 | GGUF_FILE=$(find "$TARGET_DIR" -type f -name "*-F16.gguf" | head -n 1) 66 | 67 | if [[ -z "$GGUF_FILE" ]]; then 68 | echo "Error: No F16 GGUF file found in $TARGET_DIR." 69 | exit 1 70 | fi 71 | 72 | # Skip quantization if F16 is requested 73 | if [[ "$QUANTIZATION" == "F16" ]]; then 74 | echo "F16 format requested, skipping quantization..." 75 | else 76 | echo "Converting to $QUANTIZATION quantization..." 77 | ./llama-quantize "$GGUF_FILE" "$QUANTIZATION" 78 | fi 79 | else 80 | # If not processing from Hugging Face, pass all arguments to the original entrypoint 81 | exec ./entrypoint.sh "${EXTRA_ARGS[@]}" 82 | fi 83 | -------------------------------------------------------------------------------- /scripts/model-push/push-model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Default values 5 | DEFAULT_LICENSE_PATH="$(pwd)/assets/license.txt" 6 | DEFAULT_MODELS_DIR="$(pwd)/models" 7 | DEFAULT_QUANTIZATION="Q4_K_M" 8 | SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 9 | PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" 10 | 11 | # Function to display usage information 12 | usage() { 13 | echo "Usage: $0 [OPTIONS]" 14 | echo 15 | echo "Options:" 16 | echo " --hf-model HF_NAME/HF_REPO Hugging Face model name/repository (required)" 17 | echo " --repository USER/REPOSITORY Target repository (required)" 18 | echo " --weights MODEL_WEIGHTS Model weights tag (required)" 19 | echo " --licenses PATH[,PATH,...] Paths to license files (comma-separated, required)" 20 | echo " --models-dir PATH Path to store models (default: ${DEFAULT_MODELS_DIR})" 21 | echo " --hf-token TOKEN Hugging Face token (required)" 22 | echo " --quantization TYPE Quantization type to use (default: ${DEFAULT_QUANTIZATION})" 23 | echo " --skip-f16 Skip pushing the F16 (non-quantized) version" 24 | echo " --help Display this help message" 25 | echo 26 | echo "Available quantization types:" 27 | echo " Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, Q2_K, Q3_K_S, Q3_K_M, Q3_K_L," 28 | echo " Q4_K_S, Q4_K_M (default), Q5_K_S, Q5_K_M, Q6_K, F16, F32" 29 | echo 30 | echo "Examples:" 31 | echo " $0 --hf-model meta-llama/Llama-2-7b-chat-hf --repository myregistry.com/models/llama --weights 7B --licenses ./assets/license.txt --hf-token hf_xxx" 32 | echo " $0 --hf-model meta-llama/Llama-2-7b-chat-hf --repository myregistry.com/models/llama --weights 7B --licenses ./assets/license1.txt,./assets/license2.txt --hf-token hf_xxx" 33 | exit 1 34 | } 35 | 36 | # Parse command line arguments 37 | while [[ $# -gt 0 ]]; do 38 | case "$1" in 39 | --hf-model) 40 | HF_MODEL="$2" 41 | shift 2 42 | ;; 43 | --repository) 44 | REPOSITORY="$2" 45 | shift 2 46 | ;; 47 | --weights) 48 | WEIGHTS="$2" 49 | shift 2 50 | ;; 51 | --licenses) 52 | LICENSE_PATHS="$2" 53 | shift 2 54 | ;; 55 | --models-dir) 56 | MODELS_DIR="$2" 57 | shift 2 58 | ;; 59 | --hf-token) 60 | HF_TOKEN="$2" 61 | shift 2 62 | ;; 63 | --quantization) 64 | QUANTIZATION="$2" 65 | shift 2 66 | ;; 67 | --skip-f16) 68 | SKIP_F16=true 69 | shift 70 | ;; 71 | --help) 72 | usage 73 | ;; 74 | *) 75 | echo "Unknown option: $1" 76 | usage 77 | ;; 78 | esac 79 | done 80 | 81 | # Validate required parameters 82 | if [ -z "$HF_MODEL" ]; then 83 | echo "Error: Hugging Face model (--hf-model) is required" 84 | usage 85 | fi 86 | 87 | if [ -z "$REPOSITORY" ]; then 88 | echo "Error: Repository (--repository) is required" 89 | usage 90 | fi 91 | 92 | if [ -z "$WEIGHTS" ]; then 93 | echo "Error: Weights tag (--weights) is required" 94 | usage 95 | fi 96 | 97 | if [ -z "$LICENSE_PATHS" ]; then 98 | echo "Error: License paths (--licenses) are required" 99 | usage 100 | fi 101 | 102 | if [ -z "$HF_TOKEN" ]; then 103 | echo "Error: Hugging Face token (--hf-token) is required" 104 | usage 105 | fi 106 | 107 | # Set default values if not provided 108 | if [ -z "$LICENSE_PATHS" ]; then 109 | LICENSE_PATHS="$DEFAULT_LICENSE_PATH" 110 | fi 111 | MODELS_DIR="${MODELS_DIR:-$DEFAULT_MODELS_DIR}" 112 | QUANTIZATION="${QUANTIZATION:-$DEFAULT_QUANTIZATION}" 113 | SKIP_F16="${SKIP_F16:-false}" 114 | 115 | # Create models directory if it doesn't exist 116 | mkdir -p "$MODELS_DIR" 117 | 118 | # Check if Docker is installed 119 | if ! command -v docker &> /dev/null; then 120 | echo "Error: Docker is not installed or not in PATH" 121 | exit 1 122 | fi 123 | 124 | # Check if model-distribution-tool exists 125 | if [ ! -f "${PROJECT_ROOT}/bin/model-distribution-tool" ]; then 126 | echo "Error: model-distribution-tool not found at ${PROJECT_ROOT}/bin/model-distribution-tool" 127 | echo "Please build the tool first with 'make build'" 128 | exit 1 129 | fi 130 | 131 | # Construct the full target reference 132 | TARGET="${REPOSITORY}:${WEIGHTS}-${QUANTIZATION}" 133 | 134 | echo "=== Model Push Script ===" 135 | echo "Hugging Face Model: $HF_MODEL" 136 | echo "Repository: $REPOSITORY" 137 | echo "Weights: $WEIGHTS" 138 | echo "License Paths: $LICENSE_PATHS" 139 | echo "Models Directory: $MODELS_DIR" 140 | echo "Quantization: $QUANTIZATION" 141 | echo "Skip F16 Version: $SKIP_F16" 142 | echo "Full Target: $TARGET" 143 | echo 144 | 145 | # Step 1: Run Docker container to convert the model from Hugging Face 146 | echo "Step 1: Converting model from Hugging Face..." 147 | docker build -t docker/llama-converter:latest llama-converter 148 | docker run --rm \ 149 | -e HUGGINGFACE_TOKEN="$HF_TOKEN" \ 150 | -v "$MODELS_DIR:/models" \ 151 | docker/llama-converter:latest \ 152 | --from-hf "$HF_MODEL" --quantization "$QUANTIZATION" 153 | 154 | # Get the model name from the HF_MODEL 155 | MODEL_NAME="$(echo "$HF_MODEL" | sed 's/.*\///')" 156 | MODEL_DIR="$MODELS_DIR/$MODEL_NAME" 157 | 158 | # Define paths for both model versions 159 | if [[ "$QUANTIZATION" == "F16" ]]; then 160 | # If F16 is requested, there's only one model file 161 | QUANTIZED_MODEL_FILE="$MODEL_DIR"/"$MODEL_NAME"-F16.gguf 162 | F16_MODEL_FILE="$QUANTIZED_MODEL_FILE" 163 | else 164 | # For other quantization types, we have both quantized and F16 versions 165 | QUANTIZED_MODEL_FILE="$MODEL_DIR/ggml-model-$QUANTIZATION.gguf" 166 | F16_MODEL_FILE="$MODEL_DIR"/"$MODEL_NAME"-F16.gguf 167 | fi 168 | 169 | # Check if the quantized model file exists 170 | if [ ! -f "$QUANTIZED_MODEL_FILE" ]; then 171 | echo "Error: Quantized model file not found at $QUANTIZED_MODEL_FILE" 172 | exit 1 173 | fi 174 | 175 | echo "Quantized model file: $QUANTIZED_MODEL_FILE" 176 | 177 | # Check if the F16 model file exists (if we're not skipping it) 178 | if [ "$SKIP_F16" != "true" ] && [ "$QUANTIZATION" != "F16" ]; then 179 | if [ ! -f "$F16_MODEL_FILE" ]; then 180 | echo "Warning: F16 model file not found. Skipping F16 model push." 181 | SKIP_F16=true 182 | else 183 | echo "F16 model file: $F16_MODEL_FILE" 184 | fi 185 | fi 186 | 187 | # Step 2: Check for license files 188 | echo "Step 2: Checking for license files..." 189 | LICENSE_FLAGS="" 190 | IFS=',' read -ra LICENSE_FILES <<< "$LICENSE_PATHS" 191 | for LICENSE_FILE in "${LICENSE_FILES[@]}"; do 192 | if [ ! -f "$LICENSE_FILE" ]; then 193 | echo "Error: License file not found at $LICENSE_FILE" 194 | exit 1 195 | else 196 | echo "License file found: $LICENSE_FILE" 197 | LICENSE_FLAGS="$LICENSE_FLAGS --licenses $LICENSE_FILE" 198 | fi 199 | done 200 | 201 | if [ -z "$LICENSE_FLAGS" ]; then 202 | echo "Error: No valid license files provided" 203 | exit 1 204 | fi 205 | 206 | # Step 3: Push the model(s) to the repository 207 | echo "Step 3: Pushing model(s) to the repository..." 208 | 209 | echo "Pushing quantized model ($QUANTIZATION) to $TARGET..." 210 | "${PROJECT_ROOT}/bin/model-distribution-tool" package $LICENSE_FLAGS "$QUANTIZED_MODEL_FILE" "$TARGET" 211 | 212 | # Push the F16 model if not skipped and not already pushed (when QUANTIZATION=F16) 213 | if [ "$SKIP_F16" != "true" ] && [ "$QUANTIZATION" != "F16" ]; then 214 | # Create F16 tag by appending "-F16" to the weights 215 | F16_TARGET="${REPOSITORY}:${WEIGHTS}-F16" 216 | echo "Pushing F16 model to $F16_TARGET..." 217 | "${PROJECT_ROOT}/bin/model-distribution-tool" package $LICENSE_FLAGS "$F16_MODEL_FILE" "$F16_TARGET" 218 | echo "F16 model successfully pushed to $F16_TARGET" 219 | fi 220 | 221 | echo "=== Model successfully pushed ===" 222 | echo "Hugging Face Model: $HF_MODEL" 223 | echo "Repository: $REPOSITORY" 224 | echo "Weights: $WEIGHTS" 225 | echo "License Paths: $LICENSE_PATHS" 226 | echo "Models Directory: $MODELS_DIR" 227 | echo "Quantization: $QUANTIZATION" 228 | echo "Skip F16 Version: $SKIP_F16" 229 | echo "Full Target: $TARGET" 230 | echo 231 | -------------------------------------------------------------------------------- /types/config.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "strings" 5 | "time" 6 | 7 | v1 "github.com/google/go-containerregistry/pkg/v1" 8 | "github.com/google/go-containerregistry/pkg/v1/types" 9 | ) 10 | 11 | const ( 12 | // modelConfigPrefix is the prefix for all versioned model config media types. 13 | modelConfigPrefix = "application/vnd.docker.ai.model.config" 14 | 15 | // MediaTypeModelConfigV01 is the media type for the model config json. 16 | MediaTypeModelConfigV01 = types.MediaType("application/vnd.docker.ai.model.config.v0.1+json") 17 | 18 | // MediaTypeGGUF indicates a file in GGUF version 3 format, containing a tensor model. 19 | MediaTypeGGUF = types.MediaType("application/vnd.docker.ai.gguf.v3") 20 | 21 | // MediaTypeLicense indicates a plain text file containing a license 22 | MediaTypeLicense = types.MediaType("application/vnd.docker.ai.license") 23 | 24 | FormatGGUF = Format("gguf") 25 | ) 26 | 27 | func IsModelConfig(mt types.MediaType) bool { 28 | return strings.HasPrefix(string(mt), string(MediaTypeModelConfigV01)) 29 | } 30 | 31 | type Format string 32 | 33 | type ConfigFile struct { 34 | Config Config `json:"config"` 35 | Descriptor Descriptor `json:"descriptor"` 36 | RootFS v1.RootFS `json:"rootfs"` 37 | } 38 | 39 | // Config describes the model. 40 | type Config struct { 41 | Format Format `json:"format,omitempty"` 42 | Quantization string `json:"quantization,omitempty"` 43 | Parameters string `json:"parameters,omitempty"` 44 | Architecture string `json:"architecture,omitempty"` 45 | Size string `json:"size,omitempty"` 46 | GGUF map[string]string `json:"gguf,omitempty"` 47 | } 48 | 49 | // Descriptor provides metadata about the provenance of the model. 50 | type Descriptor struct { 51 | Created *time.Time `json:"created,omitempty"` 52 | } 53 | -------------------------------------------------------------------------------- /types/model.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | v1 "github.com/google/go-containerregistry/pkg/v1" 5 | ) 6 | 7 | type Model interface { 8 | ID() (string, error) 9 | GGUFPath() (string, error) 10 | Config() (Config, error) 11 | Tags() []string 12 | Descriptor() (Descriptor, error) 13 | } 14 | 15 | type ModelArtifact interface { 16 | ID() (string, error) 17 | Config() (Config, error) 18 | Descriptor() (Descriptor, error) 19 | v1.Image 20 | } 21 | --------------------------------------------------------------------------------