├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml └── workflows │ ├── go.yml │ └── user-input.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── devtools ├── generate_discovery_client.sh └── pre-push-hook.sh ├── genai ├── caching.go ├── caching_test.go ├── chat.go ├── client.go ├── client_test.go ├── config.yaml ├── content.go ├── debug.go ├── doc.go ├── embed.go ├── example_test.go ├── files.go ├── files_test.go ├── generate.sh ├── generativelanguagepb_veneer.gen.go ├── internal │ ├── cmd │ │ └── gen-examples │ │ │ └── gen-examples.go │ ├── generativelanguage │ │ └── v1beta │ │ │ ├── generativelanguage-api.json │ │ │ └── generativelanguage-gen.go │ ├── gensupport │ │ ├── README │ │ ├── buffer.go │ │ ├── buffer_test.go │ │ ├── doc.go │ │ ├── error.go │ │ ├── error_test.go │ │ ├── json.go │ │ ├── json_test.go │ │ ├── jsonfloat.go │ │ ├── jsonfloat_test.go │ │ ├── media.go │ │ ├── media_test.go │ │ ├── params.go │ │ ├── params_test.go │ │ ├── resumable.go │ │ ├── resumable_test.go │ │ ├── retry.go │ │ ├── retryable_linux.go │ │ ├── send.go │ │ ├── send_test.go │ │ ├── util_test.go │ │ ├── version.go │ │ └── version_test.go │ ├── samples │ │ └── docs-snippets_test.go │ ├── testhelpers │ │ └── testhelpers.go │ └── version.go ├── license.txt ├── list_models.go ├── option.go └── testdata │ ├── 1251.txt │ ├── Cajun_instruments.jpg │ ├── a11.txt │ ├── badencoding.txt │ ├── earth.mp4 │ ├── organ.jpg │ ├── personWorkingOnComputer.jpg │ ├── poem.txt │ ├── sample.mp3 │ └── test.pdf ├── go.mod ├── go.sum └── license_test.go /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Use this template to report bugs 3 | labels: ["type:bug"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | **Note:** If this is a support question (e.g. _How do I do XYZ?_), please visit the [official discussion forum](https://discuss.ai.google.dev/). This is a great place to interact with developers, and to learn, share, and support each other. 9 | - type: textarea 10 | id: description 11 | attributes: 12 | label: > 13 | Description of the bug: 14 | - type: textarea 15 | id: behavior 16 | attributes: 17 | label: > 18 | Actual vs expected behavior: 19 | - type: textarea 20 | id: info 21 | attributes: 22 | label: > 23 | Any other information you'd like to share? 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Support - Official Gemini API discussion forum 4 | url: https://discuss.ai.google.dev/ 5 | about: > 6 | For non-SDK issues like service downtime issues, account- or project-specific issues, or 7 | questions (e.g. _How do I do XYZ?_) visit the official discussion forum. This is a great place 8 | to interact with developers, and to learn, share, and support each other. 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Use this template to suggest a new feature 3 | labels: ["type:feature request"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | **Note:** If this is a support question (e.g. _How do I do XYZ?_), please visit the [official discussion forum](https://discuss.ai.google.dev/). This is a great place to interact with developers, and to learn, share, and support each other. 9 | - type: textarea 10 | id: description 11 | attributes: 12 | label: > 13 | Description of the feature request: 14 | - type: textarea 15 | id: behavior 16 | attributes: 17 | label: > 18 | What problem are you trying to solve with this feature? 19 | - type: textarea 20 | id: info 21 | attributes: 22 | label: > 23 | Any other information you'd like to share? 24 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v4 21 | with: 22 | go-version: '1.21' 23 | 24 | - name: Build 25 | run: go build -v ./... 26 | 27 | - name: Test 28 | run: go test -v ./... 29 | -------------------------------------------------------------------------------- /.github/workflows/user-input.yml: -------------------------------------------------------------------------------- 1 | name: Manage awaiting user response 2 | 3 | on: 4 | issue_comment: 5 | types: [created] 6 | pull_request_review_comment: 7 | types: [created] 8 | 9 | jobs: 10 | remove_label: 11 | runs-on: ubuntu-latest 12 | if: "contains(github.event.issue.labels.*.name, 'status: awaiting user response')" 13 | steps: 14 | - uses: actions-ecosystem/action-remove-labels@v1 15 | with: 16 | labels: "status: awaiting user response" 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We would love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our Community Guidelines 22 | 23 | This project follows [Google's Open Source Community 24 | Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | 1. Clone this repo 29 | 2. Run tests with `go test ./...`; the "live" tests will be skipped 30 | unless a valid API key is set with the `GEMINI_API_KEY` environment variable. 31 | 32 | ### Code Reviews 33 | 34 | All submissions, including submissions by project members, require review. We 35 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) 36 | for this purpose. 37 | 38 | ## For Maintainers 39 | 40 | ### Preparation 41 | 42 | Install the pre-push hook: 43 | ``` 44 | cp devtools/pre-push-hook.sh .git/hooks/pre-push 45 | ``` 46 | 47 | ### Creating a new release 48 | 49 | This repo consists of a single Go module. 50 | To increase the minor or patch version of the module: 51 | 52 | 1. Run `git pull --tags` to get the up-do-date upstream tags. 53 | 2. Determine the desired tag, using `git tag -l` to see existing tags 54 | and incrementing as appropriate. We will call the result TAG in 55 | these instructions. It should be of the form `vX.Y.Z`. 56 | 3. Update the version in genai/internal/version.go to match TAG. 57 | 4. Send a PR with that change. The pre-push hook should complain, so 58 | pass the `--no-verify` flag to `git push`. 59 | 5. Submit the PR when approved. _No other PRs should be submitted until 60 | the following steps have been completed._ 61 | 6. Run `git pull` to get the submitted PR locally. You should be on main. 62 | 7. Run `git tag TAG` to tag the repo locally. 63 | 8. Run `git push origin TAG`. If the pre-push hook complains here, something 64 | is wrong; stop and review. 65 | 9. Use the [GitHub UI](https://github.com/google/generative-ai-go/releases) to 66 | create the release. Use TAG as the name. 67 | Provide release notes by summarizing the result of `git log PREVTAG..`, 68 | where PREVTAG is the previous release tag. 69 | 10. Visit https://pkg.go.dev/github.com/google/generative-ai-go@TAG and request 70 | that the version be processed. 71 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Deprecated] Google AI Go SDK for the Gemini API 2 | 3 | With Gemini 2.0, we took the chance to create a single unified SDK for all developers who want to use Google's GenAI models (Gemini, Veo, Imagen, etc). As part of that process, we took all of the feedback from this SDK and what developers like about other SDKs in the ecosystem to create the [Google Gen AI SDK](https://github.com/googleapis/go-genai). 4 | 5 | The Gemini API docs are fully updated to show examples of the new Google Gen AI SDK: [Get started](https://ai.google.dev/gemini-api/docs/quickstart?lang=go). 6 | 7 | We know how disruptive an SDK change can be and don't take this change lightly, but our goal is to create an extremely simple and clear path for developers to build with our models so it felt necessary to make this change. 8 | 9 | Thank you for building with Gemini and [let us know](https://discuss.ai.google.dev/c/gemini-api/4) if you need any help! 10 | 11 | **Please be advised that this repository is now considered legacy.** For the latest features, performance improvements, and active development, we strongly recommend migrating to the official **[Google Generative AI SDK for Go](https://github.com/googleapis/go-genai)**. 12 | 13 | **Support Plan for this Repository:** 14 | 15 | * **Limited Maintenance:** Development is now restricted to **critical bug fixes only**. No new features will be added. 16 | * **Purpose:** This limited support aims to provide stability for users while they transition to the new SDK. 17 | * **End-of-Life Date:** All support for this repository (including bug fixes) will permanently end on **August 31st, 2025**. 18 | 19 | We encourage all users to begin planning their migration to the [Google Generative AI SDK](https://github.com/googleapis/go-genai) to ensure continued access to the latest capabilities and support. 20 | -------------------------------------------------------------------------------- /devtools/generate_discovery_client.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This script generates the "discovery" client for the GenerativeLanguage API. 17 | # It is needed for file upload, which GAPIC clients don't support. 18 | 19 | # Run this tool from the `genai` directory of this repository. 20 | 21 | # The repo github.com/googleapis/google-api-go-client (corresponding to the Go import 22 | # path google.golang.org/api) contains a program that generates a Go client from 23 | # a discovery doc. It also contains all the clients generated from public discovery 24 | # docs, but the generativelanguage doc isn't public. In fact, retrieving it requires 25 | # an API key. We also don't want to put the discovery client in that repo, because 26 | # we don't want it to be public either; that would only confuse users. 27 | 28 | 29 | if [[ $GEMINI_API_KEY = '' ]]; then 30 | echo >&2 "need to set GEMINI_API_KEY" 31 | exit 1 32 | fi 33 | 34 | # Install the code generator for discovery clients. 35 | go install google.golang.org/api/google-api-go-generator@latest 36 | 37 | # Download the discovery document. 38 | docfile=/tmp/gl.json 39 | curl -s 'https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta&key='$GEMINI_API_KEY > $docfile 40 | 41 | # Generate the client. Write it to the internal directory to it is not exposed to users. 42 | google-api-go-generator -api_json_file $docfile \ 43 | -gendir internal \ 44 | -internal_pkg github.com/google/generative-ai-go/genai/internal \ 45 | -gensupport_pkg github.com/google/generative-ai-go/genai/internal/gensupport 46 | 47 | # Replace license with the proper one for this repo. 48 | file=internal/generativelanguage/v1beta/generativelanguage-gen.go 49 | cat license.txt <(tail +5 $file) | sponge $file 50 | 51 | -------------------------------------------------------------------------------- /devtools/pre-push-hook.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | # Copyright 2024 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This script performs some checks. 17 | # Install as a pre-push hook from the repo root with: 18 | # cp devtools/pre-push-hook.sh .git/hooks/pre-push 19 | 20 | go test -short ./... 21 | go vet ./... 22 | 23 | # Check that the version in the code matches the latest version tag. 24 | version_file=genai/internal/version.go 25 | latest_tag=$(git tag -l 'v*' | sort -V | tail -1) 26 | code_version=v$(awk '/^const Version/ {print substr($4, 2, length($4)-2)}' $version_file) 27 | 28 | if [[ $latest_tag == $code_version ]]; then 29 | exit 0 30 | fi 31 | 32 | echo "version $code_version in $version_file does not match latest tag $latest_tag." 33 | exit 1 34 | -------------------------------------------------------------------------------- /genai/caching.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package genai 16 | 17 | import ( 18 | "context" 19 | "errors" 20 | "fmt" 21 | "time" 22 | 23 | gl "cloud.google.com/go/ai/generativelanguage/apiv1beta" 24 | pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" 25 | "google.golang.org/api/iterator" 26 | durationpb "google.golang.org/protobuf/types/known/durationpb" 27 | fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb" 28 | timestamppb "google.golang.org/protobuf/types/known/timestamppb" 29 | ) 30 | 31 | type cacheClient = gl.CacheClient 32 | 33 | var ( 34 | newCacheClient = gl.NewCacheClient 35 | newCacheRESTClient = gl.NewCacheRESTClient 36 | ) 37 | 38 | // GenerativeModelFromCachedContent returns a [GenerativeModel] that uses the given [CachedContent]. 39 | // The argument should come from a call to [Client.CreateCachedContent] or [Client.GetCachedContent]. 40 | func (c *Client) GenerativeModelFromCachedContent(cc *CachedContent) *GenerativeModel { 41 | return &GenerativeModel{ 42 | c: c, 43 | fullName: cc.Model, 44 | CachedContentName: cc.Name, 45 | } 46 | } 47 | 48 | // CreateCachedContent creates a new CachedContent. 49 | // The argument should contain a model name and some data to be cached, which can include 50 | // contents, a system instruction, tools and/or tool configuration. It can also 51 | // include an expiration time or TTL. But it should not include a name; the system 52 | // will generate one. 53 | // 54 | // The return value will contain the name, which should be used to refer to the CachedContent 55 | // in other API calls. It will also hold various metadata like expiration and creation time. 56 | // It will not contain any of the actual content provided as input. 57 | // 58 | // You can use the return value to create a model with [Client.GenerativeModelFromCachedContent]. 59 | // Or you can set [GenerativeModel.CachedContentName] to the name of the CachedContent, in which 60 | // case you must ensure that the model provided in this call matches the name in the [GenerativeModel]. 61 | func (c *Client) CreateCachedContent(ctx context.Context, cc *CachedContent) (*CachedContent, error) { 62 | if cc.Name != "" { 63 | return nil, errors.New("genai.CreateCachedContent: do not provide a name; one will be generated") 64 | } 65 | pcc := cc.toProto() 66 | pcc.Model = Ptr(fullModelName(cc.Model)) 67 | req := &pb.CreateCachedContentRequest{ 68 | CachedContent: pcc, 69 | } 70 | debugPrint(req) 71 | return c.cachedContentFromProto(c.cc.CreateCachedContent(ctx, req)) 72 | } 73 | 74 | // GetCachedContent retrieves the CachedContent with the given name. 75 | func (c *Client) GetCachedContent(ctx context.Context, name string) (*CachedContent, error) { 76 | return c.cachedContentFromProto(c.cc.GetCachedContent(ctx, &pb.GetCachedContentRequest{Name: name})) 77 | } 78 | 79 | // DeleteCachedContent deletes the CachedContent with the given name. 80 | func (c *Client) DeleteCachedContent(ctx context.Context, name string) error { 81 | return c.cc.DeleteCachedContent(ctx, &pb.DeleteCachedContentRequest{Name: name}) 82 | } 83 | 84 | // CachedContentToUpdate specifies which fields of a CachedContent to modify in a call to 85 | // [Client.UpdateCachedContent]. 86 | type CachedContentToUpdate struct { 87 | // If non-nil, update the expire time or TTL. 88 | Expiration *ExpireTimeOrTTL 89 | } 90 | 91 | // UpdateCachedContent modifies the [CachedContent] according to the values 92 | // of the [CachedContentToUpdate] struct. 93 | // It returns the modified CachedContent. 94 | // 95 | // The argument CachedContent must have its Name field populated. 96 | // If its UpdateTime field is non-zero, it will be compared with the update time 97 | // of the stored CachedContent and the call will fail if they differ. 98 | // This avoids a race condition when two updates are attempted concurrently. 99 | // All other fields of the argument CachedContent are ignored. 100 | func (c *Client) UpdateCachedContent(ctx context.Context, cc *CachedContent, ccu *CachedContentToUpdate) (*CachedContent, error) { 101 | if ccu == nil || ccu.Expiration == nil { 102 | return nil, errors.New("genai.UpdateCachedContent: no update specified") 103 | } 104 | cc2 := &CachedContent{ 105 | Name: cc.Name, 106 | UpdateTime: cc.UpdateTime, 107 | Expiration: *ccu.Expiration, 108 | } 109 | mask := "expire_time" 110 | if ccu.Expiration.ExpireTime.IsZero() { 111 | mask = "ttl" 112 | } 113 | req := &pb.UpdateCachedContentRequest{ 114 | CachedContent: cc2.toProto(), 115 | UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{mask}}, 116 | } 117 | debugPrint(req) 118 | return c.cachedContentFromProto(c.cc.UpdateCachedContent(ctx, req)) 119 | } 120 | 121 | // ListCachedContents lists all the CachedContents associated with the project and location. 122 | func (c *Client) ListCachedContents(ctx context.Context) *CachedContentIterator { 123 | return &CachedContentIterator{ 124 | it: c.cc.ListCachedContents(ctx, &pb.ListCachedContentsRequest{}), 125 | } 126 | } 127 | 128 | // A CachedContentIterator iterates over CachedContents. 129 | type CachedContentIterator struct { 130 | it *gl.CachedContentIterator 131 | } 132 | 133 | // Next returns the next result. Its second return value is iterator.Done if there are no more 134 | // results. Once Next returns Done, all subsequent calls will return Done. 135 | func (it *CachedContentIterator) Next() (*CachedContent, error) { 136 | m, err := it.it.Next() 137 | if err != nil { 138 | return nil, err 139 | } 140 | return (CachedContent{}).fromProto(m), nil 141 | } 142 | 143 | // PageInfo supports pagination. See the google.golang.org/api/iterator package for details. 144 | func (it *CachedContentIterator) PageInfo() *iterator.PageInfo { 145 | return it.it.PageInfo() 146 | } 147 | 148 | func (c *Client) cachedContentFromProto(pcc *pb.CachedContent, err error) (*CachedContent, error) { 149 | if err != nil { 150 | return nil, err 151 | } 152 | cc := (CachedContent{}).fromProto(pcc) 153 | return cc, nil 154 | } 155 | 156 | // ExpireTimeOrTTL describes the time when a resource expires. 157 | // If ExpireTime is non-zero, it is the expiration time. 158 | // Otherwise, the expiration time is the value of TTL ("time to live") added 159 | // to the current time. 160 | type ExpireTimeOrTTL struct { 161 | ExpireTime time.Time 162 | TTL time.Duration 163 | } 164 | 165 | // populateCachedContentTo populates some fields of p from v. 166 | func populateCachedContentTo(p *pb.CachedContent, v *CachedContent) { 167 | exp := v.Expiration 168 | if !exp.ExpireTime.IsZero() { 169 | p.Expiration = &pb.CachedContent_ExpireTime{ 170 | ExpireTime: timestamppb.New(exp.ExpireTime), 171 | } 172 | } else if exp.TTL != 0 { 173 | p.Expiration = &pb.CachedContent_Ttl{ 174 | Ttl: durationpb.New(exp.TTL), 175 | } 176 | } 177 | // If both fields of v.Expiration are zero, leave p.Expiration unset. 178 | } 179 | 180 | // populateCachedContentFrom populates some fields of v from p. 181 | func populateCachedContentFrom(v *CachedContent, p *pb.CachedContent) { 182 | if p.Expiration == nil { 183 | return 184 | } 185 | switch e := p.Expiration.(type) { 186 | case *pb.CachedContent_ExpireTime: 187 | v.Expiration.ExpireTime = pvTimeFromProto(e.ExpireTime) 188 | case *pb.CachedContent_Ttl: 189 | v.Expiration.TTL = e.Ttl.AsDuration() 190 | default: 191 | panic(fmt.Sprintf("unknown type of CachedContent.Expiration: %T", p.Expiration)) 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /genai/caching_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package genai 16 | 17 | import ( 18 | "context" 19 | "path/filepath" 20 | "strings" 21 | "testing" 22 | "time" 23 | 24 | pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" 25 | "github.com/google/go-cmp/cmp" 26 | "github.com/google/go-cmp/cmp/cmpopts" 27 | "google.golang.org/api/iterator" 28 | durationpb "google.golang.org/protobuf/types/known/durationpb" 29 | timestamppb "google.golang.org/protobuf/types/known/timestamppb" 30 | ) 31 | 32 | func TestPopulateCachedContent(t *testing.T) { 33 | tm := time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC) 34 | cmpOpt := cmpopts.IgnoreUnexported( 35 | timestamppb.Timestamp{}, 36 | durationpb.Duration{}, 37 | ) 38 | for _, test := range []struct { 39 | proto *pb.CachedContent 40 | veneer *CachedContent 41 | }{ 42 | {&pb.CachedContent{}, &CachedContent{}}, 43 | { 44 | &pb.CachedContent{Expiration: &pb.CachedContent_ExpireTime{ExpireTime: timestamppb.New(tm)}}, 45 | &CachedContent{Expiration: ExpireTimeOrTTL{ExpireTime: tm}}, 46 | }, 47 | { 48 | &pb.CachedContent{Expiration: &pb.CachedContent_Ttl{Ttl: durationpb.New(time.Hour)}}, 49 | &CachedContent{Expiration: ExpireTimeOrTTL{TTL: time.Hour}}, 50 | }, 51 | } { 52 | var gotp pb.CachedContent 53 | populateCachedContentTo(&gotp, test.veneer) 54 | if g, w := gotp.Expiration, test.proto.Expiration; !cmp.Equal(g, w, cmpOpt) { 55 | t.Errorf("from %v to proto: got %v, want %v", test.veneer.Expiration, g, w) 56 | } 57 | 58 | var gotv CachedContent 59 | populateCachedContentFrom(&gotv, test.proto) 60 | if g, w := gotv.Expiration, test.veneer.Expiration; !cmp.Equal(g, w) { 61 | t.Errorf("from %v to veneer: got %v, want %v", test.proto.Expiration, g, w) 62 | } 63 | } 64 | } 65 | 66 | func testCaching(t *testing.T, client *Client) { 67 | ctx := context.Background() 68 | const model = "gemini-1.5-flash-001" 69 | 70 | file := uploadFile(t, ctx, client, filepath.Join("testdata", "earth.mp4")) 71 | 72 | t.Run("CRUD", func(t *testing.T) { 73 | must := func(cc *CachedContent, err error) *CachedContent { 74 | t.Helper() 75 | if err != nil { 76 | t.Fatal(err) 77 | } 78 | return cc 79 | } 80 | 81 | want := &CachedContent{ 82 | Model: "models/" + model, 83 | UsageMetadata: &CachedContentUsageMetadata{TotalTokenCount: 36876}, 84 | } 85 | 86 | compare := func(got *CachedContent, expireTime time.Time) { 87 | t.Helper() 88 | want.Expiration.ExpireTime = expireTime 89 | if got.CreateTime.IsZero() { 90 | t.Error("missing CreateTime") 91 | } 92 | if got.UpdateTime.IsZero() { 93 | t.Error("missing UpdateTime") 94 | } 95 | if diff := cmp.Diff(want, got, 96 | cmpopts.EquateApproxTime(10*time.Second), 97 | cmpopts.IgnoreFields(CachedContent{}, "Name", "CreateTime", "UpdateTime")); diff != "" { 98 | t.Errorf("mismatch (-want, +got):\n%s", diff) 99 | } 100 | } 101 | 102 | ttl := 30 * time.Minute 103 | wantExpireTime := time.Now().Add(ttl) 104 | // Replicate the file content multiple times to reach the minimum token threshold 105 | // for cached content. 106 | fd := FileData{MIMEType: "text/plain", URI: file.URI} 107 | parts := make([]Part, 25) 108 | for i := range parts { 109 | parts[i] = fd 110 | } 111 | argcc := &CachedContent{ 112 | Model: model, 113 | Expiration: ExpireTimeOrTTL{TTL: ttl}, 114 | Contents: []*Content{NewUserContent(parts...)}, 115 | } 116 | cc := must(client.CreateCachedContent(ctx, argcc)) 117 | compare(cc, wantExpireTime) 118 | name := cc.Name 119 | cc2 := must(client.GetCachedContent(ctx, name)) 120 | compare(cc2, wantExpireTime) 121 | gotList := listAll(t, client.ListCachedContents(ctx)) 122 | var cc3 *CachedContent 123 | for _, cc := range gotList { 124 | if cc.Name == name { 125 | cc3 = cc 126 | break 127 | } 128 | } 129 | if cc3 == nil { 130 | t.Fatal("did not find created in list") 131 | } 132 | compare(cc3, wantExpireTime) 133 | 134 | // Update using expire time. 135 | newExpireTime := cc3.Expiration.ExpireTime.Add(15 * time.Minute) 136 | cc4 := must(client.UpdateCachedContent(ctx, cc3, &CachedContentToUpdate{ 137 | Expiration: &ExpireTimeOrTTL{ExpireTime: newExpireTime}, 138 | })) 139 | compare(cc4, newExpireTime) 140 | 141 | t.Run("update-ttl", func(t *testing.T) { 142 | // Update using TTL. 143 | cc5 := must(client.UpdateCachedContent(ctx, cc4, &CachedContentToUpdate{ 144 | Expiration: &ExpireTimeOrTTL{TTL: ttl}, 145 | })) 146 | compare(cc5, time.Now().Add(ttl)) 147 | }) 148 | 149 | if err := client.DeleteCachedContent(ctx, name); err != nil { 150 | t.Fatal(err) 151 | } 152 | 153 | if err := client.DeleteCachedContent(ctx, "bad name"); err == nil { 154 | t.Fatal("want error, got nil") 155 | } 156 | }) 157 | t.Run("use", func(t *testing.T) { 158 | txt := strings.Repeat("George Washington was the first president of the United States. ", 3000) 159 | argcc := &CachedContent{ 160 | Model: model, 161 | Contents: []*Content{NewUserContent(Text(txt))}, 162 | } 163 | cc, err := client.CreateCachedContent(ctx, argcc) 164 | if err != nil { 165 | t.Fatal(err) 166 | } 167 | defer client.DeleteCachedContent(ctx, cc.Name) 168 | tokenCount := cc.UsageMetadata.TotalTokenCount 169 | m := client.GenerativeModelFromCachedContent(cc) 170 | t.Run("generation", func(t *testing.T) { 171 | res, err := m.GenerateContent(ctx, Text("Who was the first US president?")) 172 | if err != nil { 173 | t.Fatal(err) 174 | } 175 | got := responseString(res) 176 | const want = "Washington" 177 | if !strings.Contains(got, want) { 178 | t.Errorf("got %q, want string containing %q", got, want) 179 | } 180 | if g, w := res.UsageMetadata.CachedContentTokenCount, tokenCount; g != w { 181 | t.Errorf("CachedContentTokenCount: got %d, want %d", g, w) 182 | } 183 | }) 184 | t.Run("count", func(t *testing.T) { 185 | t.Skip("not yet implemented") 186 | gotRes, err := m.CountTokens(ctx, Text("Who Was the first US president?")) 187 | if err != nil { 188 | t.Fatal(err) 189 | } 190 | wantRes := &CountTokensResponse{ 191 | TotalTokens: 8, 192 | CachedContentTokenCount: tokenCount, 193 | } 194 | if !cmp.Equal(gotRes, wantRes) { 195 | t.Errorf("got %+v, want %+v", gotRes, wantRes) 196 | } 197 | }) 198 | }) 199 | } 200 | 201 | func listAll(t *testing.T, iter *CachedContentIterator) []*CachedContent { 202 | var ccs []*CachedContent 203 | for { 204 | cc, err := iter.Next() 205 | if err == iterator.Done { 206 | break 207 | } 208 | if err != nil { 209 | t.Fatal(err) 210 | } 211 | ccs = append(ccs, cc) 212 | } 213 | return ccs 214 | } 215 | -------------------------------------------------------------------------------- /genai/chat.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package genai 16 | 17 | import ( 18 | "context" 19 | ) 20 | 21 | // A ChatSession provides interactive chat. 22 | type ChatSession struct { 23 | m *GenerativeModel 24 | History []*Content 25 | } 26 | 27 | // StartChat starts a chat session. 28 | func (m *GenerativeModel) StartChat() *ChatSession { 29 | return &ChatSession{m: m} 30 | } 31 | 32 | // SendMessage sends a request to the model as part of a chat session. 33 | func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) { 34 | // Call the underlying client with the entire history plus the argument Content. 35 | cs.History = append(cs.History, NewUserContent(parts...)) 36 | req, err := cs.m.newGenerateContentRequest(cs.History...) 37 | if err != nil { 38 | return nil, err 39 | } 40 | req.GenerationConfig.CandidateCount = Ptr[int32](1) 41 | resp, err := cs.m.generateContent(ctx, req) 42 | if err != nil { 43 | return nil, err 44 | } 45 | cs.addToHistory(resp.Candidates) 46 | return resp, nil 47 | } 48 | 49 | // SendMessageStream is like SendMessage, but with a streaming request. 50 | func (cs *ChatSession) SendMessageStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator { 51 | cs.History = append(cs.History, NewUserContent(parts...)) 52 | req, err := cs.m.newGenerateContentRequest(cs.History...) 53 | if err != nil { 54 | return &GenerateContentResponseIterator{err: err} 55 | } 56 | req.GenerationConfig.CandidateCount = Ptr[int32](1) 57 | streamClient, err := cs.m.c.gc.StreamGenerateContent(ctx, req) 58 | return &GenerateContentResponseIterator{ 59 | sc: streamClient, 60 | err: err, 61 | cs: cs, 62 | } 63 | } 64 | 65 | // By default, use the first candidate for history. The user can modify that if they want. 66 | func (cs *ChatSession) addToHistory(cands []*Candidate) bool { 67 | if len(cands) > 0 { 68 | c := cands[0].Content 69 | if c == nil { 70 | return false 71 | } 72 | c.Role = roleModel 73 | cs.History = append(cs.History, copySanitizedModelContent(c)) 74 | return true 75 | } 76 | return false 77 | } 78 | 79 | // copySanitizedModelContent creates a (shallow) copy of c with role set to 80 | // model and empty text parts removed. 81 | func copySanitizedModelContent(c *Content) *Content { 82 | newc := &Content{Role: roleModel} 83 | for _, part := range c.Parts { 84 | if t, ok := part.(Text); !ok || len(string(t)) > 0 { 85 | newc.Parts = append(newc.Parts, part) 86 | } 87 | } 88 | return newc 89 | } 90 | -------------------------------------------------------------------------------- /genai/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Configuration for the protoveneer tool. 16 | 17 | package: genai 18 | 19 | protoImportPath: cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb 20 | 21 | types: 22 | HarmCategory: 23 | protoPrefix: HarmCategory_HARM_CATEGORY_ 24 | docVerb: specifies 25 | 26 | SafetySetting_HarmBlockThreshold: 27 | name: HarmBlockThreshold 28 | protoPrefix: SafetySetting_BLOCK_ 29 | veneerPrefix: HarmBlock 30 | docVerb: specifies 31 | valueNames: 32 | SafetySetting_HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockUnspecified 33 | 34 | SafetyRating_HarmProbability: 35 | name: HarmProbability 36 | protoPrefix: SafetyRating_ 37 | docVerb: specifies 38 | valueNames: 39 | SafetyRating_HARM_PROBABILITY_UNSPECIFIED: HarmProbabilityUnspecified 40 | 41 | Candidate_FinishReason: 42 | name: FinishReason 43 | protoPrefix: Candidate_ 44 | 45 | GenerateContentResponse: 46 | doc: 'is the response from a GenerateContent or GenerateContentStream call.' 47 | 48 | GenerateContentResponse_PromptFeedback_BlockReason: 49 | name: BlockReason 50 | protoPrefix: GenerateContentResponse_PromptFeedback_ 51 | 52 | Content: 53 | fields: 54 | Parts: 55 | type: '[]Part' 56 | 57 | Blob: 58 | fields: 59 | MimeType: 60 | name: MIMEType 61 | docVerb: contains 62 | 63 | FileData: 64 | fields: 65 | MimeType: 66 | name: MIMEType 67 | doc: | 68 | The IANA standard MIME type of the source data. 69 | If present, this overrides the MIME type specified or inferred 70 | when the file was uploaded. 71 | The supported MIME types are documented on [this page]. 72 | 73 | [this page]: https://ai.google.dev/gemini-api/docs/prompting_with_media?lang=go#supported_file_formats 74 | FileUri: 75 | name: URI 76 | doc: 'The URI returned from UploadFile or GetFile.' 77 | 78 | GenerationConfig: 79 | fields: 80 | ResponseMimeType: 81 | name: ResponseMIMEType 82 | 83 | SafetySetting: 84 | 85 | SafetyRating: 86 | docVerb: 'is the' 87 | 88 | CitationMetadata: 89 | 90 | CitationSource: 91 | docVerb: contains 92 | fields: 93 | Uri: 94 | name: URI 95 | License: 96 | type: string 97 | 98 | Candidate: 99 | fields: 100 | Index: 101 | type: int32 102 | GroundingAttributions: 103 | omit: true 104 | 105 | GenerateContentResponse_PromptFeedback: 106 | name: PromptFeedback 107 | docVerb: contains 108 | 109 | CountTokensResponse: 110 | 111 | TaskType: 112 | protoPrefix: TaskType 113 | valueNames: 114 | TaskType_TASK_TYPE_UNSPECIFIED: TaskTypeUnspecified 115 | 116 | EmbedContentResponse: 117 | BatchEmbedContentsResponse: 118 | 119 | ContentEmbedding: 120 | 121 | Model: 122 | name: ModelInfo 123 | doc: 'is information about a language model.' 124 | fields: 125 | BaseModelId: 126 | name: BaseModelID 127 | Temperature: 128 | type: float32 129 | TopP: 130 | type: float32 131 | TopK: 132 | type: int32 133 | 134 | # Types for function calling 135 | Tool: 136 | fields: 137 | FunctionDeclarations: 138 | doc: | 139 | Optional. A list of FunctionDeclarations available to the model that 140 | can be used for function calling. The model or system does not execute 141 | the function. Instead the defined function may be returned as a [FunctionCall] 142 | part with arguments to the client side for execution. The next conversation 143 | turn may contain a [FunctionResponse] with the role "function" generation 144 | context for the next model turn. 145 | ToolConfig: 146 | FunctionDeclaration: 147 | FunctionCall: 148 | FunctionResponse: 149 | Schema: 150 | 151 | Type: 152 | protoPrefix: Type_ 153 | veneerPrefix: '' 154 | 155 | FunctionCallingConfig: 156 | doc: 'holds configuration for function calling.' 157 | 158 | FunctionCallingConfig_Mode: 159 | name: FunctionCallingMode 160 | protoPrefix: FunctionCallingConfig 161 | veneerPrefix: FunctionCalling 162 | valueNames: 163 | FunctionCallingConfig_MODE_UNSPECIFIED: FunctionCallingUnspecified 164 | 165 | File: 166 | populateToFrom: populateFileTo, populateFileFrom 167 | fields: 168 | Uri: 169 | name: URI 170 | MimeType: 171 | name: MIMEType 172 | Metadata: 173 | type: '*FileMetadata' 174 | noConvert: true 175 | doc: 'Metadata for the File.' 176 | 177 | VideoMetadata: 178 | fields: 179 | VideoDuration: 180 | name: Duration 181 | 182 | File_State: 183 | name: FileState 184 | docVerb: represents 185 | protoPrefix: File 186 | veneerPrefix: FileState 187 | valueNames: 188 | File_STATE_UNSPECIFIED: FileStateUnspecified 189 | 190 | GenerateContentResponse_UsageMetadata: 191 | name: UsageMetadata 192 | fields: 193 | PromptTokenCount: 194 | type: int32 195 | CandidatesTokenCount: 196 | type: int32 197 | TotalTokenCount: 198 | type: int32 199 | 200 | CachedContent: 201 | populateToFrom: populateCachedContentTo, populateCachedContentFrom 202 | fields: 203 | Expiration: 204 | type: ExpireTimeOrTTL 205 | noConvert: true 206 | Name: 207 | type: string 208 | Model: 209 | type: string 210 | DisplayName: 211 | type: string 212 | 213 | CachedContent_UsageMetadata: 214 | name: CachedContentUsageMetadata 215 | 216 | CodeExecution: 217 | ExecutableCode: 218 | CodeExecutionResult: 219 | 220 | ExecutableCode_Language: 221 | name: ExecutableCodeLanguage 222 | protoPrefix: ExecutableCode 223 | veneerPrefix: ExecutableCode 224 | 225 | CodeExecutionResult_Outcome: 226 | name: CodeExecutionResultOutcome 227 | protoPrefix: CodeExecutionResult 228 | veneerPrefix: CodeExecutionResult 229 | valueNames: 230 | CodeExecutionResult_OUTCOME_OK: CodeExecutionResultOutcomeOK 231 | 232 | # Omit everything not explicitly configured. 233 | omitTypes: 234 | - '*' 235 | 236 | converters: 237 | Part: partToProto, partFromProto 238 | -------------------------------------------------------------------------------- /genai/content.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package genai 16 | 17 | import ( 18 | "fmt" 19 | 20 | pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" 21 | ) 22 | 23 | const ( 24 | roleUser = "user" 25 | roleModel = "model" 26 | ) 27 | 28 | // A Part is a piece of model content. 29 | // A Part can be one of the following types: 30 | // - Text 31 | // - Blob 32 | // - FunctionCall 33 | // - FunctionResponse 34 | // - ExecutableCode 35 | // - CodeExecutionResult 36 | type Part interface { 37 | toPart() *pb.Part 38 | } 39 | 40 | func partToProto(p Part) *pb.Part { 41 | if p == nil { 42 | return nil 43 | } 44 | return p.toPart() 45 | } 46 | 47 | func partFromProto(p *pb.Part) Part { 48 | switch d := p.Data.(type) { 49 | case *pb.Part_Text: 50 | return Text(d.Text) 51 | case *pb.Part_InlineData: 52 | return Blob{ 53 | MIMEType: d.InlineData.MimeType, 54 | Data: d.InlineData.Data, 55 | } 56 | case *pb.Part_FunctionCall: 57 | return *(FunctionCall{}).fromProto(d.FunctionCall) 58 | 59 | case *pb.Part_FunctionResponse: 60 | panic("FunctionResponse unimplemented") 61 | 62 | case *pb.Part_ExecutableCode: 63 | return (ExecutableCode{}).fromProto(d.ExecutableCode) 64 | case *pb.Part_CodeExecutionResult: 65 | return (CodeExecutionResult{}).fromProto(d.CodeExecutionResult) 66 | default: 67 | panic(fmt.Errorf("unknown Part.Data type %T", p.Data)) 68 | } 69 | } 70 | 71 | // A Text is a piece of text, like a question or phrase. 72 | type Text string 73 | 74 | func (t Text) toPart() *pb.Part { 75 | return &pb.Part{ 76 | Data: &pb.Part_Text{Text: string(t)}, 77 | } 78 | } 79 | 80 | func (b Blob) toPart() *pb.Part { 81 | return &pb.Part{ 82 | Data: &pb.Part_InlineData{ 83 | InlineData: b.toProto(), 84 | }, 85 | } 86 | } 87 | 88 | // ImageData is a convenience function for creating an image 89 | // Blob for input to a model. 90 | // The format should be the second part of the MIME type, after "image/". 91 | // For example, for a PNG image, pass "png". 92 | func ImageData(format string, data []byte) Blob { 93 | return Blob{ 94 | MIMEType: "image/" + format, 95 | Data: data, 96 | } 97 | } 98 | 99 | func (f FunctionCall) toPart() *pb.Part { 100 | return &pb.Part{ 101 | Data: &pb.Part_FunctionCall{ 102 | FunctionCall: f.toProto(), 103 | }, 104 | } 105 | } 106 | 107 | func (f FunctionResponse) toPart() *pb.Part { 108 | return &pb.Part{ 109 | Data: &pb.Part_FunctionResponse{ 110 | FunctionResponse: f.toProto(), 111 | }, 112 | } 113 | } 114 | 115 | func (fd FileData) toPart() *pb.Part { 116 | return &pb.Part{ 117 | Data: &pb.Part_FileData{ 118 | FileData: fd.toProto(), 119 | }, 120 | } 121 | } 122 | 123 | func (ec ExecutableCode) toPart() *pb.Part { 124 | return &pb.Part{ 125 | Data: &pb.Part_ExecutableCode{ 126 | ExecutableCode: ec.toProto(), 127 | }, 128 | } 129 | } 130 | 131 | func (c CodeExecutionResult) toPart() *pb.Part { 132 | return &pb.Part{ 133 | Data: &pb.Part_CodeExecutionResult{ 134 | CodeExecutionResult: c.toProto(), 135 | }, 136 | } 137 | } 138 | 139 | // Ptr returns a pointer to its argument. 140 | // It can be used to initialize pointer fields: 141 | // 142 | // model.Temperature = genai.Ptr[float32](0.1) 143 | func Ptr[T any](t T) *T { return &t } 144 | 145 | // SetCandidateCount sets the CandidateCount field. 146 | func (c *GenerationConfig) SetCandidateCount(x int32) { c.CandidateCount = &x } 147 | 148 | // SetMaxOutputTokens sets the MaxOutputTokens field. 149 | func (c *GenerationConfig) SetMaxOutputTokens(x int32) { c.MaxOutputTokens = &x } 150 | 151 | // SetTemperature sets the Temperature field. 152 | func (c *GenerationConfig) SetTemperature(x float32) { c.Temperature = &x } 153 | 154 | // SetTopP sets the TopP field. 155 | func (c *GenerationConfig) SetTopP(x float32) { c.TopP = &x } 156 | 157 | // SetTopK sets the TopK field. 158 | func (c *GenerationConfig) SetTopK(x int32) { c.TopK = &x } 159 | 160 | // FunctionCalls return all the FunctionCall parts in the candidate. 161 | func (c *Candidate) FunctionCalls() []FunctionCall { 162 | if c.Content == nil { 163 | return nil 164 | } 165 | var fcs []FunctionCall 166 | for _, p := range c.Content.Parts { 167 | if fc, ok := p.(FunctionCall); ok { 168 | fcs = append(fcs, fc) 169 | } 170 | } 171 | return fcs 172 | } 173 | 174 | // NewUserContent returns a *Content with a "user" role set and one or more 175 | // parts. 176 | func NewUserContent(parts ...Part) *Content { 177 | content := &Content{Role: roleUser, Parts: []Part{}} 178 | for _, part := range parts { 179 | content.Parts = append(content.Parts, part) 180 | } 181 | return content 182 | } 183 | -------------------------------------------------------------------------------- /genai/debug.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // This file contains debugging support functions. 16 | 17 | package genai 18 | 19 | import ( 20 | "fmt" 21 | "os" 22 | 23 | "google.golang.org/protobuf/encoding/prototext" 24 | "google.golang.org/protobuf/proto" 25 | ) 26 | 27 | // printRequests controls whether request protobufs are written to stderr. 28 | var printRequests = false 29 | 30 | func debugPrint(m proto.Message) { 31 | if !printRequests { 32 | return 33 | } 34 | fmt.Fprintln(os.Stderr, "--------") 35 | fmt.Fprintf(os.Stderr, "%T\n", m) 36 | fmt.Fprint(os.Stderr, prototext.Format(m)) 37 | fmt.Fprintln(os.Stderr, "^^^^^^^^") 38 | } 39 | -------------------------------------------------------------------------------- /genai/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // # [Deprecated] Google AI Go SDK for the Gemini API 16 | // 17 | // With Gemini 2.0, we took the chance to create a single unified SDK for all developers who want 18 | // to use Google's GenAI models (Gemini, Veo, Imagen, etc). As part of that process, we took all of 19 | // the feedback from this SDK and what developers like about other SDKs in the ecosystem to create 20 | // the [Google Gen AI SDK]. 21 | // 22 | // The Gemini API docs are fully updated to show examples of the new Google Gen AI SDK: 23 | // [Get started]. 24 | // 25 | // We know how disruptive an SDK change can be and don't take this change lightly, but our goal 26 | // is to create an extremely simple and clear path for developers to build with our models so it 27 | // felt necessary to make this change. 28 | // 29 | // Thank you for building with Gemini and [let us know] if you need any help! 30 | // 31 | // Please be advised that this repository is now considered legacy. For the latest features, performance 32 | // improvements, and active development, we strongly recommend migrating to the official 33 | // [Google Generative AI SDK for Go]. 34 | // 35 | // Support Plan for this Repository: 36 | // 37 | // - Limited Maintenance: Development is now restricted to critical bug fixes only. No new features will be added. 38 | // - Purpose: This limited support aims to provide stability for users while they transition to the new SDK. 39 | // - End-of-Life Date: All support for this repository (including bug fixes) will permanently end on August 31st, 2025. 40 | // 41 | // We encourage all users to begin planning their migration to the [Google Generative AI SDK] to ensure 42 | // continued access to the latest capabilities and support. 43 | // 44 | // 45 | // # Getting started 46 | // 47 | // NOTE: This client uses the v1beta version of the API. 48 | // 49 | // Reading the [examples] is the best way to learn how to use this package. 50 | // 51 | // # Authorization 52 | // 53 | // You will need an API key to use the service. 54 | // See the [setup tutorial] for details. 55 | // 56 | // # Errors 57 | // 58 | // [examples]: https://pkg.go.dev/github.com/google/generative-ai-go/genai#pkg-examples 59 | // [setup tutorial]: https://ai.google.dev/tutorials/setup 60 | // [Google Gen AI SDK]: https://github.com/googleapis/go-genai 61 | // [Get started]: https://ai.google.dev/gemini-api/docs/quickstart?lang=go 62 | // [let us know]: https://discuss.ai.google.dev/c/gemini-api/4 63 | // [Google Generative AI SDK for Go]: https://github.com/googleapis/go-genai 64 | // [Google Generative AI SDK]: https://github.com/googleapis/go-genai 65 | package genai 66 | -------------------------------------------------------------------------------- /genai/embed.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package genai 16 | 17 | import ( 18 | "context" 19 | 20 | pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" 21 | ) 22 | 23 | // EmbeddingModel creates a new instance of the named embedding model. 24 | // Example name: "embedding-001" or "models/embedding-001". 25 | func (c *Client) EmbeddingModel(name string) *EmbeddingModel { 26 | return &EmbeddingModel{ 27 | c: c, 28 | name: name, 29 | fullName: fullModelName(name), 30 | } 31 | } 32 | 33 | // EmbeddingModel is a model that computes embeddings. 34 | // Create one with [Client.EmbeddingModel]. 35 | type EmbeddingModel struct { 36 | c *Client 37 | name string 38 | fullName string 39 | // TaskType describes how the embedding will be used. 40 | TaskType TaskType 41 | } 42 | 43 | // Name returns the name of the EmbeddingModel. 44 | func (m *EmbeddingModel) Name() string { 45 | return m.name 46 | } 47 | 48 | // EmbedContent returns an embedding for the list of parts. 49 | func (m *EmbeddingModel) EmbedContent(ctx context.Context, parts ...Part) (*EmbedContentResponse, error) { 50 | return m.EmbedContentWithTitle(ctx, "", parts...) 51 | } 52 | 53 | // EmbedContentWithTitle returns an embedding for the list of parts. 54 | // If the given title is non-empty, it is passed to the model and 55 | // the task type is set to TaskTypeRetrievalDocument. 56 | func (m *EmbeddingModel) EmbedContentWithTitle(ctx context.Context, title string, parts ...Part) (*EmbedContentResponse, error) { 57 | req := newEmbedContentRequest(m.fullName, m.TaskType, title, parts) 58 | res, err := m.c.gc.EmbedContent(ctx, req) 59 | if err != nil { 60 | return nil, err 61 | } 62 | return (EmbedContentResponse{}).fromProto(res), nil 63 | } 64 | 65 | func newEmbedContentRequest(model string, tt TaskType, title string, parts []Part) *pb.EmbedContentRequest { 66 | req := &pb.EmbedContentRequest{ 67 | Model: model, 68 | Content: NewUserContent(parts...).toProto(), 69 | } 70 | // A non-empty title overrides the task type. 71 | if title != "" { 72 | req.Title = &title 73 | tt = TaskTypeRetrievalDocument 74 | } 75 | if tt != TaskTypeUnspecified { 76 | taskType := pb.TaskType(tt) 77 | req.TaskType = &taskType 78 | } 79 | debugPrint(req) 80 | return req 81 | } 82 | 83 | // An EmbeddingBatch holds a collection of embedding requests. 84 | type EmbeddingBatch struct { 85 | tt TaskType 86 | req *pb.BatchEmbedContentsRequest 87 | } 88 | 89 | // NewBatch returns a new, empty EmbeddingBatch with the same TaskType as the model. 90 | // Make multiple calls to [EmbeddingBatch.AddContent] or [EmbeddingBatch.AddContentWithTitle]. 91 | // Then pass the EmbeddingBatch to [EmbeddingModel.BatchEmbedContents] to get 92 | // all the embeddings in a single call to the model. 93 | func (m *EmbeddingModel) NewBatch() *EmbeddingBatch { 94 | return &EmbeddingBatch{ 95 | tt: m.TaskType, 96 | req: &pb.BatchEmbedContentsRequest{ 97 | Model: m.fullName, 98 | }, 99 | } 100 | } 101 | 102 | // AddContent adds a content to the batch. 103 | func (b *EmbeddingBatch) AddContent(parts ...Part) *EmbeddingBatch { 104 | b.AddContentWithTitle("", parts...) 105 | return b 106 | } 107 | 108 | // AddContent adds a content to the batch with a title. 109 | func (b *EmbeddingBatch) AddContentWithTitle(title string, parts ...Part) *EmbeddingBatch { 110 | b.req.Requests = append(b.req.Requests, newEmbedContentRequest(b.req.Model, b.tt, title, parts)) 111 | return b 112 | } 113 | 114 | // BatchEmbedContents returns the embeddings for all the contents in the batch. 115 | func (m *EmbeddingModel) BatchEmbedContents(ctx context.Context, b *EmbeddingBatch) (*BatchEmbedContentsResponse, error) { 116 | res, err := m.c.gc.BatchEmbedContents(ctx, b.req) 117 | if err != nil { 118 | return nil, err 119 | } 120 | return (BatchEmbedContentsResponse{}).fromProto(res), nil 121 | } 122 | 123 | // Info returns information about the model. 124 | func (m *EmbeddingModel) Info(ctx context.Context) (*ModelInfo, error) { 125 | return m.c.modelInfo(ctx, m.fullName) 126 | } 127 | -------------------------------------------------------------------------------- /genai/files.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package genai 16 | 17 | //go:generate ../devtools/generate_discovery_client.sh 18 | 19 | import ( 20 | "context" 21 | "io" 22 | "os" 23 | "strings" 24 | 25 | gl "cloud.google.com/go/ai/generativelanguage/apiv1beta" 26 | pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" 27 | gld "github.com/google/generative-ai-go/genai/internal/generativelanguage/v1beta" // discovery client 28 | "google.golang.org/api/googleapi" 29 | "google.golang.org/api/iterator" 30 | ) 31 | 32 | // UploadFileOptions are options for [Client.UploadFile]. 33 | type UploadFileOptions struct { 34 | // A more readable name for the file. 35 | DisplayName string 36 | 37 | // The IANA standard MIME type of the file. It will be stored with the file as metadata. 38 | // If omitted, the service will try to infer it. You may instead wish to use 39 | // [http.DetectContentType]. 40 | // The supported MIME types are documented on [this page]. 41 | // 42 | // [this page]: https://ai.google.dev/gemini-api/docs/document-processing?lang=go#technical-details 43 | MIMEType string 44 | } 45 | 46 | // UploadFile copies the contents of the given io.Reader to file storage associated 47 | // with the service, and returns information about the resulting file. 48 | // 49 | // The name is a relatively short, unique identifier for the file (rather than a typical 50 | // filename). 51 | // Typically it should be left empty, in which case a unique name will be generated. 52 | // Otherwise, it can contain up to 40 characters that are lowercase 53 | // alphanumeric or dashes (-), not starting or ending with a dash. 54 | // To generate your own unique names, consider a cryptographic hash algorithm like SHA-1. 55 | // The string "files/" is prepended to the name if it does not contain a '/'. 56 | // 57 | // Use the returned file's URI field with a [FileData] Part to use it for generation. 58 | // 59 | // It is an error to upload a file that already exists. 60 | func (c *Client) UploadFile(ctx context.Context, name string, r io.Reader, opts *UploadFileOptions) (*File, error) { 61 | if name != "" { 62 | name = userNameToServiceName(name) 63 | } 64 | req := &gld.CreateFileRequest{ 65 | File: &gld.File{Name: name}, 66 | } 67 | if opts != nil && opts.DisplayName != "" { 68 | req.File.DisplayName = opts.DisplayName 69 | } 70 | call := c.ds.Media.Upload(req) 71 | var mopts []googleapi.MediaOption 72 | if opts != nil && opts.MIMEType != "" { 73 | mopts = append(mopts, googleapi.ContentType(opts.MIMEType)) 74 | } 75 | call.Media(r, mopts...) 76 | res, err := call.Do() 77 | if err != nil { 78 | return nil, err 79 | } 80 | // Don't return the result, because it contains a file as represented by the 81 | // discovery client and we'd have to write code to convert it to this package's 82 | // File type. 83 | // Instead, make a GetFile call to get the proto file, which our generated code can convert. 84 | return c.GetFile(ctx, res.File.Name) 85 | } 86 | 87 | // UploadFileFromPath is a convenience method wrapping [UploadFile]. It takes 88 | // a path to read the file from, and uses a default auto-generated ID for the 89 | // uploaded file. 90 | func (c *Client) UploadFileFromPath(ctx context.Context, path string, opts *UploadFileOptions) (*File, error) { 91 | osf, err := os.Open(path) 92 | if err != nil { 93 | return nil, err 94 | } 95 | defer osf.Close() 96 | 97 | return c.UploadFile(ctx, "", osf, opts) 98 | } 99 | 100 | // GetFile returns the named file. 101 | func (c *Client) GetFile(ctx context.Context, name string) (*File, error) { 102 | req := &pb.GetFileRequest{Name: userNameToServiceName(name)} 103 | debugPrint(req) 104 | pf, err := c.fc.GetFile(ctx, req) 105 | if err != nil { 106 | return nil, err 107 | } 108 | return (File{}).fromProto(pf), nil 109 | } 110 | 111 | // DeleteFile deletes the file with the given name. 112 | // It is an error to delete a file that does not exist. 113 | func (c *Client) DeleteFile(ctx context.Context, name string) error { 114 | req := &pb.DeleteFileRequest{Name: userNameToServiceName(name)} 115 | debugPrint(req) 116 | return c.fc.DeleteFile(ctx, req) 117 | } 118 | 119 | // userNameToServiceName converts a name supplied by the user to a name required by the service. 120 | func userNameToServiceName(name string) string { 121 | if strings.ContainsRune(name, '/') { 122 | return name 123 | } 124 | return "files/" + name 125 | } 126 | 127 | // ListFiles returns an iterator over the uploaded files. 128 | func (c *Client) ListFiles(ctx context.Context) *FileIterator { 129 | return &FileIterator{ 130 | it: c.fc.ListFiles(ctx, &pb.ListFilesRequest{}), 131 | } 132 | } 133 | 134 | // A FileIterator iterates over Files. 135 | type FileIterator struct { 136 | it *gl.FileIterator 137 | } 138 | 139 | // Next returns the next result. Its second return value is iterator.Done if there are no more 140 | // results. Once Next returns Done, all subsequent calls will return Done. 141 | func (it *FileIterator) Next() (*File, error) { 142 | m, err := it.it.Next() 143 | if err != nil { 144 | return nil, err 145 | } 146 | return (File{}).fromProto(m), nil 147 | } 148 | 149 | // PageInfo supports pagination. See the google.golang.org/api/iterator package for details. 150 | func (it *FileIterator) PageInfo() *iterator.PageInfo { 151 | return it.it.PageInfo() 152 | } 153 | 154 | // FileMetadata holds metadata about a file. 155 | type FileMetadata struct { 156 | // Set if the file contains video. 157 | Video *VideoMetadata 158 | } 159 | 160 | func populateFileTo(p *pb.File, f *File) { 161 | p.Metadata = nil 162 | if f == nil || f.Metadata == nil { 163 | return 164 | } 165 | if f.Metadata.Video != nil { 166 | p.Metadata = &pb.File_VideoMetadata{ 167 | VideoMetadata: f.Metadata.Video.toProto(), 168 | } 169 | } 170 | } 171 | 172 | func populateFileFrom(f *File, p *pb.File) { 173 | f.Metadata = nil 174 | if p == nil || p.Metadata == nil { 175 | return 176 | } 177 | 178 | if p.Metadata != nil { 179 | switch m := p.Metadata.(type) { 180 | case *pb.File_VideoMetadata: 181 | f.Metadata = &FileMetadata{ 182 | Video: (VideoMetadata{}).fromProto(m.VideoMetadata), 183 | } 184 | default: 185 | // ignore other types 186 | // TODO: signal a problem 187 | } 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /genai/files_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package genai 16 | 17 | import ( 18 | "testing" 19 | "time" 20 | 21 | pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" 22 | "github.com/google/go-cmp/cmp" 23 | "github.com/google/go-cmp/cmp/cmpopts" 24 | "google.golang.org/protobuf/types/known/durationpb" 25 | // google.golang.org/protobuf/proto 26 | ) 27 | 28 | func TestPopulateFile(t *testing.T) { 29 | f1 := &File{} 30 | p1 := &pb.File{} 31 | f2 := &File{Metadata: &FileMetadata{ 32 | Video: &VideoMetadata{Duration: time.Minute}, 33 | }} 34 | p2 := &pb.File{ 35 | Metadata: &pb.File_VideoMetadata{ 36 | VideoMetadata: &pb.VideoMetadata{ 37 | VideoDuration: durationpb.New(time.Minute), 38 | }, 39 | }, 40 | } 41 | 42 | for _, test := range []struct { 43 | f *File 44 | p *pb.File 45 | }{ 46 | {f1, p1}, 47 | {f2, p2}, 48 | } { 49 | var pgot pb.File 50 | populateFileTo(&pgot, test.f) 51 | if !cmp.Equal(&pgot, test.p, cmpopts.IgnoreUnexported(pb.File{}, pb.VideoMetadata{}, durationpb.Duration{})) { 52 | t.Errorf("got %+v, want %+v", &pgot, test.p) 53 | } 54 | 55 | var fgot File 56 | populateFileFrom(&fgot, test.p) 57 | if !cmp.Equal(&fgot, test.f) { 58 | t.Errorf("got %+v, want %+v", &fgot, test.f) 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /genai/generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Copyright 2023 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | version=$(awk '$1 == "cloud.google.com/go/ai" {print $2}' ../go.mod) 17 | 18 | if [[ $version = '' ]]; then 19 | echo >&2 "could not get version of cloud.google.com/go/ai from ../go.mod" 20 | exit 1 21 | fi 22 | 23 | dir=~/go/pkg/mod/cloud.google.com/go/ai@$version/generativelanguage/apiv1beta/generativelanguagepb 24 | 25 | if [[ ! -d $dir ]]; then 26 | echo >&2 "$dir does not exist or is not a directory" 27 | exit 1 28 | fi 29 | 30 | echo "generating from $dir" 31 | protoveneer -license license.txt config.yaml $dir 32 | 33 | -------------------------------------------------------------------------------- /genai/internal/cmd/gen-examples/gen-examples.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // This code generator takes examples from the internal/samples directory 16 | // and copies them to "official" examples in genai/example_test.go, while 17 | // removing snippet comments (between [START...] and [END...]) that are used 18 | // for website documentation purposes. 19 | // It's invoked with a go:generate directive in the source file. 20 | 21 | package main 22 | 23 | import ( 24 | "flag" 25 | "fmt" 26 | "go/ast" 27 | "go/format" 28 | "go/parser" 29 | "go/token" 30 | "log" 31 | "os" 32 | "strings" 33 | ) 34 | 35 | func main() { 36 | inPath := flag.String("in", "", "input file path") 37 | outPath := flag.String("out", "", "output file path") 38 | flag.Parse() 39 | 40 | if len(*inPath) == 0 || len(*outPath) == 0 { 41 | log.Fatalf("got empty -in (%v) or -out (%v)", *inPath, *outPath) 42 | } 43 | 44 | inFile, err := os.Open(*inPath) 45 | if err != nil { 46 | log.Fatal(err) 47 | } 48 | defer inFile.Close() 49 | 50 | fset := token.NewFileSet() 51 | file, err := parser.ParseFile(fset, *inPath, inFile, parser.ParseComments) 52 | if err != nil { 53 | log.Fatal(err) 54 | } 55 | 56 | for _, cgroup := range file.Comments { 57 | sanitizeCommentGroup(cgroup) 58 | } 59 | 60 | outFile, err := os.Create(*outPath) 61 | if err != nil { 62 | log.Fatal(err) 63 | } 64 | defer outFile.Close() 65 | 66 | fmt.Fprintln(outFile, strings.TrimLeft(preamble, "\r\n")) 67 | format.Node(outFile, fset, file) 68 | } 69 | 70 | const preamble = ` 71 | // This file was generated from internal/samples/docs-snippets_test.go. DO NOT EDIT. 72 | ` 73 | 74 | func printCommentGroup(cg *ast.CommentGroup) { 75 | fmt.Printf("-- comment group %p\n", cg) 76 | for _, c := range cg.List { 77 | fmt.Println(c.Slash, c.Text) 78 | } 79 | } 80 | 81 | // sanitizeCommentGroup removes comment blocks between [START... and [END... 82 | // (including these lines), and also any go:generate directives - it modifies cg. 83 | func sanitizeCommentGroup(cg *ast.CommentGroup) { 84 | var nl []*ast.Comment 85 | excludeBlock := false 86 | for _, commentLine := range cg.List { 87 | if strings.Contains(commentLine.Text, "[START") { 88 | excludeBlock = true 89 | } else if strings.Contains(commentLine.Text, "[END") { 90 | excludeBlock = false 91 | } else if !excludeBlock { 92 | 93 | if !strings.Contains(commentLine.Text, "go:generate") { 94 | nl = append(nl, commentLine) 95 | } 96 | } 97 | } 98 | cg.List = nl 99 | } 100 | -------------------------------------------------------------------------------- /genai/internal/gensupport/README: -------------------------------------------------------------------------------- 1 | This directory was copied from github.com/googleapis/google-api-go-client/internal/gensupport. 2 | It is needed for the discovery client in ../generativelanguage. 3 | 4 | To update, first clone github.com/googleapis/google-api-go-client 5 | into a directory we will call DIR below. 6 | Then, from the repo root: 7 | ``` 8 | rm genai/internal/gensupport/*.go 9 | cp $DIR/internal/gensupport/*.go genai/internal/gensupport 10 | ``` 11 | Then edit the params.go and resumable.go files to replace the reference to `internal.Version` 12 | with the literal string from $DIR/internal/version.go, and remove the import of `internal`. 13 | -------------------------------------------------------------------------------- /genai/internal/gensupport/buffer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "bytes" 19 | "io" 20 | 21 | "google.golang.org/api/googleapi" 22 | ) 23 | 24 | // MediaBuffer buffers data from an io.Reader to support uploading media in 25 | // retryable chunks. It should be created with NewMediaBuffer. 26 | type MediaBuffer struct { 27 | media io.Reader 28 | 29 | chunk []byte // The current chunk which is pending upload. The capacity is the chunk size. 30 | err error // Any error generated when populating chunk by reading media. 31 | 32 | // The absolute position of chunk in the underlying media. 33 | off int64 34 | } 35 | 36 | // NewMediaBuffer initializes a MediaBuffer. 37 | func NewMediaBuffer(media io.Reader, chunkSize int) *MediaBuffer { 38 | return &MediaBuffer{media: media, chunk: make([]byte, 0, chunkSize)} 39 | } 40 | 41 | // Chunk returns the current buffered chunk, the offset in the underlying media 42 | // from which the chunk is drawn, and the size of the chunk. 43 | // Successive calls to Chunk return the same chunk between calls to Next. 44 | func (mb *MediaBuffer) Chunk() (chunk io.Reader, off int64, size int, err error) { 45 | // There may already be data in chunk if Next has not been called since the previous call to Chunk. 46 | if mb.err == nil && len(mb.chunk) == 0 { 47 | mb.err = mb.loadChunk() 48 | } 49 | return bytes.NewReader(mb.chunk), mb.off, len(mb.chunk), mb.err 50 | } 51 | 52 | // loadChunk will read from media into chunk, up to the capacity of chunk. 53 | func (mb *MediaBuffer) loadChunk() error { 54 | bufSize := cap(mb.chunk) 55 | mb.chunk = mb.chunk[:bufSize] 56 | 57 | read := 0 58 | var err error 59 | for err == nil && read < bufSize { 60 | var n int 61 | n, err = mb.media.Read(mb.chunk[read:]) 62 | read += n 63 | } 64 | mb.chunk = mb.chunk[:read] 65 | return err 66 | } 67 | 68 | // Next advances to the next chunk, which will be returned by the next call to Chunk. 69 | // Calls to Next without a corresponding prior call to Chunk will have no effect. 70 | func (mb *MediaBuffer) Next() { 71 | mb.off += int64(len(mb.chunk)) 72 | mb.chunk = mb.chunk[0:0] 73 | } 74 | 75 | type readerTyper struct { 76 | io.Reader 77 | googleapi.ContentTyper 78 | } 79 | 80 | // ReaderAtToReader adapts a ReaderAt to be used as a Reader. 81 | // If ra implements googleapi.ContentTyper, then the returned reader 82 | // will also implement googleapi.ContentTyper, delegating to ra. 83 | func ReaderAtToReader(ra io.ReaderAt, size int64) io.Reader { 84 | r := io.NewSectionReader(ra, 0, size) 85 | if typer, ok := ra.(googleapi.ContentTyper); ok { 86 | return readerTyper{r, typer} 87 | } 88 | return r 89 | } 90 | -------------------------------------------------------------------------------- /genai/internal/gensupport/buffer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "bytes" 19 | "io" 20 | "reflect" 21 | "testing" 22 | "testing/iotest" 23 | 24 | "google.golang.org/api/googleapi" 25 | ) 26 | 27 | // getChunkAsString reads a chunk from mb, but does not call Next. 28 | func getChunkAsString(t *testing.T, mb *MediaBuffer) (string, error) { 29 | chunk, _, size, err := mb.Chunk() 30 | 31 | buf, e := io.ReadAll(chunk) 32 | if e != nil { 33 | t.Fatalf("Failed reading chunk: %v", e) 34 | } 35 | if size != len(buf) { 36 | t.Fatalf("reported chunk size doesn't match actual chunk size: got: %v; want: %v", size, len(buf)) 37 | } 38 | return string(buf), err 39 | } 40 | 41 | func TestChunking(t *testing.T) { 42 | type testCase struct { 43 | data string // the data to read from the Reader 44 | finalErr error // error to return after data has been read 45 | chunkSize int 46 | wantChunks []string 47 | } 48 | 49 | for _, singleByteReads := range []bool{true, false} { 50 | for _, tc := range []testCase{ 51 | { 52 | data: "abcdefg", 53 | finalErr: nil, 54 | chunkSize: 3, 55 | wantChunks: []string{"abc", "def", "g"}, 56 | }, 57 | { 58 | data: "abcdefg", 59 | finalErr: nil, 60 | chunkSize: 1, 61 | wantChunks: []string{"a", "b", "c", "d", "e", "f", "g"}, 62 | }, 63 | { 64 | data: "abcdefg", 65 | finalErr: nil, 66 | chunkSize: 7, 67 | wantChunks: []string{"abcdefg"}, 68 | }, 69 | { 70 | data: "abcdefg", 71 | finalErr: nil, 72 | chunkSize: 8, 73 | wantChunks: []string{"abcdefg"}, 74 | }, 75 | { 76 | data: "abcdefg", 77 | finalErr: io.ErrUnexpectedEOF, 78 | chunkSize: 3, 79 | wantChunks: []string{"abc", "def", "g"}, 80 | }, 81 | { 82 | data: "abcdefg", 83 | finalErr: io.ErrUnexpectedEOF, 84 | chunkSize: 8, 85 | wantChunks: []string{"abcdefg"}, 86 | }, 87 | } { 88 | var r io.Reader = &errReader{buf: []byte(tc.data), err: tc.finalErr} 89 | 90 | if singleByteReads { 91 | r = iotest.OneByteReader(r) 92 | } 93 | 94 | mb := NewMediaBuffer(r, tc.chunkSize) 95 | var gotErr error 96 | got := []string{} 97 | for { 98 | chunk, err := getChunkAsString(t, mb) 99 | if len(chunk) != 0 { 100 | got = append(got, string(chunk)) 101 | } 102 | if err != nil { 103 | gotErr = err 104 | break 105 | } 106 | mb.Next() 107 | } 108 | 109 | if !reflect.DeepEqual(got, tc.wantChunks) { 110 | t.Errorf("Failed reading buffer: got: %v; want:%v", got, tc.wantChunks) 111 | } 112 | 113 | expectedErr := tc.finalErr 114 | if expectedErr == nil { 115 | expectedErr = io.EOF 116 | } 117 | if gotErr != expectedErr { 118 | t.Errorf("Reading buffer error: got: %v; want: %v", gotErr, expectedErr) 119 | } 120 | } 121 | } 122 | } 123 | 124 | func TestChunkCanBeReused(t *testing.T) { 125 | er := &errReader{buf: []byte("abcdefg")} 126 | mb := NewMediaBuffer(er, 3) 127 | 128 | // expectChunk reads a chunk and checks that it got what was wanted. 129 | expectChunk := func(want string, wantErr error) { 130 | got, err := getChunkAsString(t, mb) 131 | if err != wantErr { 132 | t.Errorf("error reading buffer: got: %v; want: %v", err, wantErr) 133 | } 134 | if !reflect.DeepEqual(got, want) { 135 | t.Errorf("Failed reading buffer: got: %q; want:%q", got, want) 136 | } 137 | } 138 | expectChunk("abc", nil) 139 | // On second call, should get same chunk again. 140 | expectChunk("abc", nil) 141 | mb.Next() 142 | expectChunk("def", nil) 143 | expectChunk("def", nil) 144 | mb.Next() 145 | expectChunk("g", io.EOF) 146 | expectChunk("g", io.EOF) 147 | mb.Next() 148 | expectChunk("", io.EOF) 149 | } 150 | 151 | func TestPos(t *testing.T) { 152 | er := &errReader{buf: []byte("abcdefg")} 153 | mb := NewMediaBuffer(er, 3) 154 | 155 | expectChunkAtOffset := func(want int64, wantErr error) { 156 | _, off, _, err := mb.Chunk() 157 | if err != wantErr { 158 | t.Errorf("error reading buffer: got: %v; want: %v", err, wantErr) 159 | } 160 | if got := off; got != want { 161 | t.Errorf("resumable buffer Pos: got: %v; want: %v", got, want) 162 | } 163 | } 164 | 165 | // We expect the first chunk to be at offset 0. 166 | expectChunkAtOffset(0, nil) 167 | // Fetching the same chunk should return the same offset. 168 | expectChunkAtOffset(0, nil) 169 | 170 | // Calling Next multiple times should only cause off to advance by 3, since off is not advanced until 171 | // the chunk is actually read. 172 | mb.Next() 173 | mb.Next() 174 | expectChunkAtOffset(3, nil) 175 | 176 | mb.Next() 177 | 178 | // Load the final 1-byte chunk. 179 | expectChunkAtOffset(6, io.EOF) 180 | 181 | // Next will advance 1 byte. But there are no more chunks, so off will not increase beyond 7. 182 | mb.Next() 183 | expectChunkAtOffset(7, io.EOF) 184 | mb.Next() 185 | expectChunkAtOffset(7, io.EOF) 186 | } 187 | 188 | // bytes.Reader implements both Reader and ReaderAt. The following types 189 | // implement various combinations of Reader, ReaderAt and ContentTyper, by 190 | // wrapping bytes.Reader. All implement at least ReaderAt, so they can be 191 | // passed to ReaderAtToReader. The following table summarizes which types 192 | // implement which interfaces: 193 | // 194 | // ReaderAt Reader ContentTyper 195 | // reader x x 196 | // typerReader x x x 197 | // readerAt x 198 | // typerReaderAt x x 199 | 200 | // reader implements Reader, in addition to ReaderAt. 201 | type reader struct { 202 | r *bytes.Reader 203 | } 204 | 205 | func (r *reader) ReadAt(b []byte, off int64) (n int, err error) { 206 | return r.r.ReadAt(b, off) 207 | } 208 | 209 | func (r *reader) Read(b []byte) (n int, err error) { 210 | return r.r.Read(b) 211 | } 212 | 213 | // typerReader implements Reader and ContentTyper, in addition to ReaderAt. 214 | type typerReader struct { 215 | r *bytes.Reader 216 | } 217 | 218 | func (tr *typerReader) ReadAt(b []byte, off int64) (n int, err error) { 219 | return tr.r.ReadAt(b, off) 220 | } 221 | 222 | func (tr *typerReader) Read(b []byte) (n int, err error) { 223 | return tr.r.Read(b) 224 | } 225 | 226 | func (tr *typerReader) ContentType() string { 227 | return "ctype" 228 | } 229 | 230 | // readerAt implements only ReaderAt. 231 | type readerAt struct { 232 | r *bytes.Reader 233 | } 234 | 235 | func (ra *readerAt) ReadAt(b []byte, off int64) (n int, err error) { 236 | return ra.r.ReadAt(b, off) 237 | } 238 | 239 | // typerReaderAt implements ContentTyper, in addition to ReaderAt. 240 | type typerReaderAt struct { 241 | r *bytes.Reader 242 | } 243 | 244 | func (tra *typerReaderAt) ReadAt(b []byte, off int64) (n int, err error) { 245 | return tra.r.ReadAt(b, off) 246 | } 247 | 248 | func (tra *typerReaderAt) ContentType() string { 249 | return "ctype" 250 | } 251 | 252 | func TestAdapter(t *testing.T) { 253 | data := "abc" 254 | 255 | checkConversion := func(to io.Reader, wantTyper bool) { 256 | if _, ok := to.(googleapi.ContentTyper); ok != wantTyper { 257 | t.Errorf("reader implements typer? got: %v; want: %v", ok, wantTyper) 258 | } 259 | if typer, ok := to.(googleapi.ContentTyper); ok && typer.ContentType() != "ctype" { 260 | t.Errorf("content type: got: %s; want: ctype", typer.ContentType()) 261 | } 262 | buf, err := io.ReadAll(to) 263 | if err != nil { 264 | t.Errorf("error reading data: %v", err) 265 | return 266 | } 267 | if !bytes.Equal(buf, []byte(data)) { 268 | t.Errorf("failed reading data: got: %s; want: %s", buf, data) 269 | } 270 | } 271 | 272 | type testCase struct { 273 | from io.ReaderAt 274 | wantTyper bool 275 | } 276 | for _, tc := range []testCase{ 277 | { 278 | from: &reader{bytes.NewReader([]byte(data))}, 279 | wantTyper: false, 280 | }, 281 | { 282 | // Reader and ContentTyper 283 | from: &typerReader{bytes.NewReader([]byte(data))}, 284 | wantTyper: true, 285 | }, 286 | { 287 | // ReaderAt 288 | from: &readerAt{bytes.NewReader([]byte(data))}, 289 | wantTyper: false, 290 | }, 291 | { 292 | // ReaderAt and ContentTyper 293 | from: &typerReaderAt{bytes.NewReader([]byte(data))}, 294 | wantTyper: true, 295 | }, 296 | } { 297 | to := ReaderAtToReader(tc.from, int64(len(data))) 298 | checkConversion(to, tc.wantTyper) 299 | // tc.from is a ReaderAt, and should be treated like one, even 300 | // if it also implements Reader. Specifically, it can be 301 | // reused and read from the beginning. 302 | to = ReaderAtToReader(tc.from, int64(len(data))) 303 | checkConversion(to, tc.wantTyper) 304 | } 305 | } 306 | -------------------------------------------------------------------------------- /genai/internal/gensupport/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Package gensupport is an internal implementation detail used by code 16 | // generated by the google-api-go-generator tool. 17 | // 18 | // This package may be modified at any time without regard for backwards 19 | // compatibility. It should not be used directly by API users. 20 | package gensupport 21 | -------------------------------------------------------------------------------- /genai/internal/gensupport/error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "errors" 19 | 20 | "github.com/googleapis/gax-go/v2/apierror" 21 | "google.golang.org/api/googleapi" 22 | ) 23 | 24 | // WrapError creates an [apierror.APIError] from err, wraps it in err, and 25 | // returns err. If err is not a [googleapi.Error] (or a 26 | // [google.golang.org/grpc/status.Status]), it returns err without modification. 27 | func WrapError(err error) error { 28 | var herr *googleapi.Error 29 | apiError, ok := apierror.ParseError(err, false) 30 | if ok && errors.As(err, &herr) { 31 | herr.Wrap(apiError) 32 | } 33 | return err 34 | } 35 | -------------------------------------------------------------------------------- /genai/internal/gensupport/error_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "errors" 19 | "testing" 20 | 21 | "github.com/google/go-cmp/cmp" 22 | "github.com/googleapis/gax-go/v2/apierror" 23 | "google.golang.org/api/googleapi" 24 | "google.golang.org/genproto/googleapis/rpc/errdetails" 25 | "google.golang.org/protobuf/proto" 26 | ) 27 | 28 | func TestWrapError(t *testing.T) { 29 | // The error format v2 for Google JSON REST APIs, per https://cloud.google.com/apis/design/errors#http_mapping. 30 | jsonErrStr := "{\"error\":{\"details\":[{\"@type\":\"type.googleapis.com/google.rpc.ErrorInfo\", \"reason\":\"just because\", \"domain\":\"tests\"}]}}" 31 | hae := &googleapi.Error{ 32 | Body: jsonErrStr, 33 | } 34 | err := WrapError(hae) 35 | 36 | var aerr *apierror.APIError 37 | if ok := errors.As(err, &aerr); !ok { 38 | t.Errorf("got false, want true") 39 | } 40 | 41 | httpErrInfo := &errdetails.ErrorInfo{Reason: "just because", Domain: "tests"} 42 | details := apierror.ErrDetails{ErrorInfo: httpErrInfo} 43 | if diff := cmp.Diff(aerr.Details(), details, cmp.Comparer(proto.Equal)); diff != "" { 44 | t.Errorf("got(-), want(+),: \n%s", diff) 45 | } 46 | if s := aerr.Reason(); s != "just because" { 47 | t.Errorf("Reason() got %s, want 'just because'", s) 48 | } 49 | if s := aerr.Domain(); s != "tests" { 50 | t.Errorf("Domain() got %s, want nil", s) 51 | } 52 | if err := aerr.Unwrap(); err != nil { 53 | t.Errorf("Unwrap() got %T, want nil", err) 54 | } 55 | if m := aerr.Metadata(); m != nil { 56 | t.Errorf("Metadata() got %v, want nil", m) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /genai/internal/gensupport/json.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "encoding/json" 19 | "fmt" 20 | "reflect" 21 | "strings" 22 | ) 23 | 24 | // MarshalJSON returns a JSON encoding of schema containing only selected fields. 25 | // A field is selected if any of the following is true: 26 | // - it has a non-empty value 27 | // - its field name is present in forceSendFields and it is not a nil pointer or nil interface 28 | // - its field name is present in nullFields. 29 | // 30 | // The JSON key for each selected field is taken from the field's json: struct tag. 31 | func MarshalJSON(schema interface{}, forceSendFields, nullFields []string) ([]byte, error) { 32 | if len(forceSendFields) == 0 && len(nullFields) == 0 { 33 | return json.Marshal(schema) 34 | } 35 | 36 | mustInclude := make(map[string]bool) 37 | for _, f := range forceSendFields { 38 | mustInclude[f] = true 39 | } 40 | useNull := make(map[string]bool) 41 | useNullMaps := make(map[string]map[string]bool) 42 | for _, nf := range nullFields { 43 | parts := strings.SplitN(nf, ".", 2) 44 | field := parts[0] 45 | if len(parts) == 1 { 46 | useNull[field] = true 47 | } else { 48 | if useNullMaps[field] == nil { 49 | useNullMaps[field] = map[string]bool{} 50 | } 51 | useNullMaps[field][parts[1]] = true 52 | } 53 | } 54 | 55 | dataMap, err := schemaToMap(schema, mustInclude, useNull, useNullMaps) 56 | if err != nil { 57 | return nil, err 58 | } 59 | return json.Marshal(dataMap) 60 | } 61 | 62 | func schemaToMap(schema interface{}, mustInclude, useNull map[string]bool, useNullMaps map[string]map[string]bool) (map[string]interface{}, error) { 63 | m := make(map[string]interface{}) 64 | s := reflect.ValueOf(schema) 65 | st := s.Type() 66 | 67 | for i := 0; i < s.NumField(); i++ { 68 | jsonTag := st.Field(i).Tag.Get("json") 69 | if jsonTag == "" { 70 | continue 71 | } 72 | tag, err := parseJSONTag(jsonTag) 73 | if err != nil { 74 | return nil, err 75 | } 76 | if tag.ignore { 77 | continue 78 | } 79 | 80 | v := s.Field(i) 81 | f := st.Field(i) 82 | 83 | if useNull[f.Name] { 84 | if !isEmptyValue(v) { 85 | return nil, fmt.Errorf("field %q in NullFields has non-empty value", f.Name) 86 | } 87 | m[tag.apiName] = nil 88 | continue 89 | } 90 | 91 | if !includeField(v, f, mustInclude) { 92 | continue 93 | } 94 | 95 | // If map fields are explicitly set to null, use a map[string]interface{}. 96 | if f.Type.Kind() == reflect.Map && useNullMaps[f.Name] != nil { 97 | ms, ok := v.Interface().(map[string]string) 98 | if !ok { 99 | mi, err := initMapSlow(v, f.Name, useNullMaps) 100 | if err != nil { 101 | return nil, err 102 | } 103 | m[tag.apiName] = mi 104 | continue 105 | } 106 | mi := map[string]interface{}{} 107 | for k, v := range ms { 108 | mi[k] = v 109 | } 110 | for k := range useNullMaps[f.Name] { 111 | mi[k] = nil 112 | } 113 | m[tag.apiName] = mi 114 | continue 115 | } 116 | 117 | // nil maps are treated as empty maps. 118 | if f.Type.Kind() == reflect.Map && v.IsNil() { 119 | m[tag.apiName] = map[string]string{} 120 | continue 121 | } 122 | 123 | // nil slices are treated as empty slices. 124 | if f.Type.Kind() == reflect.Slice && v.IsNil() { 125 | m[tag.apiName] = []bool{} 126 | continue 127 | } 128 | 129 | if tag.stringFormat { 130 | m[tag.apiName] = formatAsString(v, f.Type.Kind()) 131 | } else { 132 | m[tag.apiName] = v.Interface() 133 | } 134 | } 135 | return m, nil 136 | } 137 | 138 | // initMapSlow uses reflection to build up a map object. This is slower than 139 | // the default behavior so it should be used only as a fallback. 140 | func initMapSlow(rv reflect.Value, fieldName string, useNullMaps map[string]map[string]bool) (map[string]interface{}, error) { 141 | mi := map[string]interface{}{} 142 | iter := rv.MapRange() 143 | for iter.Next() { 144 | k, ok := iter.Key().Interface().(string) 145 | if !ok { 146 | return nil, fmt.Errorf("field %q has keys in NullFields but is not a map[string]any", fieldName) 147 | } 148 | v := iter.Value().Interface() 149 | mi[k] = v 150 | } 151 | for k := range useNullMaps[fieldName] { 152 | mi[k] = nil 153 | } 154 | return mi, nil 155 | } 156 | 157 | // formatAsString returns a string representation of v, dereferencing it first if possible. 158 | func formatAsString(v reflect.Value, kind reflect.Kind) string { 159 | if kind == reflect.Ptr && !v.IsNil() { 160 | v = v.Elem() 161 | } 162 | 163 | return fmt.Sprintf("%v", v.Interface()) 164 | } 165 | 166 | // jsonTag represents a restricted version of the struct tag format used by encoding/json. 167 | // It is used to describe the JSON encoding of fields in a Schema struct. 168 | type jsonTag struct { 169 | apiName string 170 | stringFormat bool 171 | ignore bool 172 | } 173 | 174 | // parseJSONTag parses a restricted version of the struct tag format used by encoding/json. 175 | // The format of the tag must match that generated by the Schema.writeSchemaStruct method 176 | // in the api generator. 177 | func parseJSONTag(val string) (jsonTag, error) { 178 | if val == "-" { 179 | return jsonTag{ignore: true}, nil 180 | } 181 | 182 | var tag jsonTag 183 | 184 | i := strings.Index(val, ",") 185 | if i == -1 || val[:i] == "" { 186 | return tag, fmt.Errorf("malformed json tag: %s", val) 187 | } 188 | 189 | tag = jsonTag{ 190 | apiName: val[:i], 191 | } 192 | 193 | switch val[i+1:] { 194 | case "omitempty": 195 | case "omitempty,string": 196 | tag.stringFormat = true 197 | default: 198 | return tag, fmt.Errorf("malformed json tag: %s", val) 199 | } 200 | 201 | return tag, nil 202 | } 203 | 204 | // Reports whether the struct field "f" with value "v" should be included in JSON output. 205 | func includeField(v reflect.Value, f reflect.StructField, mustInclude map[string]bool) bool { 206 | // The regular JSON encoding of a nil pointer is "null", which means "delete this field". 207 | // Therefore, we could enable field deletion by honoring pointer fields' presence in the mustInclude set. 208 | // However, many fields are not pointers, so there would be no way to delete these fields. 209 | // Rather than partially supporting field deletion, we ignore mustInclude for nil pointer fields. 210 | // Deletion will be handled by a separate mechanism. 211 | if f.Type.Kind() == reflect.Ptr && v.IsNil() { 212 | return false 213 | } 214 | 215 | // The "any" type is represented as an interface{}. If this interface 216 | // is nil, there is no reasonable representation to send. We ignore 217 | // these fields, for the same reasons as given above for pointers. 218 | if f.Type.Kind() == reflect.Interface && v.IsNil() { 219 | return false 220 | } 221 | 222 | return mustInclude[f.Name] || !isEmptyValue(v) 223 | } 224 | 225 | // isEmptyValue reports whether v is the empty value for its type. This 226 | // implementation is based on that of the encoding/json package, but its 227 | // correctness does not depend on it being identical. What's important is that 228 | // this function return false in situations where v should not be sent as part 229 | // of a PATCH operation. 230 | func isEmptyValue(v reflect.Value) bool { 231 | switch v.Kind() { 232 | case reflect.Array, reflect.Map, reflect.Slice, reflect.String: 233 | return v.Len() == 0 234 | case reflect.Bool: 235 | return !v.Bool() 236 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 237 | return v.Int() == 0 238 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 239 | return v.Uint() == 0 240 | case reflect.Float32, reflect.Float64: 241 | return v.Float() == 0 242 | case reflect.Interface, reflect.Ptr: 243 | return v.IsNil() 244 | } 245 | return false 246 | } 247 | -------------------------------------------------------------------------------- /genai/internal/gensupport/json_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "encoding/json" 19 | "reflect" 20 | "testing" 21 | 22 | "google.golang.org/api/googleapi" 23 | ) 24 | 25 | type CustomType struct { 26 | Foo string `json:"foo,omitempty"` 27 | } 28 | 29 | type schema struct { 30 | // Basic types 31 | B bool `json:"b,omitempty"` 32 | F float64 `json:"f,omitempty"` 33 | I int64 `json:"i,omitempty"` 34 | Istr int64 `json:"istr,omitempty,string"` 35 | Str string `json:"str,omitempty"` 36 | 37 | // Pointers to basic types 38 | PB *bool `json:"pb,omitempty"` 39 | PF *float64 `json:"pf,omitempty"` 40 | PI *int64 `json:"pi,omitempty"` 41 | PIStr *int64 `json:"pistr,omitempty,string"` 42 | PStr *string `json:"pstr,omitempty"` 43 | 44 | // Other types 45 | Int64s googleapi.Int64s `json:"i64s,omitempty"` 46 | S []int `json:"s,omitempty"` 47 | M map[string]string `json:"m,omitempty"` 48 | Any interface{} `json:"any,omitempty"` 49 | Child *child `json:"child,omitempty"` 50 | MapToAnyArray map[string][]interface{} `json:"maptoanyarray,omitempty"` 51 | MapToCustomType map[string]CustomType `json:"maptocustomtype,omitempty"` 52 | 53 | ForceSendFields []string `json:"-"` 54 | NullFields []string `json:"-"` 55 | } 56 | 57 | type child struct { 58 | B bool `json:"childbool,omitempty"` 59 | } 60 | 61 | type testCase struct { 62 | s schema 63 | want string 64 | } 65 | 66 | func TestBasics(t *testing.T) { 67 | for _, tc := range []testCase{ 68 | { 69 | s: schema{}, 70 | want: `{}`, 71 | }, 72 | { 73 | s: schema{ 74 | ForceSendFields: []string{"B", "F", "I", "Istr", "Str", "PB", "PF", "PI", "PIStr", "PStr"}, 75 | }, 76 | want: `{"b":false,"f":0.0,"i":0,"istr":"0","str":""}`, 77 | }, 78 | { 79 | s: schema{ 80 | NullFields: []string{"B", "F", "I", "Istr", "Str", "PB", "PF", "PI", "PIStr", "PStr"}, 81 | }, 82 | want: `{"b":null,"f":null,"i":null,"istr":null,"str":null,"pb":null,"pf":null,"pi":null,"pistr":null,"pstr":null}`, 83 | }, 84 | { 85 | s: schema{ 86 | B: true, 87 | F: 1.2, 88 | I: 1, 89 | Istr: 2, 90 | Str: "a", 91 | PB: googleapi.Bool(true), 92 | PF: googleapi.Float64(1.2), 93 | PI: googleapi.Int64(int64(1)), 94 | PIStr: googleapi.Int64(int64(2)), 95 | PStr: googleapi.String("a"), 96 | }, 97 | want: `{"b":true,"f":1.2,"i":1,"istr":"2","str":"a","pb":true,"pf":1.2,"pi":1,"pistr":"2","pstr":"a"}`, 98 | }, 99 | { 100 | s: schema{ 101 | B: false, 102 | F: 0.0, 103 | I: 0, 104 | Istr: 0, 105 | Str: "", 106 | PB: googleapi.Bool(false), 107 | PF: googleapi.Float64(0.0), 108 | PI: googleapi.Int64(int64(0)), 109 | PIStr: googleapi.Int64(int64(0)), 110 | PStr: googleapi.String(""), 111 | }, 112 | want: `{"pb":false,"pf":0.0,"pi":0,"pistr":"0","pstr":""}`, 113 | }, 114 | { 115 | s: schema{ 116 | B: false, 117 | F: 0.0, 118 | I: 0, 119 | Istr: 0, 120 | Str: "", 121 | PB: googleapi.Bool(false), 122 | PF: googleapi.Float64(0.0), 123 | PI: googleapi.Int64(int64(0)), 124 | PIStr: googleapi.Int64(int64(0)), 125 | PStr: googleapi.String(""), 126 | ForceSendFields: []string{"B", "F", "I", "Istr", "Str", "PB", "PF", "PI", "PIStr", "PStr"}, 127 | }, 128 | want: `{"b":false,"f":0.0,"i":0,"istr":"0","str":"","pb":false,"pf":0.0,"pi":0,"pistr":"0","pstr":""}`, 129 | }, 130 | { 131 | s: schema{ 132 | B: false, 133 | F: 0.0, 134 | I: 0, 135 | Istr: 0, 136 | Str: "", 137 | PB: googleapi.Bool(false), 138 | PF: googleapi.Float64(0.0), 139 | PI: googleapi.Int64(int64(0)), 140 | PIStr: googleapi.Int64(int64(0)), 141 | PStr: googleapi.String(""), 142 | NullFields: []string{"B", "F", "I", "Istr", "Str"}, 143 | }, 144 | want: `{"b":null,"f":null,"i":null,"istr":null,"str":null,"pb":false,"pf":0.0,"pi":0,"pistr":"0","pstr":""}`, 145 | }, 146 | } { 147 | checkMarshalJSON(t, tc) 148 | } 149 | } 150 | 151 | func TestSliceFields(t *testing.T) { 152 | for _, tc := range []testCase{ 153 | { 154 | s: schema{}, 155 | want: `{}`, 156 | }, 157 | { 158 | s: schema{S: []int{}, Int64s: googleapi.Int64s{}}, 159 | want: `{}`, 160 | }, 161 | { 162 | s: schema{S: []int{1}, Int64s: googleapi.Int64s{1}}, 163 | want: `{"s":[1],"i64s":["1"]}`, 164 | }, 165 | { 166 | s: schema{ 167 | ForceSendFields: []string{"S", "Int64s"}, 168 | }, 169 | want: `{"s":[],"i64s":[]}`, 170 | }, 171 | { 172 | s: schema{ 173 | S: []int{}, 174 | Int64s: googleapi.Int64s{}, 175 | ForceSendFields: []string{"S", "Int64s"}, 176 | }, 177 | want: `{"s":[],"i64s":[]}`, 178 | }, 179 | { 180 | s: schema{ 181 | S: []int{1}, 182 | Int64s: googleapi.Int64s{1}, 183 | ForceSendFields: []string{"S", "Int64s"}, 184 | }, 185 | want: `{"s":[1],"i64s":["1"]}`, 186 | }, 187 | { 188 | s: schema{ 189 | NullFields: []string{"S", "Int64s"}, 190 | }, 191 | want: `{"s":null,"i64s":null}`, 192 | }, 193 | } { 194 | checkMarshalJSON(t, tc) 195 | } 196 | } 197 | 198 | func TestMapField(t *testing.T) { 199 | for _, tc := range []testCase{ 200 | { 201 | s: schema{}, 202 | want: `{}`, 203 | }, 204 | { 205 | s: schema{M: make(map[string]string)}, 206 | want: `{}`, 207 | }, 208 | { 209 | s: schema{M: map[string]string{"a": "b"}}, 210 | want: `{"m":{"a":"b"}}`, 211 | }, 212 | { 213 | s: schema{ 214 | ForceSendFields: []string{"M"}, 215 | }, 216 | want: `{"m":{}}`, 217 | }, 218 | { 219 | s: schema{ 220 | NullFields: []string{"M"}, 221 | }, 222 | want: `{"m":null}`, 223 | }, 224 | { 225 | s: schema{ 226 | M: make(map[string]string), 227 | ForceSendFields: []string{"M"}, 228 | }, 229 | want: `{"m":{}}`, 230 | }, 231 | { 232 | s: schema{ 233 | M: make(map[string]string), 234 | NullFields: []string{"M"}, 235 | }, 236 | want: `{"m":null}`, 237 | }, 238 | { 239 | s: schema{ 240 | M: map[string]string{"a": "b"}, 241 | ForceSendFields: []string{"M"}, 242 | }, 243 | want: `{"m":{"a":"b"}}`, 244 | }, 245 | { 246 | s: schema{ 247 | M: map[string]string{"a": "b"}, 248 | NullFields: []string{"M.a", "M."}, 249 | }, 250 | want: `{"m": {"a": null, "":null}}`, 251 | }, 252 | { 253 | s: schema{ 254 | M: map[string]string{"a": "b"}, 255 | NullFields: []string{"M.c"}, 256 | }, 257 | want: `{"m": {"a": "b", "c": null}}`, 258 | }, 259 | { 260 | s: schema{ 261 | NullFields: []string{"M.a"}, 262 | ForceSendFields: []string{"M"}, 263 | }, 264 | want: `{"m": {"a": null}}`, 265 | }, 266 | { 267 | s: schema{ 268 | NullFields: []string{"M.a"}, 269 | }, 270 | want: `{}`, 271 | }, 272 | { 273 | s: schema{ 274 | MapToCustomType: map[string]CustomType{ 275 | "a": {Foo: "foo"}, 276 | }, 277 | NullFields: []string{"MapToCustomType.b"}, 278 | }, 279 | want: `{"maptocustomtype": {"a": {"foo": "foo"}, "b": null}}`, 280 | }, 281 | } { 282 | checkMarshalJSON(t, tc) 283 | } 284 | } 285 | 286 | func TestMapToAnyArray(t *testing.T) { 287 | for _, tc := range []testCase{ 288 | { 289 | s: schema{}, 290 | want: `{}`, 291 | }, 292 | { 293 | s: schema{MapToAnyArray: make(map[string][]interface{})}, 294 | want: `{}`, 295 | }, 296 | { 297 | s: schema{ 298 | MapToAnyArray: map[string][]interface{}{ 299 | "a": {2, "b"}, 300 | }, 301 | }, 302 | want: `{"maptoanyarray":{"a":[2, "b"]}}`, 303 | }, 304 | { 305 | s: schema{ 306 | MapToAnyArray: map[string][]interface{}{ 307 | "a": nil, 308 | }, 309 | }, 310 | want: `{"maptoanyarray":{"a": null}}`, 311 | }, 312 | { 313 | s: schema{ 314 | MapToAnyArray: map[string][]interface{}{ 315 | "a": {nil}, 316 | }, 317 | }, 318 | want: `{"maptoanyarray":{"a":[null]}}`, 319 | }, 320 | { 321 | s: schema{ 322 | ForceSendFields: []string{"MapToAnyArray"}, 323 | }, 324 | want: `{"maptoanyarray":{}}`, 325 | }, 326 | { 327 | s: schema{ 328 | NullFields: []string{"MapToAnyArray"}, 329 | }, 330 | want: `{"maptoanyarray":null}`, 331 | }, 332 | { 333 | s: schema{ 334 | MapToAnyArray: make(map[string][]interface{}), 335 | ForceSendFields: []string{"MapToAnyArray"}, 336 | }, 337 | want: `{"maptoanyarray":{}}`, 338 | }, 339 | { 340 | s: schema{ 341 | MapToAnyArray: map[string][]interface{}{ 342 | "a": {2, "b"}, 343 | }, 344 | ForceSendFields: []string{"MapToAnyArray"}, 345 | }, 346 | want: `{"maptoanyarray":{"a":[2, "b"]}}`, 347 | }, 348 | } { 349 | checkMarshalJSON(t, tc) 350 | } 351 | } 352 | 353 | type anyType struct { 354 | Field int 355 | } 356 | 357 | func (a anyType) MarshalJSON() ([]byte, error) { 358 | return []byte(`"anyType value"`), nil 359 | } 360 | 361 | func TestAnyField(t *testing.T) { 362 | // ForceSendFields has no effect on nil interfaces and interfaces that contain nil pointers. 363 | var nilAny *anyType 364 | for _, tc := range []testCase{ 365 | { 366 | s: schema{}, 367 | want: `{}`, 368 | }, 369 | { 370 | s: schema{Any: nilAny}, 371 | want: `{"any": null}`, 372 | }, 373 | { 374 | s: schema{Any: &anyType{}}, 375 | want: `{"any":"anyType value"}`, 376 | }, 377 | { 378 | s: schema{Any: anyType{}}, 379 | want: `{"any":"anyType value"}`, 380 | }, 381 | { 382 | s: schema{ 383 | ForceSendFields: []string{"Any"}, 384 | }, 385 | want: `{}`, 386 | }, 387 | { 388 | s: schema{ 389 | NullFields: []string{"Any"}, 390 | }, 391 | want: `{"any":null}`, 392 | }, 393 | { 394 | s: schema{ 395 | Any: nilAny, 396 | ForceSendFields: []string{"Any"}, 397 | }, 398 | want: `{"any": null}`, 399 | }, 400 | { 401 | s: schema{ 402 | Any: &anyType{}, 403 | ForceSendFields: []string{"Any"}, 404 | }, 405 | want: `{"any":"anyType value"}`, 406 | }, 407 | { 408 | s: schema{ 409 | Any: anyType{}, 410 | ForceSendFields: []string{"Any"}, 411 | }, 412 | want: `{"any":"anyType value"}`, 413 | }, 414 | } { 415 | checkMarshalJSON(t, tc) 416 | } 417 | } 418 | 419 | func TestSubschema(t *testing.T) { 420 | // Subschemas are always stored as pointers, so ForceSendFields has no effect on them. 421 | for _, tc := range []testCase{ 422 | { 423 | s: schema{}, 424 | want: `{}`, 425 | }, 426 | { 427 | s: schema{ 428 | ForceSendFields: []string{"Child"}, 429 | }, 430 | want: `{}`, 431 | }, 432 | { 433 | s: schema{ 434 | NullFields: []string{"Child"}, 435 | }, 436 | want: `{"child":null}`, 437 | }, 438 | { 439 | s: schema{Child: &child{}}, 440 | want: `{"child":{}}`, 441 | }, 442 | { 443 | s: schema{ 444 | Child: &child{}, 445 | ForceSendFields: []string{"Child"}, 446 | }, 447 | want: `{"child":{}}`, 448 | }, 449 | { 450 | s: schema{Child: &child{B: true}}, 451 | want: `{"child":{"childbool":true}}`, 452 | }, 453 | 454 | { 455 | s: schema{ 456 | Child: &child{B: true}, 457 | ForceSendFields: []string{"Child"}, 458 | }, 459 | want: `{"child":{"childbool":true}}`, 460 | }, 461 | } { 462 | checkMarshalJSON(t, tc) 463 | } 464 | } 465 | 466 | // checkMarshalJSON verifies that calling schemaToMap on tc.s yields a result which is equivalent to tc.want. 467 | func checkMarshalJSON(t *testing.T, tc testCase) { 468 | doCheckMarshalJSON(t, tc.s, tc.s.ForceSendFields, tc.s.NullFields, tc.want) 469 | if len(tc.s.ForceSendFields) == 0 && len(tc.s.NullFields) == 0 { 470 | // verify that the code path used when ForceSendFields and NullFields 471 | // are non-empty produces the same output as the fast path that is used 472 | // when they are empty. 473 | doCheckMarshalJSON(t, tc.s, []string{"dummy"}, []string{"dummy"}, tc.want) 474 | } 475 | } 476 | 477 | func doCheckMarshalJSON(t *testing.T, s schema, forceSendFields, nullFields []string, wantJSON string) { 478 | encoded, err := MarshalJSON(s, forceSendFields, nullFields) 479 | if err != nil { 480 | t.Fatalf("encoding json:\n got err: %v", err) 481 | } 482 | 483 | // The expected and obtained JSON can differ in field ordering, so unmarshal before comparing. 484 | var got interface{} 485 | var want interface{} 486 | err = json.Unmarshal(encoded, &got) 487 | if err != nil { 488 | t.Fatalf("decoding json:\n got err: %v", err) 489 | } 490 | err = json.Unmarshal([]byte(wantJSON), &want) 491 | if err != nil { 492 | t.Fatalf("decoding json:\n got err: %v", err) 493 | } 494 | if !reflect.DeepEqual(got, want) { 495 | t.Errorf("schemaToMap:\ngot :%v\nwant: %v", got, want) 496 | } 497 | } 498 | 499 | func TestParseJSONTag(t *testing.T) { 500 | for _, tc := range []struct { 501 | tag string 502 | want jsonTag 503 | }{ 504 | { 505 | tag: "-", 506 | want: jsonTag{ignore: true}, 507 | }, { 508 | tag: "name,omitempty", 509 | want: jsonTag{apiName: "name"}, 510 | }, { 511 | tag: "name,omitempty,string", 512 | want: jsonTag{apiName: "name", stringFormat: true}, 513 | }, 514 | } { 515 | got, err := parseJSONTag(tc.tag) 516 | if err != nil { 517 | t.Fatalf("parsing json:\n got err: %v\ntag: %q", err, tc.tag) 518 | } 519 | if !reflect.DeepEqual(got, tc.want) { 520 | t.Errorf("parseJSONTage:\ngot :%v\nwant:%v", got, tc.want) 521 | } 522 | } 523 | } 524 | 525 | func TestParseMalformedJSONTag(t *testing.T) { 526 | for _, tag := range []string{ 527 | "", 528 | "name", 529 | "name,", 530 | "name,blah", 531 | "name,blah,string", 532 | ",omitempty", 533 | ",omitempty,string", 534 | "name,omitempty,string,blah", 535 | } { 536 | _, err := parseJSONTag(tag) 537 | if err == nil { 538 | t.Fatalf("parsing json: expected err, got nil for tag: %v", tag) 539 | } 540 | } 541 | } 542 | -------------------------------------------------------------------------------- /genai/internal/gensupport/jsonfloat.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "encoding/json" 19 | "errors" 20 | "fmt" 21 | "math" 22 | ) 23 | 24 | // JSONFloat64 is a float64 that supports proper unmarshaling of special float 25 | // values in JSON, according to 26 | // https://developers.google.com/protocol-buffers/docs/proto3#json. Although 27 | // that is a proto-to-JSON spec, it applies to all Google APIs. 28 | // 29 | // The jsonpb package 30 | // (https://github.com/golang/protobuf/blob/master/jsonpb/jsonpb.go) has 31 | // similar functionality, but only for direct translation from proto messages 32 | // to JSON. 33 | type JSONFloat64 float64 34 | 35 | func (f *JSONFloat64) UnmarshalJSON(data []byte) error { 36 | var ff float64 37 | if err := json.Unmarshal(data, &ff); err == nil { 38 | *f = JSONFloat64(ff) 39 | return nil 40 | } 41 | var s string 42 | if err := json.Unmarshal(data, &s); err == nil { 43 | switch s { 44 | case "NaN": 45 | ff = math.NaN() 46 | case "Infinity": 47 | ff = math.Inf(1) 48 | case "-Infinity": 49 | ff = math.Inf(-1) 50 | default: 51 | return fmt.Errorf("google.golang.org/api/internal: bad float string %q", s) 52 | } 53 | *f = JSONFloat64(ff) 54 | return nil 55 | } 56 | return errors.New("google.golang.org/api/internal: data not float or string") 57 | } 58 | -------------------------------------------------------------------------------- /genai/internal/gensupport/jsonfloat_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "encoding/json" 19 | "math" 20 | "testing" 21 | ) 22 | 23 | func TestJSONFloat(t *testing.T) { 24 | for _, test := range []struct { 25 | in string 26 | want float64 27 | }{ 28 | {"0", 0}, 29 | {"-10", -10}, 30 | {"1e23", 1e23}, 31 | {`"Infinity"`, math.Inf(1)}, 32 | {`"-Infinity"`, math.Inf(-1)}, 33 | {`"NaN"`, math.NaN()}, 34 | } { 35 | var f64 JSONFloat64 36 | if err := json.Unmarshal([]byte(test.in), &f64); err != nil { 37 | t.Fatal(err) 38 | } 39 | got := float64(f64) 40 | if got != test.want && math.IsNaN(got) != math.IsNaN(test.want) { 41 | t.Errorf("%s: got %f, want %f", test.in, got, test.want) 42 | } 43 | } 44 | } 45 | 46 | func TestJSONFloatErrors(t *testing.T) { 47 | var f64 JSONFloat64 48 | for _, in := range []string{"", "a", `"Inf"`, `"-Inf"`, `"nan"`, `"nana"`} { 49 | if err := json.Unmarshal([]byte(in), &f64); err == nil { 50 | t.Errorf("%q: got nil, want error", in) 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /genai/internal/gensupport/media.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "bytes" 19 | "fmt" 20 | "io" 21 | "mime" 22 | "mime/multipart" 23 | "net/http" 24 | "net/textproto" 25 | "strings" 26 | "sync" 27 | "time" 28 | 29 | gax "github.com/googleapis/gax-go/v2" 30 | "google.golang.org/api/googleapi" 31 | ) 32 | 33 | type typeReader struct { 34 | io.Reader 35 | typ string 36 | } 37 | 38 | // multipartReader combines the contents of multiple readers to create a multipart/related HTTP body. 39 | // Close must be called if reads from the multipartReader are abandoned before reaching EOF. 40 | type multipartReader struct { 41 | pr *io.PipeReader 42 | ctype string 43 | mu sync.Mutex 44 | pipeOpen bool 45 | } 46 | 47 | // boundary optionally specifies the MIME boundary 48 | func newMultipartReader(parts []typeReader, boundary string) *multipartReader { 49 | mp := &multipartReader{pipeOpen: true} 50 | var pw *io.PipeWriter 51 | mp.pr, pw = io.Pipe() 52 | mpw := multipart.NewWriter(pw) 53 | if boundary != "" { 54 | mpw.SetBoundary(boundary) 55 | } 56 | mp.ctype = "multipart/related; boundary=" + mpw.Boundary() 57 | go func() { 58 | for _, part := range parts { 59 | w, err := mpw.CreatePart(typeHeader(part.typ)) 60 | if err != nil { 61 | mpw.Close() 62 | pw.CloseWithError(fmt.Errorf("googleapi: CreatePart failed: %v", err)) 63 | return 64 | } 65 | _, err = io.Copy(w, part.Reader) 66 | if err != nil { 67 | mpw.Close() 68 | pw.CloseWithError(fmt.Errorf("googleapi: Copy failed: %v", err)) 69 | return 70 | } 71 | } 72 | 73 | mpw.Close() 74 | pw.Close() 75 | }() 76 | return mp 77 | } 78 | 79 | func (mp *multipartReader) Read(data []byte) (n int, err error) { 80 | return mp.pr.Read(data) 81 | } 82 | 83 | func (mp *multipartReader) Close() error { 84 | mp.mu.Lock() 85 | if !mp.pipeOpen { 86 | mp.mu.Unlock() 87 | return nil 88 | } 89 | mp.pipeOpen = false 90 | mp.mu.Unlock() 91 | return mp.pr.Close() 92 | } 93 | 94 | // CombineBodyMedia combines a json body with media content to create a multipart/related HTTP body. 95 | // It returns a ReadCloser containing the combined body, and the overall "multipart/related" content type, with random boundary. 96 | // 97 | // The caller must call Close on the returned ReadCloser if reads are abandoned before reaching EOF. 98 | func CombineBodyMedia(body io.Reader, bodyContentType string, media io.Reader, mediaContentType string) (io.ReadCloser, string) { 99 | return combineBodyMedia(body, bodyContentType, media, mediaContentType, "") 100 | } 101 | 102 | // combineBodyMedia is CombineBodyMedia but with an optional mimeBoundary field. 103 | func combineBodyMedia(body io.Reader, bodyContentType string, media io.Reader, mediaContentType, mimeBoundary string) (io.ReadCloser, string) { 104 | mp := newMultipartReader([]typeReader{ 105 | {body, bodyContentType}, 106 | {media, mediaContentType}, 107 | }, mimeBoundary) 108 | return mp, mp.ctype 109 | } 110 | 111 | func typeHeader(contentType string) textproto.MIMEHeader { 112 | h := make(textproto.MIMEHeader) 113 | if contentType != "" { 114 | h.Set("Content-Type", contentType) 115 | } 116 | return h 117 | } 118 | 119 | // PrepareUpload determines whether the data in the supplied reader should be 120 | // uploaded in a single request, or in sequential chunks. 121 | // chunkSize is the size of the chunk that media should be split into. 122 | // 123 | // If chunkSize is zero, media is returned as the first value, and the other 124 | // two return values are nil, true. 125 | // 126 | // Otherwise, a MediaBuffer is returned, along with a bool indicating whether the 127 | // contents of media fit in a single chunk. 128 | // 129 | // After PrepareUpload has been called, media should no longer be used: the 130 | // media content should be accessed via one of the return values. 131 | func PrepareUpload(media io.Reader, chunkSize int) (r io.Reader, mb *MediaBuffer, singleChunk bool) { 132 | if chunkSize == 0 { // do not chunk 133 | return media, nil, true 134 | } 135 | mb = NewMediaBuffer(media, chunkSize) 136 | _, _, _, err := mb.Chunk() 137 | // If err is io.EOF, we can upload this in a single request. Otherwise, err is 138 | // either nil or a non-EOF error. If it is the latter, then the next call to 139 | // mb.Chunk will return the same error. Returning a MediaBuffer ensures that this 140 | // error will be handled at some point. 141 | return nil, mb, err == io.EOF 142 | } 143 | 144 | // MediaInfo holds information for media uploads. It is intended for use by generated 145 | // code only. 146 | type MediaInfo struct { 147 | // At most one of Media and MediaBuffer will be set. 148 | media io.Reader 149 | buffer *MediaBuffer 150 | singleChunk bool 151 | mType string 152 | size int64 // mediaSize, if known. Used only for calls to progressUpdater_. 153 | progressUpdater googleapi.ProgressUpdater 154 | chunkRetryDeadline time.Duration 155 | } 156 | 157 | // NewInfoFromMedia should be invoked from the Media method of a call. It returns a 158 | // MediaInfo populated with chunk size and content type, and a reader or MediaBuffer 159 | // if needed. 160 | func NewInfoFromMedia(r io.Reader, options []googleapi.MediaOption) *MediaInfo { 161 | mi := &MediaInfo{} 162 | opts := googleapi.ProcessMediaOptions(options) 163 | if !opts.ForceEmptyContentType { 164 | mi.mType = opts.ContentType 165 | if mi.mType == "" { 166 | r, mi.mType = gax.DetermineContentType(r) 167 | } 168 | } 169 | mi.chunkRetryDeadline = opts.ChunkRetryDeadline 170 | mi.media, mi.buffer, mi.singleChunk = PrepareUpload(r, opts.ChunkSize) 171 | return mi 172 | } 173 | 174 | // NewInfoFromResumableMedia should be invoked from the ResumableMedia method of a 175 | // call. It returns a MediaInfo using the given reader, size and media type. 176 | func NewInfoFromResumableMedia(r io.ReaderAt, size int64, mediaType string) *MediaInfo { 177 | rdr := ReaderAtToReader(r, size) 178 | mType := mediaType 179 | if mType == "" { 180 | rdr, mType = gax.DetermineContentType(rdr) 181 | } 182 | 183 | return &MediaInfo{ 184 | size: size, 185 | mType: mType, 186 | buffer: NewMediaBuffer(rdr, googleapi.DefaultUploadChunkSize), 187 | media: nil, 188 | singleChunk: false, 189 | } 190 | } 191 | 192 | // SetProgressUpdater sets the progress updater for the media info. 193 | func (mi *MediaInfo) SetProgressUpdater(pu googleapi.ProgressUpdater) { 194 | if mi != nil { 195 | mi.progressUpdater = pu 196 | } 197 | } 198 | 199 | // UploadType determines the type of upload: a single request, or a resumable 200 | // series of requests. 201 | func (mi *MediaInfo) UploadType() string { 202 | if mi.singleChunk { 203 | return "multipart" 204 | } 205 | return "resumable" 206 | } 207 | 208 | // UploadRequest sets up an HTTP request for media upload. It adds headers 209 | // as necessary, and returns a replacement for the body and a function for http.Request.GetBody. 210 | func (mi *MediaInfo) UploadRequest(reqHeaders http.Header, body io.Reader) (newBody io.Reader, getBody func() (io.ReadCloser, error), cleanup func()) { 211 | cleanup = func() {} 212 | if mi == nil { 213 | return body, nil, cleanup 214 | } 215 | var media io.Reader 216 | if mi.media != nil { 217 | // This only happens when the caller has turned off chunking. In that 218 | // case, we write all of media in a single non-retryable request. 219 | media = mi.media 220 | } else if mi.singleChunk { 221 | // The data fits in a single chunk, which has now been read into the MediaBuffer. 222 | // We obtain that chunk so we can write it in a single request. The request can 223 | // be retried because the data is stored in the MediaBuffer. 224 | media, _, _, _ = mi.buffer.Chunk() 225 | } 226 | toCleanup := []io.Closer{} 227 | if media != nil { 228 | fb := readerFunc(body) 229 | fm := readerFunc(media) 230 | combined, ctype := CombineBodyMedia(body, "application/json", media, mi.mType) 231 | toCleanup = append(toCleanup, combined) 232 | if fb != nil && fm != nil { 233 | getBody = func() (io.ReadCloser, error) { 234 | rb := io.NopCloser(fb()) 235 | rm := io.NopCloser(fm()) 236 | var mimeBoundary string 237 | if _, params, err := mime.ParseMediaType(ctype); err == nil { 238 | mimeBoundary = params["boundary"] 239 | } 240 | r, _ := combineBodyMedia(rb, "application/json", rm, mi.mType, mimeBoundary) 241 | toCleanup = append(toCleanup, r) 242 | return r, nil 243 | } 244 | } 245 | reqHeaders.Set("Content-Type", ctype) 246 | body = combined 247 | } 248 | if mi.buffer != nil && mi.mType != "" && !mi.singleChunk { 249 | // This happens when initiating a resumable upload session. 250 | // The initial request contains a JSON body rather than media. 251 | // It can be retried with a getBody function that re-creates the request body. 252 | fb := readerFunc(body) 253 | if fb != nil { 254 | getBody = func() (io.ReadCloser, error) { 255 | rb := io.NopCloser(fb()) 256 | toCleanup = append(toCleanup, rb) 257 | return rb, nil 258 | } 259 | } 260 | reqHeaders.Set("X-Upload-Content-Type", mi.mType) 261 | } 262 | // Ensure that any bodies created in getBody are cleaned up. 263 | cleanup = func() { 264 | for _, closer := range toCleanup { 265 | _ = closer.Close() 266 | } 267 | } 268 | return body, getBody, cleanup 269 | } 270 | 271 | // readerFunc returns a function that always returns an io.Reader that has the same 272 | // contents as r, provided that can be done without consuming r. Otherwise, it 273 | // returns nil. 274 | // See http.NewRequest (in net/http/request.go). 275 | func readerFunc(r io.Reader) func() io.Reader { 276 | switch r := r.(type) { 277 | case *bytes.Buffer: 278 | buf := r.Bytes() 279 | return func() io.Reader { return bytes.NewReader(buf) } 280 | case *bytes.Reader: 281 | snapshot := *r 282 | return func() io.Reader { r := snapshot; return &r } 283 | case *strings.Reader: 284 | snapshot := *r 285 | return func() io.Reader { r := snapshot; return &r } 286 | default: 287 | return nil 288 | } 289 | } 290 | 291 | // ResumableUpload returns an appropriately configured ResumableUpload value if the 292 | // upload is resumable, or nil otherwise. 293 | func (mi *MediaInfo) ResumableUpload(locURI string) *ResumableUpload { 294 | if mi == nil || mi.singleChunk { 295 | return nil 296 | } 297 | return &ResumableUpload{ 298 | URI: locURI, 299 | Media: mi.buffer, 300 | MediaType: mi.mType, 301 | Callback: func(curr int64) { 302 | if mi.progressUpdater != nil { 303 | mi.progressUpdater(curr, mi.size) 304 | } 305 | }, 306 | ChunkRetryDeadline: mi.chunkRetryDeadline, 307 | } 308 | } 309 | 310 | // SetGetBody sets the GetBody field of req to f. This was once needed 311 | // to gracefully support Go 1.7 and earlier which didn't have that 312 | // field. 313 | // 314 | // Deprecated: the code generator no longer uses this as of 315 | // 2019-02-19. Nothing else should be calling this anyway, but we 316 | // won't delete this immediately; it will be deleted in as early as 6 317 | // months. 318 | func SetGetBody(req *http.Request, f func() (io.ReadCloser, error)) { 319 | req.GetBody = f 320 | } 321 | -------------------------------------------------------------------------------- /genai/internal/gensupport/media_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "bytes" 19 | cryptorand "crypto/rand" 20 | "io" 21 | mathrand "math/rand" 22 | "net/http" 23 | "strings" 24 | "testing" 25 | "time" 26 | 27 | "google.golang.org/api/googleapi" 28 | ) 29 | 30 | func TestNewInfoFromMedia(t *testing.T) { 31 | const textType = "text/plain; charset=utf-8" 32 | for _, test := range []struct { 33 | desc string 34 | r io.Reader 35 | opts []googleapi.MediaOption 36 | wantType string 37 | wantMedia, wantBuffer, wantSingleChunk bool 38 | wantDeadline time.Duration 39 | }{ 40 | { 41 | desc: "an empty reader results in a MediaBuffer with a single, empty chunk", 42 | r: new(bytes.Buffer), 43 | opts: nil, 44 | wantType: textType, 45 | wantBuffer: true, 46 | wantSingleChunk: true, 47 | }, 48 | { 49 | desc: "ContentType is observed", 50 | r: new(bytes.Buffer), 51 | opts: []googleapi.MediaOption{googleapi.ContentType("xyz")}, 52 | wantType: "xyz", 53 | wantBuffer: true, 54 | wantSingleChunk: true, 55 | }, 56 | { 57 | desc: "ChunkRetryDeadline is observed", 58 | r: new(bytes.Buffer), 59 | opts: []googleapi.MediaOption{googleapi.ChunkRetryDeadline(time.Second)}, 60 | wantType: textType, 61 | wantBuffer: true, 62 | wantSingleChunk: true, 63 | wantDeadline: time.Second, 64 | }, 65 | { 66 | desc: "chunk size of zero: don't use a MediaBuffer; upload as a single chunk", 67 | r: strings.NewReader("12345"), 68 | opts: []googleapi.MediaOption{googleapi.ChunkSize(0)}, 69 | wantType: textType, 70 | wantMedia: true, 71 | wantSingleChunk: true, 72 | }, 73 | { 74 | desc: "chunk size > data size: MediaBuffer with single chunk", 75 | r: strings.NewReader("12345"), 76 | opts: []googleapi.MediaOption{googleapi.ChunkSize(100)}, 77 | wantType: textType, 78 | wantBuffer: true, 79 | wantSingleChunk: true, 80 | }, 81 | { 82 | desc: "chunk size == data size: MediaBuffer with single chunk", 83 | r: &nullReader{googleapi.MinUploadChunkSize}, 84 | opts: []googleapi.MediaOption{googleapi.ChunkSize(1)}, 85 | wantType: "application/octet-stream", 86 | wantBuffer: true, 87 | wantSingleChunk: true, 88 | }, 89 | { 90 | desc: "chunk size < data size: MediaBuffer, not single chunk", 91 | // Note that ChunkSize = 1 is rounded up to googleapi.MinUploadChunkSize. 92 | r: &nullReader{2 * googleapi.MinUploadChunkSize}, 93 | opts: []googleapi.MediaOption{googleapi.ChunkSize(1)}, 94 | wantType: "application/octet-stream", 95 | wantBuffer: true, 96 | wantSingleChunk: false, 97 | }, 98 | } { 99 | 100 | mi := NewInfoFromMedia(test.r, test.opts) 101 | if got, want := mi.mType, test.wantType; got != want { 102 | t.Errorf("%s: type: got %q, want %q", test.desc, got, want) 103 | } 104 | if got, want := (mi.media != nil), test.wantMedia; got != want { 105 | t.Errorf("%s: media non-nil: got %t, want %t", test.desc, got, want) 106 | } 107 | if got, want := (mi.buffer != nil), test.wantBuffer; got != want { 108 | t.Errorf("%s: buffer non-nil: got %t, want %t", test.desc, got, want) 109 | } 110 | if got, want := mi.singleChunk, test.wantSingleChunk; got != want { 111 | t.Errorf("%s: singleChunk: got %t, want %t", test.desc, got, want) 112 | } 113 | if got, want := mi.chunkRetryDeadline, test.wantDeadline; got != want { 114 | t.Errorf("%s: chunkRetryDeadline: got %v, want %v", test.desc, got, want) 115 | } 116 | } 117 | } 118 | 119 | func TestUploadRequest(t *testing.T) { 120 | for _, test := range []struct { 121 | desc string 122 | r io.Reader 123 | chunkSize int 124 | wantContentType string 125 | wantUploadType string 126 | }{ 127 | { 128 | desc: "chunk size of zero: don't use a MediaBuffer; upload as a single chunk", 129 | r: strings.NewReader("12345"), 130 | chunkSize: 0, 131 | wantContentType: "multipart/related;", 132 | }, 133 | { 134 | desc: "chunk size > data size: MediaBuffer with single chunk", 135 | r: strings.NewReader("12345"), 136 | chunkSize: 100, 137 | wantContentType: "multipart/related;", 138 | }, 139 | { 140 | desc: "chunk size == data size: MediaBuffer with single chunk", 141 | r: &nullReader{googleapi.MinUploadChunkSize}, 142 | chunkSize: 1, 143 | wantContentType: "multipart/related;", 144 | }, 145 | { 146 | desc: "chunk size < data size: MediaBuffer, not single chunk", 147 | // Note that ChunkSize = 1 is rounded up to googleapi.MinUploadChunkSize. 148 | r: &nullReader{2 * googleapi.MinUploadChunkSize}, 149 | chunkSize: 1, 150 | wantUploadType: "application/octet-stream", 151 | }, 152 | } { 153 | mi := NewInfoFromMedia(test.r, []googleapi.MediaOption{googleapi.ChunkSize(test.chunkSize)}) 154 | h := http.Header{} 155 | mi.UploadRequest(h, new(bytes.Buffer)) 156 | if got, want := h.Get("Content-Type"), test.wantContentType; !strings.HasPrefix(got, want) { 157 | t.Errorf("%s: Content-Type: got %q, want prefix %q", test.desc, got, want) 158 | } 159 | if got, want := h.Get("X-Upload-Content-Type"), test.wantUploadType; got != want { 160 | t.Errorf("%s: X-Upload-Content-Type: got %q, want %q", test.desc, got, want) 161 | } 162 | } 163 | } 164 | 165 | func TestUploadRequestGetBody(t *testing.T) { 166 | // Test that a single chunk results in a getBody function that is non-nil, and 167 | // that produces the same content as the original body. 168 | 169 | // Restore the crypto/rand.Reader mocked out below. 170 | defer func(old io.Reader) { cryptorand.Reader = old }(cryptorand.Reader) 171 | 172 | for i, test := range []struct { 173 | desc string 174 | r io.Reader 175 | chunkSize int 176 | wantGetBody bool 177 | }{ 178 | { 179 | desc: "chunk size of zero: no getBody", 180 | r: &nullReader{10}, 181 | chunkSize: 0, 182 | wantGetBody: false, 183 | }, 184 | { 185 | desc: "chunk size == data size: 1 chunk, getBody", 186 | r: &nullReader{googleapi.MinUploadChunkSize}, 187 | chunkSize: 1, 188 | wantGetBody: true, 189 | }, 190 | { 191 | desc: "chunk size < data size: MediaBuffer, >1 chunk, getBody", 192 | // Note that ChunkSize = 1 is rounded up to googleapi.MinUploadChunkSize. 193 | r: &nullReader{2 * googleapi.MinUploadChunkSize}, 194 | chunkSize: 1, 195 | wantGetBody: true, 196 | }, 197 | } { 198 | cryptorand.Reader = mathrand.New(mathrand.NewSource(int64(i))) 199 | 200 | mi := NewInfoFromMedia(test.r, []googleapi.MediaOption{googleapi.ChunkSize(test.chunkSize)}) 201 | r, getBody, _ := mi.UploadRequest(http.Header{}, bytes.NewBuffer([]byte("body"))) 202 | if got, want := (getBody != nil), test.wantGetBody; got != want { 203 | t.Errorf("%s: getBody: got %t, want %t", test.desc, got, want) 204 | continue 205 | } 206 | if getBody == nil { 207 | continue 208 | } 209 | want, err := io.ReadAll(r) 210 | if err != nil { 211 | t.Fatal(err) 212 | } 213 | for i := 0; i < 3; i++ { 214 | rc, err := getBody() 215 | if err != nil { 216 | t.Fatal(err) 217 | } 218 | got, err := io.ReadAll(rc) 219 | if err != nil { 220 | t.Fatal(err) 221 | } 222 | if !bytes.Equal(got, want) { 223 | t.Errorf("%s, %d:\ngot:\n%s\nwant:\n%s", test.desc, i, string(got), string(want)) 224 | } 225 | } 226 | } 227 | } 228 | 229 | func TestResumableUpload(t *testing.T) { 230 | for _, test := range []struct { 231 | desc string 232 | r io.Reader 233 | chunkSize int 234 | wantUploadType string 235 | wantResumableUpload bool 236 | chunkRetryDeadline time.Duration 237 | }{ 238 | { 239 | desc: "chunk size of zero: don't use a MediaBuffer; upload as a single chunk", 240 | r: strings.NewReader("12345"), 241 | chunkSize: 0, 242 | wantUploadType: "multipart", 243 | wantResumableUpload: false, 244 | }, 245 | { 246 | desc: "chunk size > data size: MediaBuffer with single chunk", 247 | r: strings.NewReader("12345"), 248 | chunkSize: 100, 249 | wantUploadType: "multipart", 250 | wantResumableUpload: false, 251 | }, 252 | { 253 | desc: "chunk size == data size: MediaBuffer with single chunk", 254 | // (Because nullReader returns EOF with the last bytes.) 255 | r: &nullReader{googleapi.MinUploadChunkSize}, 256 | chunkSize: googleapi.MinUploadChunkSize, 257 | wantUploadType: "multipart", 258 | wantResumableUpload: false, 259 | }, 260 | { 261 | desc: "chunk size < data size: MediaBuffer, not single chunk", 262 | // Note that ChunkSize = 1 is rounded up to googleapi.MinUploadChunkSize. 263 | r: &nullReader{2 * googleapi.MinUploadChunkSize}, 264 | chunkSize: 1, 265 | wantUploadType: "resumable", 266 | wantResumableUpload: true, 267 | }, 268 | { 269 | desc: "confirm that ChunkRetryDeadline is carried to ResumableUpload", 270 | r: &nullReader{2 * googleapi.MinUploadChunkSize}, 271 | chunkSize: 1, 272 | wantUploadType: "resumable", 273 | wantResumableUpload: true, 274 | chunkRetryDeadline: 1 * time.Second, 275 | }, 276 | } { 277 | opts := []googleapi.MediaOption{googleapi.ChunkSize(test.chunkSize)} 278 | if test.chunkRetryDeadline != 0 { 279 | opts = append(opts, googleapi.ChunkRetryDeadline(test.chunkRetryDeadline)) 280 | } 281 | mi := NewInfoFromMedia(test.r, opts) 282 | if got, want := mi.UploadType(), test.wantUploadType; got != want { 283 | t.Errorf("%s: upload type: got %q, want %q", test.desc, got, want) 284 | } 285 | if got, want := mi.ResumableUpload("") != nil, test.wantResumableUpload; got != want { 286 | t.Errorf("%s: resumable upload non-nil: got %t, want %t", test.desc, got, want) 287 | } 288 | if test.chunkRetryDeadline != 0 { 289 | if got := mi.ResumableUpload(""); got != nil { 290 | if got.ChunkRetryDeadline != test.chunkRetryDeadline { 291 | t.Errorf("%s: ChunkRetryDeadline: got %v, want %v", test.desc, got.ChunkRetryDeadline, test.chunkRetryDeadline) 292 | } 293 | } else { 294 | t.Errorf("%s: test case invalid; resumable upload is nil", test.desc) 295 | } 296 | } 297 | } 298 | } 299 | 300 | // A nullReader simulates reading a fixed number of bytes. 301 | type nullReader struct { 302 | remain int 303 | } 304 | 305 | // Read doesn't touch buf, but it does reduce the amount of bytes remaining 306 | // by len(buf). 307 | func (r *nullReader) Read(buf []byte) (int, error) { 308 | n := len(buf) 309 | if r.remain < n { 310 | n = r.remain 311 | } 312 | r.remain -= n 313 | var err error 314 | if r.remain == 0 { 315 | err = io.EOF 316 | } 317 | return n, err 318 | } 319 | -------------------------------------------------------------------------------- /genai/internal/gensupport/params.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "net/http" 19 | "net/url" 20 | 21 | "google.golang.org/api/googleapi" 22 | ) 23 | 24 | // URLParams is a simplified replacement for url.Values 25 | // that safely builds up URL parameters for encoding. 26 | type URLParams map[string][]string 27 | 28 | // Get returns the first value for the given key, or "". 29 | func (u URLParams) Get(key string) string { 30 | vs := u[key] 31 | if len(vs) == 0 { 32 | return "" 33 | } 34 | return vs[0] 35 | } 36 | 37 | // Set sets the key to value. 38 | // It replaces any existing values. 39 | func (u URLParams) Set(key, value string) { 40 | u[key] = []string{value} 41 | } 42 | 43 | // SetMulti sets the key to an array of values. 44 | // It replaces any existing values. 45 | // Note that values must not be modified after calling SetMulti 46 | // so the caller is responsible for making a copy if necessary. 47 | func (u URLParams) SetMulti(key string, values []string) { 48 | u[key] = values 49 | } 50 | 51 | // Encode encodes the values into “URL encoded” form 52 | // ("bar=baz&foo=quux") sorted by key. 53 | func (u URLParams) Encode() string { 54 | return url.Values(u).Encode() 55 | } 56 | 57 | // SetOptions sets the URL params and any additional `CallOption` or 58 | // `MultiCallOption` passed in. 59 | func SetOptions(u URLParams, opts ...googleapi.CallOption) { 60 | for _, o := range opts { 61 | m, ok := o.(googleapi.MultiCallOption) 62 | if ok { 63 | u.SetMulti(m.GetMulti()) 64 | continue 65 | } 66 | u.Set(o.Get()) 67 | } 68 | } 69 | 70 | // SetHeaders sets common headers for all requests. The keyvals header pairs 71 | // should have a corresponding value for every key provided. If there is an odd 72 | // number of keyvals this method will panic. 73 | func SetHeaders(userAgent, contentType string, userHeaders http.Header, keyvals ...string) http.Header { 74 | reqHeaders := make(http.Header) 75 | reqHeaders.Set("x-goog-api-client", "gl-go/"+GoVersion()+" gdcl/"+"0.179.0") 76 | for i := 0; i < len(keyvals); i = i + 2 { 77 | reqHeaders.Set(keyvals[i], keyvals[i+1]) 78 | } 79 | reqHeaders.Set("User-Agent", userAgent) 80 | if contentType != "" { 81 | reqHeaders.Set("Content-Type", contentType) 82 | } 83 | for k, v := range userHeaders { 84 | reqHeaders[k] = v 85 | } 86 | return reqHeaders 87 | } 88 | -------------------------------------------------------------------------------- /genai/internal/gensupport/params_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "net/http" 19 | "testing" 20 | 21 | "google.golang.org/api/googleapi" 22 | ) 23 | 24 | func TestSetOptionsGetMulti(t *testing.T) { 25 | co := googleapi.QueryParameter("key", "foo", "bar") 26 | urlParams := make(URLParams) 27 | SetOptions(urlParams, co) 28 | if got, want := urlParams.Encode(), "key=foo&key=bar"; got != want { 29 | t.Fatalf("URLParams.Encode() = %q, want %q", got, want) 30 | } 31 | } 32 | 33 | func TestSetHeaders(t *testing.T) { 34 | userAgent := "google-api-go-client/123" 35 | contentType := "application/json" 36 | userHeaders := make(http.Header) 37 | userHeaders.Set("baz", "300") 38 | got := SetHeaders(userAgent, contentType, userHeaders, "foo", "100", "bar", "200") 39 | 40 | if len(got) != 6 { 41 | t.Fatalf("SetHeaders() = %q, want len(6)", got) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /genai/internal/gensupport/resumable.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "context" 19 | "errors" 20 | "fmt" 21 | "io" 22 | "net/http" 23 | "strings" 24 | "sync" 25 | "time" 26 | 27 | "github.com/google/uuid" 28 | ) 29 | 30 | // ResumableUpload is used by the generated APIs to provide resumable uploads. 31 | // It is not used by developers directly. 32 | type ResumableUpload struct { 33 | Client *http.Client 34 | // URI is the resumable resource destination provided by the server after specifying "&uploadType=resumable". 35 | URI string 36 | UserAgent string // User-Agent for header of the request 37 | // Media is the object being uploaded. 38 | Media *MediaBuffer 39 | // MediaType defines the media type, e.g. "image/jpeg". 40 | MediaType string 41 | 42 | mu sync.Mutex // guards progress 43 | progress int64 // number of bytes uploaded so far 44 | 45 | // Callback is an optional function that will be periodically called with the cumulative number of bytes uploaded. 46 | Callback func(int64) 47 | 48 | // Retry optionally configures retries for requests made against the upload. 49 | Retry *RetryConfig 50 | 51 | // ChunkRetryDeadline configures the per-chunk deadline after which no further 52 | // retries should happen. 53 | ChunkRetryDeadline time.Duration 54 | 55 | // Track current request invocation ID and attempt count for retry metrics 56 | // and idempotency headers. 57 | invocationID string 58 | attempts int 59 | } 60 | 61 | // Progress returns the number of bytes uploaded at this point. 62 | func (rx *ResumableUpload) Progress() int64 { 63 | rx.mu.Lock() 64 | defer rx.mu.Unlock() 65 | return rx.progress 66 | } 67 | 68 | // doUploadRequest performs a single HTTP request to upload data. 69 | // off specifies the offset in rx.Media from which data is drawn. 70 | // size is the number of bytes in data. 71 | // final specifies whether data is the final chunk to be uploaded. 72 | func (rx *ResumableUpload) doUploadRequest(ctx context.Context, data io.Reader, off, size int64, final bool) (*http.Response, error) { 73 | req, err := http.NewRequest("POST", rx.URI, data) 74 | if err != nil { 75 | return nil, err 76 | } 77 | 78 | req.ContentLength = size 79 | var contentRange string 80 | if final { 81 | if size == 0 { 82 | contentRange = fmt.Sprintf("bytes */%v", off) 83 | } else { 84 | contentRange = fmt.Sprintf("bytes %v-%v/%v", off, off+size-1, off+size) 85 | } 86 | } else { 87 | contentRange = fmt.Sprintf("bytes %v-%v/*", off, off+size-1) 88 | } 89 | req.Header.Set("Content-Range", contentRange) 90 | req.Header.Set("Content-Type", rx.MediaType) 91 | req.Header.Set("User-Agent", rx.UserAgent) 92 | 93 | // TODO(b/274504690): Consider dropping gccl-invocation-id key since it 94 | // duplicates the X-Goog-Gcs-Idempotency-Token header (added in v0.115.0). 95 | baseXGoogHeader := "gl-go/" + GoVersion() + " gdcl/" + "0.179.0" 96 | invocationHeader := fmt.Sprintf("gccl-invocation-id/%s gccl-attempt-count/%d", rx.invocationID, rx.attempts) 97 | req.Header.Set("X-Goog-Api-Client", strings.Join([]string{baseXGoogHeader, invocationHeader}, " ")) 98 | 99 | // Set idempotency token header which is used by GCS uploads. 100 | req.Header.Set("X-Goog-Gcs-Idempotency-Token", rx.invocationID) 101 | 102 | // Google's upload endpoint uses status code 308 for a 103 | // different purpose than the "308 Permanent Redirect" 104 | // since-standardized in RFC 7238. Because of the conflict in 105 | // semantics, Google added this new request header which 106 | // causes it to not use "308" and instead reply with 200 OK 107 | // and sets the upload-specific "X-HTTP-Status-Code-Override: 108 | // 308" response header. 109 | req.Header.Set("X-GUploader-No-308", "yes") 110 | 111 | return SendRequest(ctx, rx.Client, req) 112 | } 113 | 114 | func statusResumeIncomplete(resp *http.Response) bool { 115 | // This is how the server signals "status resume incomplete" 116 | // when X-GUploader-No-308 is set to "yes": 117 | return resp != nil && resp.Header.Get("X-Http-Status-Code-Override") == "308" 118 | } 119 | 120 | // reportProgress calls a user-supplied callback to report upload progress. 121 | // If old==updated, the callback is not called. 122 | func (rx *ResumableUpload) reportProgress(old, updated int64) { 123 | if updated-old == 0 { 124 | return 125 | } 126 | rx.mu.Lock() 127 | rx.progress = updated 128 | rx.mu.Unlock() 129 | if rx.Callback != nil { 130 | rx.Callback(updated) 131 | } 132 | } 133 | 134 | // transferChunk performs a single HTTP request to upload a single chunk from rx.Media. 135 | func (rx *ResumableUpload) transferChunk(ctx context.Context) (*http.Response, error) { 136 | chunk, off, size, err := rx.Media.Chunk() 137 | 138 | done := err == io.EOF 139 | if !done && err != nil { 140 | return nil, err 141 | } 142 | 143 | res, err := rx.doUploadRequest(ctx, chunk, off, int64(size), done) 144 | if err != nil { 145 | return res, err 146 | } 147 | 148 | // We sent "X-GUploader-No-308: yes" (see comment elsewhere in 149 | // this file), so we don't expect to get a 308. 150 | if res.StatusCode == 308 { 151 | return nil, errors.New("unexpected 308 response status code") 152 | } 153 | 154 | if res.StatusCode == http.StatusOK { 155 | rx.reportProgress(off, off+int64(size)) 156 | } 157 | 158 | if statusResumeIncomplete(res) { 159 | rx.Media.Next() 160 | } 161 | return res, nil 162 | } 163 | 164 | // Upload starts the process of a resumable upload with a cancellable context. 165 | // It retries using the provided back off strategy until cancelled or the 166 | // strategy indicates to stop retrying. 167 | // It is called from the auto-generated API code and is not visible to the user. 168 | // Before sending an HTTP request, Upload calls any registered hook functions, 169 | // and calls the returned functions after the request returns (see send.go). 170 | // rx is private to the auto-generated API code. 171 | // Exactly one of resp or err will be nil. If resp is non-nil, the caller must call resp.Body.Close. 172 | func (rx *ResumableUpload) Upload(ctx context.Context) (resp *http.Response, err error) { 173 | // There are a couple of cases where it's possible for err and resp to both 174 | // be non-nil. However, we expose a simpler contract to our callers: exactly 175 | // one of resp and err will be non-nil. This means that any response body 176 | // must be closed here before returning a non-nil error. 177 | prepareReturn := func(resp *http.Response, err error) (*http.Response, error) { 178 | if err != nil { 179 | if resp != nil && resp.Body != nil { 180 | resp.Body.Close() 181 | } 182 | return nil, err 183 | } 184 | // This case is very unlikely but possible only if rx.ChunkRetryDeadline is 185 | // set to a very small value, in which case no requests will be sent before 186 | // the deadline. Return an error to avoid causing a panic. 187 | if resp == nil { 188 | return nil, fmt.Errorf("upload request to %v not sent, choose larger value for ChunkRetryDeadline", rx.URI) 189 | } 190 | return resp, nil 191 | } 192 | // Configure retryable error criteria. 193 | errorFunc := rx.Retry.errorFunc() 194 | 195 | // Configure per-chunk retry deadline. 196 | var retryDeadline time.Duration 197 | if rx.ChunkRetryDeadline != 0 { 198 | retryDeadline = rx.ChunkRetryDeadline 199 | } else { 200 | retryDeadline = defaultRetryDeadline 201 | } 202 | 203 | // Send all chunks. 204 | for { 205 | var pause time.Duration 206 | 207 | // Each chunk gets its own initialized-at-zero backoff and invocation ID. 208 | bo := rx.Retry.backoff() 209 | quitAfterTimer := time.NewTimer(retryDeadline) 210 | rx.attempts = 1 211 | rx.invocationID = uuid.New().String() 212 | 213 | // Retry loop for a single chunk. 214 | for { 215 | pauseTimer := time.NewTimer(pause) 216 | select { 217 | case <-ctx.Done(): 218 | quitAfterTimer.Stop() 219 | pauseTimer.Stop() 220 | if err == nil { 221 | err = ctx.Err() 222 | } 223 | return prepareReturn(resp, err) 224 | case <-pauseTimer.C: 225 | case <-quitAfterTimer.C: 226 | pauseTimer.Stop() 227 | return prepareReturn(resp, err) 228 | } 229 | pauseTimer.Stop() 230 | 231 | // Check for context cancellation or timeout once more. If more than one 232 | // case in the select statement above was satisfied at the same time, Go 233 | // will choose one arbitrarily. 234 | // That can cause an operation to go through even if the context was 235 | // canceled before or the timeout was reached. 236 | select { 237 | case <-ctx.Done(): 238 | quitAfterTimer.Stop() 239 | if err == nil { 240 | err = ctx.Err() 241 | } 242 | return prepareReturn(resp, err) 243 | case <-quitAfterTimer.C: 244 | return prepareReturn(resp, err) 245 | default: 246 | } 247 | 248 | resp, err = rx.transferChunk(ctx) 249 | 250 | var status int 251 | if resp != nil { 252 | status = resp.StatusCode 253 | } 254 | 255 | // Check if we should retry the request. 256 | if !errorFunc(status, err) { 257 | quitAfterTimer.Stop() 258 | break 259 | } 260 | 261 | rx.attempts++ 262 | pause = bo.Pause() 263 | if resp != nil && resp.Body != nil { 264 | resp.Body.Close() 265 | } 266 | } 267 | 268 | // If the chunk was uploaded successfully, but there's still 269 | // more to go, upload the next chunk without any delay. 270 | if statusResumeIncomplete(resp) { 271 | resp.Body.Close() 272 | continue 273 | } 274 | 275 | return prepareReturn(resp, err) 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /genai/internal/gensupport/resumable_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "context" 19 | "fmt" 20 | "io" 21 | "net/http" 22 | "reflect" 23 | "strings" 24 | "testing" 25 | "time" 26 | ) 27 | 28 | type unexpectedReader struct{} 29 | 30 | func (unexpectedReader) Read([]byte) (int, error) { 31 | return 0, fmt.Errorf("unexpected read in test") 32 | } 33 | 34 | // event is an expected request/response pair 35 | type event struct { 36 | // the byte range header that should be present in a request. 37 | byteRange string 38 | // the http status code to send in response. 39 | responseStatus int 40 | } 41 | 42 | // interruptibleTransport is configured with a canned set of requests/responses. 43 | // It records the incoming data, unless the corresponding event is configured to return 44 | // http.StatusServiceUnavailable. 45 | type interruptibleTransport struct { 46 | events []event 47 | buf []byte 48 | bodies bodyTracker 49 | } 50 | 51 | // bodyTracker keeps track of response bodies that have not been closed. 52 | type bodyTracker map[io.ReadCloser]struct{} 53 | 54 | func (bt bodyTracker) Add(body io.ReadCloser) { 55 | bt[body] = struct{}{} 56 | } 57 | 58 | func (bt bodyTracker) Close(body io.ReadCloser) { 59 | delete(bt, body) 60 | } 61 | 62 | type trackingCloser struct { 63 | io.Reader 64 | tracker bodyTracker 65 | } 66 | 67 | func (tc *trackingCloser) Close() error { 68 | tc.tracker.Close(tc) 69 | return nil 70 | } 71 | 72 | func (tc *trackingCloser) Open() { 73 | tc.tracker.Add(tc) 74 | } 75 | 76 | func (t *interruptibleTransport) RoundTrip(req *http.Request) (*http.Response, error) { 77 | if len(t.events) == 0 { 78 | panic("ran out of events, but got a request") 79 | } 80 | ev := t.events[0] 81 | t.events = t.events[1:] 82 | if got, want := req.Header.Get("Content-Range"), ev.byteRange; got != want { 83 | return nil, fmt.Errorf("byte range: got %s; want %s", got, want) 84 | } 85 | 86 | if ev.responseStatus != http.StatusServiceUnavailable { 87 | buf, err := io.ReadAll(req.Body) 88 | if err != nil { 89 | return nil, fmt.Errorf("error reading from request data: %v", err) 90 | } 91 | t.buf = append(t.buf, buf...) 92 | } 93 | 94 | tc := &trackingCloser{unexpectedReader{}, t.bodies} 95 | tc.Open() 96 | h := http.Header{} 97 | status := ev.responseStatus 98 | 99 | // Support "X-GUploader-No-308" like Google: 100 | if status == 308 && req.Header.Get("X-GUploader-No-308") == "yes" { 101 | status = 200 102 | h.Set("X-Http-Status-Code-Override", "308") 103 | } 104 | 105 | res := &http.Response{ 106 | StatusCode: status, 107 | Header: h, 108 | Body: tc, 109 | } 110 | return res, nil 111 | } 112 | 113 | // progressRecorder records updates, and calls f for every invocation of ProgressUpdate. 114 | type progressRecorder struct { 115 | updates []int64 116 | f func() 117 | } 118 | 119 | func (pr *progressRecorder) ProgressUpdate(current int64) { 120 | pr.updates = append(pr.updates, current) 121 | if pr.f != nil { 122 | pr.f() 123 | } 124 | } 125 | 126 | func TestInterruptedTransferChunks(t *testing.T) { 127 | type testCase struct { 128 | name string 129 | data string 130 | chunkSize int 131 | events []event 132 | wantProgress []int64 133 | } 134 | 135 | for _, tc := range []testCase{ 136 | { 137 | name: "large", 138 | data: strings.Repeat("a", 300), 139 | chunkSize: 90, 140 | events: []event{ 141 | {"bytes 0-89/*", http.StatusServiceUnavailable}, 142 | {"bytes 0-89/*", 308}, 143 | {"bytes 90-179/*", 308}, 144 | {"bytes 180-269/*", http.StatusServiceUnavailable}, 145 | {"bytes 180-269/*", 308}, 146 | {"bytes 270-299/300", 200}, 147 | }, 148 | wantProgress: []int64{90, 180, 270, 300}, 149 | }, 150 | { 151 | name: "small", 152 | data: strings.Repeat("a", 20), 153 | chunkSize: 10, 154 | events: []event{ 155 | {"bytes 0-9/*", http.StatusServiceUnavailable}, 156 | {"bytes 0-9/*", 308}, 157 | {"bytes 10-19/*", http.StatusServiceUnavailable}, 158 | {"bytes 10-19/*", 308}, 159 | // 0 byte final request demands a byte range with leading asterix. 160 | {"bytes */20", http.StatusServiceUnavailable}, 161 | {"bytes */20", 200}, 162 | }, 163 | wantProgress: []int64{10, 20}, 164 | }, 165 | } { 166 | t.Run(tc.name, func(t *testing.T) { 167 | media := strings.NewReader(tc.data) 168 | 169 | tr := &interruptibleTransport{ 170 | buf: make([]byte, 0, len(tc.data)), 171 | events: tc.events, 172 | bodies: bodyTracker{}, 173 | } 174 | 175 | pr := progressRecorder{} 176 | rx := &ResumableUpload{ 177 | Client: &http.Client{Transport: tr}, 178 | Media: NewMediaBuffer(media, tc.chunkSize), 179 | MediaType: "text/plain", 180 | Callback: pr.ProgressUpdate, 181 | } 182 | 183 | oldBackoff := backoff 184 | backoff = func() Backoff { return new(NoPauseBackoff) } 185 | defer func() { backoff = oldBackoff }() 186 | 187 | res, err := rx.Upload(context.Background()) 188 | if err == nil { 189 | res.Body.Close() 190 | } 191 | if err != nil || res == nil || res.StatusCode != http.StatusOK { 192 | if res == nil { 193 | t.Fatalf("Upload not successful, res=nil: %v", err) 194 | } else { 195 | t.Fatalf("Upload not successful, statusCode=%v, err=%v", res.StatusCode, err) 196 | } 197 | } 198 | if !reflect.DeepEqual(tr.buf, []byte(tc.data)) { 199 | t.Fatalf("transferred contents:\ngot %s\nwant %s", tr.buf, tc.data) 200 | } 201 | 202 | if !reflect.DeepEqual(pr.updates, tc.wantProgress) { 203 | t.Fatalf("progress updates: got %v, want %v", pr.updates, tc.wantProgress) 204 | } 205 | 206 | if len(tr.events) > 0 { 207 | t.Fatalf("did not observe all expected events. leftover events: %v", tr.events) 208 | } 209 | if len(tr.bodies) > 0 { 210 | t.Errorf("unclosed request bodies: %v", tr.bodies) 211 | } 212 | }) 213 | } 214 | } 215 | 216 | func TestCancelUploadFast(t *testing.T) { 217 | const ( 218 | chunkSize = 90 219 | mediaSize = 300 220 | ) 221 | media := strings.NewReader(strings.Repeat("a", mediaSize)) 222 | 223 | tr := &interruptibleTransport{ 224 | buf: make([]byte, 0, mediaSize), 225 | } 226 | 227 | pr := progressRecorder{} 228 | rx := &ResumableUpload{ 229 | Client: &http.Client{Transport: tr}, 230 | Media: NewMediaBuffer(media, chunkSize), 231 | MediaType: "text/plain", 232 | Callback: pr.ProgressUpdate, 233 | } 234 | 235 | oldBackoff := backoff 236 | backoff = func() Backoff { return new(NoPauseBackoff) } 237 | defer func() { backoff = oldBackoff }() 238 | 239 | ctx, cancelFunc := context.WithCancel(context.Background()) 240 | cancelFunc() // stop the upload that hasn't started yet 241 | res, err := rx.Upload(ctx) 242 | if err != context.Canceled { 243 | t.Fatalf("Upload err: got: %v; want: context cancelled", err) 244 | } 245 | if res != nil { 246 | t.Fatalf("Upload result: got: %v; want: nil", res) 247 | } 248 | if pr.updates != nil { 249 | t.Errorf("progress updates: got %v; want: nil", pr.updates) 250 | } 251 | } 252 | 253 | func TestCancelUploadBasic(t *testing.T) { 254 | const ( 255 | chunkSize = 90 256 | mediaSize = 300 257 | ) 258 | media := strings.NewReader(strings.Repeat("a", mediaSize)) 259 | 260 | tr := &interruptibleTransport{ 261 | buf: make([]byte, 0, mediaSize), 262 | events: []event{ 263 | {"bytes 0-89/*", http.StatusServiceUnavailable}, 264 | {"bytes 0-89/*", 308}, 265 | {"bytes 90-179/*", 308}, 266 | {"bytes 180-269/*", 308}, // Upload should be cancelled before this event. 267 | }, 268 | bodies: bodyTracker{}, 269 | } 270 | 271 | ctx, cancelFunc := context.WithCancel(context.Background()) 272 | numUpdates := 0 273 | 274 | pr := progressRecorder{f: func() { 275 | numUpdates++ 276 | if numUpdates >= 2 { 277 | cancelFunc() 278 | } 279 | }} 280 | 281 | rx := &ResumableUpload{ 282 | Client: &http.Client{Transport: tr}, 283 | Media: NewMediaBuffer(media, chunkSize), 284 | MediaType: "text/plain", 285 | Callback: pr.ProgressUpdate, 286 | } 287 | 288 | oldBackoff := backoff 289 | backoff = func() Backoff { return new(PauseOneSecond) } 290 | defer func() { backoff = oldBackoff }() 291 | 292 | res, err := rx.Upload(ctx) 293 | if err != context.Canceled { 294 | t.Fatalf("Upload err: got: %v; want: context cancelled", err) 295 | } 296 | if res != nil { 297 | t.Fatalf("Upload result: got: %v; want: nil", res) 298 | } 299 | if got, want := tr.buf, []byte(strings.Repeat("a", chunkSize*2)); !reflect.DeepEqual(got, want) { 300 | t.Fatalf("transferred contents:\ngot %s\nwant %s", got, want) 301 | } 302 | if got, want := pr.updates, []int64{chunkSize, chunkSize * 2}; !reflect.DeepEqual(got, want) { 303 | t.Fatalf("progress updates: got %v; want: %v", got, want) 304 | } 305 | if len(tr.bodies) > 0 { 306 | t.Errorf("unclosed request bodies: %v", tr.bodies) 307 | } 308 | } 309 | 310 | func TestRetry_EachChunkHasItsOwnRetryDeadline(t *testing.T) { 311 | const ( 312 | chunkSize = 90 313 | mediaSize = 300 314 | ) 315 | media := strings.NewReader(strings.Repeat("a", mediaSize)) 316 | 317 | // This transport returns multiple errors on both the first chunk and third 318 | // chunk of the upload. If the timeout were not reset between chunks, the 319 | // errors on the third chunk would not retry and cause a failure. 320 | tr := &interruptibleTransport{ 321 | buf: make([]byte, 0, mediaSize), 322 | events: []event{ 323 | {"bytes 0-89/*", http.StatusServiceUnavailable}, 324 | // cum: 1s sleep 325 | {"bytes 0-89/*", http.StatusServiceUnavailable}, 326 | // cum: 2s sleep 327 | {"bytes 0-89/*", http.StatusServiceUnavailable}, 328 | // cum: 3s sleep 329 | {"bytes 0-89/*", http.StatusServiceUnavailable}, 330 | // cum: 4s sleep 331 | {"bytes 0-89/*", 308}, 332 | // cum: 1s sleep <-- resets because it's a new chunk 333 | {"bytes 90-179/*", 308}, 334 | // cum: 1s sleep <-- resets because it's a new chunk 335 | {"bytes 180-269/*", http.StatusServiceUnavailable}, 336 | // cum: 1s sleep on later chunk 337 | {"bytes 180-269/*", http.StatusServiceUnavailable}, 338 | // cum: 2s sleep on later chunk 339 | {"bytes 180-269/*", 308}, 340 | // cum: 3s sleep <-- resets because it's a new chunk 341 | {"bytes 270-299/300", 200}, 342 | }, 343 | bodies: bodyTracker{}, 344 | } 345 | 346 | rx := &ResumableUpload{ 347 | Client: &http.Client{Transport: tr}, 348 | Media: NewMediaBuffer(media, chunkSize), 349 | MediaType: "text/plain", 350 | Callback: func(int64) {}, 351 | ChunkRetryDeadline: 5 * time.Second, 352 | } 353 | 354 | oldBackoff := backoff 355 | backoff = func() Backoff { return new(PauseOneSecond) } 356 | defer func() { backoff = oldBackoff }() 357 | 358 | resCode := make(chan int, 1) 359 | go func() { 360 | resp, err := rx.Upload(context.Background()) 361 | if err != nil { 362 | t.Error(err) 363 | return 364 | } 365 | resCode <- resp.StatusCode 366 | }() 367 | 368 | select { 369 | case <-time.After(15 * time.Second): 370 | t.Fatal("timed out waiting for Upload to complete") 371 | case got := <-resCode: 372 | if want := http.StatusOK; got != want { 373 | t.Fatalf("want %d, got %d", want, got) 374 | } 375 | } 376 | } 377 | -------------------------------------------------------------------------------- /genai/internal/gensupport/retry.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "errors" 19 | "io" 20 | "net" 21 | "strings" 22 | "time" 23 | 24 | "github.com/googleapis/gax-go/v2" 25 | "google.golang.org/api/googleapi" 26 | ) 27 | 28 | // Backoff is an interface around gax.Backoff's Pause method, allowing tests to provide their 29 | // own implementation. 30 | type Backoff interface { 31 | Pause() time.Duration 32 | } 33 | 34 | // These are declared as global variables so that tests can overwrite them. 35 | var ( 36 | // Default per-chunk deadline for resumable uploads. 37 | defaultRetryDeadline = 32 * time.Second 38 | // Default backoff timer. 39 | backoff = func() Backoff { 40 | return &gax.Backoff{Initial: 100 * time.Millisecond} 41 | } 42 | // syscallRetryable is a platform-specific hook, specified in retryable_linux.go 43 | syscallRetryable func(error) bool = func(err error) bool { return false } 44 | ) 45 | 46 | const ( 47 | // statusTooManyRequests is returned by the storage API if the 48 | // per-project limits have been temporarily exceeded. The request 49 | // should be retried. 50 | // https://cloud.google.com/storage/docs/json_api/v1/status-codes#standardcodes 51 | statusTooManyRequests = 429 52 | 53 | // statusRequestTimeout is returned by the storage API if the 54 | // upload connection was broken. The request should be retried. 55 | statusRequestTimeout = 408 56 | ) 57 | 58 | // shouldRetry indicates whether an error is retryable for the purposes of this 59 | // package, unless a ShouldRetry func is specified by the RetryConfig instead. 60 | // It follows guidance from 61 | // https://cloud.google.com/storage/docs/exponential-backoff . 62 | func shouldRetry(status int, err error) bool { 63 | if 500 <= status && status <= 599 { 64 | return true 65 | } 66 | if status == statusTooManyRequests || status == statusRequestTimeout { 67 | return true 68 | } 69 | if err == io.ErrUnexpectedEOF { 70 | return true 71 | } 72 | // Transient network errors should be retried. 73 | if syscallRetryable(err) { 74 | return true 75 | } 76 | if err, ok := err.(interface{ Temporary() bool }); ok { 77 | if err.Temporary() { 78 | return true 79 | } 80 | } 81 | var opErr *net.OpError 82 | if errors.As(err, &opErr) { 83 | if strings.Contains(opErr.Error(), "use of closed network connection") { 84 | // TODO: check against net.ErrClosed (go 1.16+) instead of string 85 | return true 86 | } 87 | } 88 | 89 | // If Go 1.13 error unwrapping is available, use this to examine wrapped 90 | // errors. 91 | if err, ok := err.(interface{ Unwrap() error }); ok { 92 | return shouldRetry(status, err.Unwrap()) 93 | } 94 | return false 95 | } 96 | 97 | // RetryConfig allows configuration of backoff timing and retryable errors. 98 | type RetryConfig struct { 99 | Backoff *gax.Backoff 100 | ShouldRetry func(err error) bool 101 | } 102 | 103 | // Get a new backoff object based on the configured values. 104 | func (r *RetryConfig) backoff() Backoff { 105 | if r == nil || r.Backoff == nil { 106 | return backoff() 107 | } 108 | return &gax.Backoff{ 109 | Initial: r.Backoff.Initial, 110 | Max: r.Backoff.Max, 111 | Multiplier: r.Backoff.Multiplier, 112 | } 113 | } 114 | 115 | // This is kind of hacky; it is necessary because ShouldRetry expects to 116 | // handle HTTP errors via googleapi.Error, but the error has not yet been 117 | // wrapped with a googleapi.Error at this layer, and the ErrorFunc type 118 | // in the manual layer does not pass in a status explicitly as it does 119 | // here. So, we must wrap error status codes in a googleapi.Error so that 120 | // ShouldRetry can parse this correctly. 121 | func (r *RetryConfig) errorFunc() func(status int, err error) bool { 122 | if r == nil || r.ShouldRetry == nil { 123 | return shouldRetry 124 | } 125 | return func(status int, err error) bool { 126 | if status >= 400 { 127 | return r.ShouldRetry(&googleapi.Error{Code: status}) 128 | } 129 | return r.ShouldRetry(err) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /genai/internal/gensupport/retryable_linux.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build linux 16 | // +build linux 17 | 18 | package gensupport 19 | 20 | import "syscall" 21 | 22 | func init() { 23 | // Initialize syscallRetryable to return true on transient socket-level 24 | // errors. These errors are specific to Linux. 25 | syscallRetryable = func(err error) bool { return err == syscall.ECONNRESET || err == syscall.ECONNREFUSED } 26 | } 27 | -------------------------------------------------------------------------------- /genai/internal/gensupport/send.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "context" 19 | "encoding/json" 20 | "errors" 21 | "fmt" 22 | "net/http" 23 | "strings" 24 | "time" 25 | 26 | "github.com/google/uuid" 27 | "github.com/googleapis/gax-go/v2" 28 | "github.com/googleapis/gax-go/v2/callctx" 29 | ) 30 | 31 | // Use this error type to return an error which allows introspection of both 32 | // the context error and the error from the service. 33 | type wrappedCallErr struct { 34 | ctxErr error 35 | wrappedErr error 36 | } 37 | 38 | func (e wrappedCallErr) Error() string { 39 | return fmt.Sprintf("retry failed with %v; last error: %v", e.ctxErr, e.wrappedErr) 40 | } 41 | 42 | func (e wrappedCallErr) Unwrap() error { 43 | return e.wrappedErr 44 | } 45 | 46 | // Is allows errors.Is to match the error from the call as well as context 47 | // sentinel errors. 48 | func (e wrappedCallErr) Is(target error) bool { 49 | return errors.Is(e.ctxErr, target) || errors.Is(e.wrappedErr, target) 50 | } 51 | 52 | // SendRequest sends a single HTTP request using the given client. 53 | // If ctx is non-nil, it calls all hooks, then sends the request with 54 | // req.WithContext, then calls any functions returned by the hooks in 55 | // reverse order. 56 | func SendRequest(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { 57 | // Add headers set in context metadata. 58 | if ctx != nil { 59 | headers := callctx.HeadersFromContext(ctx) 60 | for k, vals := range headers { 61 | for _, v := range vals { 62 | req.Header.Add(k, v) 63 | } 64 | } 65 | } 66 | 67 | // Disallow Accept-Encoding because it interferes with the automatic gzip handling 68 | // done by the default http.Transport. See https://github.com/google/google-api-go-client/issues/219. 69 | if _, ok := req.Header["Accept-Encoding"]; ok { 70 | return nil, errors.New("google api: custom Accept-Encoding headers not allowed") 71 | } 72 | if ctx == nil { 73 | return client.Do(req) 74 | } 75 | return send(ctx, client, req) 76 | } 77 | 78 | func send(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { 79 | if client == nil { 80 | client = http.DefaultClient 81 | } 82 | resp, err := client.Do(req.WithContext(ctx)) 83 | // If we got an error, and the context has been canceled, 84 | // the context's error is probably more useful. 85 | if err != nil { 86 | select { 87 | case <-ctx.Done(): 88 | err = ctx.Err() 89 | default: 90 | } 91 | } 92 | return resp, err 93 | } 94 | 95 | // SendRequestWithRetry sends a single HTTP request using the given client, 96 | // with retries if a retryable error is returned. 97 | // If ctx is non-nil, it calls all hooks, then sends the request with 98 | // req.WithContext, then calls any functions returned by the hooks in 99 | // reverse order. 100 | func SendRequestWithRetry(ctx context.Context, client *http.Client, req *http.Request, retry *RetryConfig) (*http.Response, error) { 101 | // Add headers set in context metadata. 102 | if ctx != nil { 103 | headers := callctx.HeadersFromContext(ctx) 104 | for k, vals := range headers { 105 | for _, v := range vals { 106 | req.Header.Add(k, v) 107 | } 108 | } 109 | } 110 | 111 | // Disallow Accept-Encoding because it interferes with the automatic gzip handling 112 | // done by the default http.Transport. See https://github.com/google/google-api-go-client/issues/219. 113 | if _, ok := req.Header["Accept-Encoding"]; ok { 114 | return nil, errors.New("google api: custom Accept-Encoding headers not allowed") 115 | } 116 | if ctx == nil { 117 | return client.Do(req) 118 | } 119 | return sendAndRetry(ctx, client, req, retry) 120 | } 121 | 122 | func sendAndRetry(ctx context.Context, client *http.Client, req *http.Request, retry *RetryConfig) (*http.Response, error) { 123 | if client == nil { 124 | client = http.DefaultClient 125 | } 126 | 127 | var resp *http.Response 128 | var err error 129 | attempts := 1 130 | invocationID := uuid.New().String() 131 | baseXGoogHeader := req.Header.Get("X-Goog-Api-Client") 132 | 133 | // Loop to retry the request, up to the context deadline. 134 | var pause time.Duration 135 | var bo Backoff 136 | if retry != nil && retry.Backoff != nil { 137 | bo = &gax.Backoff{ 138 | Initial: retry.Backoff.Initial, 139 | Max: retry.Backoff.Max, 140 | Multiplier: retry.Backoff.Multiplier, 141 | } 142 | } else { 143 | bo = backoff() 144 | } 145 | 146 | errorFunc := retry.errorFunc() 147 | 148 | for { 149 | t := time.NewTimer(pause) 150 | select { 151 | case <-ctx.Done(): 152 | t.Stop() 153 | // If we got an error and the context has been canceled, return an error acknowledging 154 | // both the context cancelation and the service error. 155 | if err != nil { 156 | return resp, wrappedCallErr{ctx.Err(), err} 157 | } 158 | return resp, ctx.Err() 159 | case <-t.C: 160 | } 161 | 162 | if ctx.Err() != nil { 163 | // Check for context cancellation once more. If more than one case in a 164 | // select is satisfied at the same time, Go will choose one arbitrarily. 165 | // That can cause an operation to go through even if the context was 166 | // canceled before. 167 | if err != nil { 168 | return resp, wrappedCallErr{ctx.Err(), err} 169 | } 170 | return resp, ctx.Err() 171 | } 172 | 173 | // Set retry metrics and idempotency headers for GCS. 174 | // TODO(b/274504690): Consider dropping gccl-invocation-id key since it 175 | // duplicates the X-Goog-Gcs-Idempotency-Token header (added in v0.115.0). 176 | invocationHeader := fmt.Sprintf("gccl-invocation-id/%s gccl-attempt-count/%d", invocationID, attempts) 177 | xGoogHeader := strings.Join([]string{invocationHeader, baseXGoogHeader}, " ") 178 | req.Header.Set("X-Goog-Api-Client", xGoogHeader) 179 | req.Header.Set("X-Goog-Gcs-Idempotency-Token", invocationID) 180 | 181 | resp, err = client.Do(req.WithContext(ctx)) 182 | 183 | var status int 184 | if resp != nil { 185 | status = resp.StatusCode 186 | } 187 | 188 | // Check if we can retry the request. A retry can only be done if the error 189 | // is retryable and the request body can be re-created using GetBody (this 190 | // will not be possible if the body was unbuffered). 191 | if req.GetBody == nil || !errorFunc(status, err) { 192 | break 193 | } 194 | attempts++ 195 | var errBody error 196 | req.Body, errBody = req.GetBody() 197 | if errBody != nil { 198 | break 199 | } 200 | 201 | pause = bo.Pause() 202 | if resp != nil && resp.Body != nil { 203 | resp.Body.Close() 204 | } 205 | } 206 | return resp, err 207 | } 208 | 209 | // DecodeResponse decodes the body of res into target. If there is no body, 210 | // target is unchanged. 211 | func DecodeResponse(target interface{}, res *http.Response) error { 212 | if res.StatusCode == http.StatusNoContent { 213 | return nil 214 | } 215 | return json.NewDecoder(res.Body).Decode(target) 216 | } 217 | -------------------------------------------------------------------------------- /genai/internal/gensupport/send_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "context" 19 | "errors" 20 | "fmt" 21 | "net/http" 22 | "testing" 23 | 24 | "github.com/google/go-cmp/cmp" 25 | "github.com/googleapis/gax-go/v2/callctx" 26 | ) 27 | 28 | func TestSendRequest(t *testing.T) { 29 | // Setting Accept-Encoding should give an error immediately. 30 | req, _ := http.NewRequest("GET", "url", nil) 31 | req.Header.Set("Accept-Encoding", "") 32 | _, err := SendRequest(context.Background(), nil, req) 33 | if err == nil { 34 | t.Error("got nil, want error") 35 | } 36 | } 37 | 38 | func TestSendRequestWithRetry(t *testing.T) { 39 | // Setting Accept-Encoding should give an error immediately. 40 | req, _ := http.NewRequest("GET", "url", nil) 41 | req.Header.Set("Accept-Encoding", "") 42 | _, err := SendRequestWithRetry(context.Background(), nil, req, nil) 43 | if err == nil { 44 | t.Error("got nil, want error") 45 | } 46 | } 47 | 48 | type headerRoundTripper struct { 49 | wantHeader http.Header 50 | } 51 | 52 | func (rt *headerRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { 53 | // Ignore x-goog headers sent by SendRequestWithRetry 54 | r.Header.Del("X-Goog-Api-Client") 55 | r.Header.Del("X-Goog-Gcs-Idempotency-Token") 56 | if diff := cmp.Diff(r.Header, rt.wantHeader); diff != "" { 57 | return nil, fmt.Errorf("headers don't match: %v", diff) 58 | } 59 | return &http.Response{StatusCode: 200}, nil 60 | } 61 | 62 | // Ensure that headers set via the context are passed through to the request as expected. 63 | func TestSendRequestHeader(t *testing.T) { 64 | ctx := context.Background() 65 | ctx = callctx.SetHeaders(ctx, "foo", "100", "bar", "200") 66 | client := http.Client{ 67 | Transport: &headerRoundTripper{ 68 | wantHeader: map[string][]string{"Foo": {"100"}, "Bar": {"200"}}, 69 | }, 70 | } 71 | req, _ := http.NewRequest("GET", "url", nil) 72 | if _, err := SendRequest(ctx, &client, req); err != nil { 73 | t.Errorf("SendRequest: %v", err) 74 | } 75 | req2, _ := http.NewRequest("GET", "url", nil) 76 | if _, err := SendRequestWithRetry(ctx, &client, req2, nil); err != nil { 77 | t.Errorf("SendRequest: %v", err) 78 | } 79 | } 80 | 81 | type brokenRoundTripper struct{} 82 | 83 | func (t *brokenRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { 84 | return nil, errors.New("this should not happen") 85 | } 86 | 87 | func TestCanceledContextDoesNotPerformRequest(t *testing.T) { 88 | client := http.Client{ 89 | Transport: &brokenRoundTripper{}, 90 | } 91 | for i := 0; i < 1000; i++ { 92 | req, _ := http.NewRequest("GET", "url", nil) 93 | ctx, cancel := context.WithCancel(context.Background()) 94 | cancel() 95 | _, err := SendRequestWithRetry(ctx, &client, req, nil) 96 | if !errors.Is(err, context.Canceled) { 97 | t.Fatalf("got %v, want %v", err, context.Canceled) 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /genai/internal/gensupport/util_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "io" 19 | "time" 20 | ) 21 | 22 | // errReader reads out of a buffer until it is empty, then returns the specified error. 23 | type errReader struct { 24 | buf []byte 25 | err error 26 | } 27 | 28 | func (er *errReader) Read(p []byte) (int, error) { 29 | if len(er.buf) == 0 { 30 | if er.err == nil { 31 | return 0, io.EOF 32 | } 33 | return 0, er.err 34 | } 35 | n := copy(p, er.buf) 36 | er.buf = er.buf[n:] 37 | return n, nil 38 | } 39 | 40 | // NoPauseBackoff implements backoff with infinite 0-length pauses. 41 | type NoPauseBackoff struct{} 42 | 43 | func (bo *NoPauseBackoff) Pause() time.Duration { return 0 } 44 | 45 | // PauseOneSecond implements backoff with infinite 1s pauses. 46 | type PauseOneSecond struct{} 47 | 48 | func (bo *PauseOneSecond) Pause() time.Duration { return time.Second } 49 | -------------------------------------------------------------------------------- /genai/internal/gensupport/version.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import ( 18 | "runtime" 19 | "strings" 20 | "unicode" 21 | ) 22 | 23 | // GoVersion returns the Go runtime version. The returned string 24 | // has no whitespace. 25 | func GoVersion() string { 26 | return goVersion 27 | } 28 | 29 | var goVersion = goVer(runtime.Version()) 30 | 31 | const develPrefix = "devel +" 32 | 33 | func goVer(s string) string { 34 | if strings.HasPrefix(s, develPrefix) { 35 | s = s[len(develPrefix):] 36 | if p := strings.IndexFunc(s, unicode.IsSpace); p >= 0 { 37 | s = s[:p] 38 | } 39 | return s 40 | } 41 | 42 | if strings.HasPrefix(s, "go1") { 43 | s = s[2:] 44 | var prerelease string 45 | if p := strings.IndexFunc(s, notSemverRune); p >= 0 { 46 | s, prerelease = s[:p], s[p:] 47 | } 48 | if strings.HasSuffix(s, ".") { 49 | s += "0" 50 | } else if strings.Count(s, ".") < 2 { 51 | s += ".0" 52 | } 53 | if prerelease != "" { 54 | s += "-" + prerelease 55 | } 56 | return s 57 | } 58 | return "" 59 | } 60 | 61 | func notSemverRune(r rune) bool { 62 | return !strings.ContainsRune("0123456789.", r) 63 | } 64 | -------------------------------------------------------------------------------- /genai/internal/gensupport/version_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gensupport 16 | 17 | import "testing" 18 | 19 | func TestGoVer(t *testing.T) { 20 | for _, tst := range []struct { 21 | in, want string 22 | }{ 23 | {"go1.8", "1.8.0"}, 24 | {"go1.7.3", "1.7.3"}, 25 | {"go1.8.typealias", "1.8.0-typealias"}, 26 | {"go1.8beta1", "1.8.0-beta1"}, 27 | {"go1.8rc2", "1.8.0-rc2"}, 28 | {"devel +824f981dd4b7 Tue Apr 29 21:41:54 2014 -0400", "824f981dd4b7"}, 29 | {"foo bar zipzap", ""}, 30 | } { 31 | if got := goVer(tst.in); got != tst.want { 32 | t.Errorf("goVer(%q) = %q, want %q", tst.in, got, tst.want) 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /genai/internal/testhelpers/testhelpers.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package testhelpers 16 | 17 | import ( 18 | "log" 19 | "os" 20 | "path/filepath" 21 | ) 22 | 23 | // ModuleRootDir finds the location of the root directory of this respository. 24 | // Note: typically Go tests can assume a fixed directory location, but some 25 | // tests/examples in this repository get copied and can run from multiple 26 | // directories, requiring the use of this function. 27 | func ModuleRootDir() string { 28 | dir, err := os.Getwd() 29 | if err != nil { 30 | log.Fatal("Getcwd:", err) 31 | } 32 | 33 | for { 34 | if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { 35 | return dir 36 | } 37 | 38 | parentDir := filepath.Dir(dir) 39 | if parentDir == dir { 40 | log.Fatal("unable to find") 41 | } 42 | dir = parentDir 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /genai/internal/version.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package internal 16 | 17 | // Version is the current tagged release of the library. 18 | const Version = "0.19.0" 19 | -------------------------------------------------------------------------------- /genai/license.txt: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /genai/list_models.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package genai 16 | 17 | import ( 18 | "context" 19 | 20 | gl "cloud.google.com/go/ai/generativelanguage/apiv1beta" 21 | pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" 22 | 23 | "google.golang.org/api/iterator" 24 | ) 25 | 26 | func (c *Client) ListModels(ctx context.Context) *ModelInfoIterator { 27 | return &ModelInfoIterator{ 28 | it: c.mc.ListModels(ctx, &pb.ListModelsRequest{}), 29 | } 30 | } 31 | 32 | // A ModelInfoIterator iterates over Models. 33 | type ModelInfoIterator struct { 34 | it *gl.ModelIterator 35 | } 36 | 37 | // Next returns the next result. Its second return value is iterator.Done if there are no more 38 | // results. Once Next returns Done, all subsequent calls will return Done. 39 | func (it *ModelInfoIterator) Next() (*ModelInfo, error) { 40 | m, err := it.it.Next() 41 | if err != nil { 42 | return nil, err 43 | } 44 | return (ModelInfo{}).fromProto(m), nil 45 | } 46 | 47 | // PageInfo supports pagination. See the google.golang.org/api/iterator package for details. 48 | func (it *ModelInfoIterator) PageInfo() *iterator.PageInfo { 49 | return it.it.PageInfo() 50 | } 51 | -------------------------------------------------------------------------------- /genai/option.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package genai 16 | 17 | import ( 18 | "google.golang.org/api/option" 19 | "google.golang.org/api/option/internaloption" 20 | ) 21 | 22 | // WithClientInfo sets request information identifying the 23 | // product that is calling this client. 24 | func WithClientInfo(key, value string) option.ClientOption { 25 | return &clientInfo{key: key, value: value} 26 | } 27 | 28 | type clientInfo struct { 29 | internaloption.EmbeddableAdapter 30 | key, value string 31 | } 32 | 33 | // optionOfType returns the first value of opts that has type T, 34 | // along with true. If there is no option of that type, it returns 35 | // the zero value for T and false. 36 | func optionOfType[T option.ClientOption](opts []option.ClientOption) (T, bool) { 37 | for _, opt := range opts { 38 | if opt, ok := opt.(T); ok { 39 | return opt, true 40 | } 41 | } 42 | var z T 43 | return z, false 44 | } 45 | -------------------------------------------------------------------------------- /genai/testdata/1251.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/generative-ai-go/186f0111b75297bbb4edb3d9b93c21fc59c9b7e1/genai/testdata/1251.txt -------------------------------------------------------------------------------- /genai/testdata/Cajun_instruments.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/generative-ai-go/186f0111b75297bbb4edb3d9b93c21fc59c9b7e1/genai/testdata/Cajun_instruments.jpg -------------------------------------------------------------------------------- /genai/testdata/badencoding.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/generative-ai-go/186f0111b75297bbb4edb3d9b93c21fc59c9b7e1/genai/testdata/badencoding.txt -------------------------------------------------------------------------------- /genai/testdata/earth.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/generative-ai-go/186f0111b75297bbb4edb3d9b93c21fc59c9b7e1/genai/testdata/earth.mp4 -------------------------------------------------------------------------------- /genai/testdata/organ.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/generative-ai-go/186f0111b75297bbb4edb3d9b93c21fc59c9b7e1/genai/testdata/organ.jpg -------------------------------------------------------------------------------- /genai/testdata/personWorkingOnComputer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/generative-ai-go/186f0111b75297bbb4edb3d9b93c21fc59c9b7e1/genai/testdata/personWorkingOnComputer.jpg -------------------------------------------------------------------------------- /genai/testdata/poem.txt: -------------------------------------------------------------------------------- 1 | When daisies pied, and violets blue, 2 | And lady-smocks all silver-white, 3 | And cuckoo-buds of yellow hue 4 | Do paint the meadows with delight 5 | 6 | -------------------------------------------------------------------------------- /genai/testdata/sample.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/generative-ai-go/186f0111b75297bbb4edb3d9b93c21fc59c9b7e1/genai/testdata/sample.mp3 -------------------------------------------------------------------------------- /genai/testdata/test.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/generative-ai-go/186f0111b75297bbb4edb3d9b93c21fc59c9b7e1/genai/testdata/test.pdf -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/google/generative-ai-go 2 | 3 | go 1.21 4 | 5 | require ( 6 | cloud.google.com/go/ai v0.8.0 7 | github.com/google/go-cmp v0.6.0 8 | github.com/google/uuid v1.6.0 9 | github.com/googleapis/gax-go/v2 v2.12.5 10 | google.golang.org/api v0.186.0 11 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240617180043-68d350f18fd4 12 | google.golang.org/grpc v1.64.1 13 | google.golang.org/protobuf v1.34.2 14 | ) 15 | 16 | require ( 17 | cloud.google.com/go v0.115.0 // indirect 18 | cloud.google.com/go/auth v0.6.0 // indirect 19 | cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect 20 | cloud.google.com/go/compute/metadata v0.3.0 // indirect 21 | cloud.google.com/go/longrunning v0.5.7 // indirect 22 | github.com/felixge/httpsnoop v1.0.4 // indirect 23 | github.com/go-logr/logr v1.4.1 // indirect 24 | github.com/go-logr/stdr v1.2.2 // indirect 25 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 26 | github.com/golang/protobuf v1.5.4 // indirect 27 | github.com/google/s2a-go v0.1.7 // indirect 28 | github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect 29 | go.opencensus.io v0.24.0 // indirect 30 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect 31 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect 32 | go.opentelemetry.io/otel v1.26.0 // indirect 33 | go.opentelemetry.io/otel/metric v1.26.0 // indirect 34 | go.opentelemetry.io/otel/trace v1.26.0 // indirect 35 | golang.org/x/crypto v0.31.0 // indirect 36 | golang.org/x/net v0.26.0 // indirect 37 | golang.org/x/oauth2 v0.21.0 // indirect 38 | golang.org/x/sync v0.10.0 // indirect 39 | golang.org/x/sys v0.28.0 // indirect 40 | golang.org/x/text v0.21.0 // indirect 41 | golang.org/x/time v0.5.0 // indirect 42 | google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect 43 | ) 44 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= 2 | cloud.google.com/go v0.115.0 h1:CnFSK6Xo3lDYRoBKEcAtia6VSC837/ZkJuRduSFnr14= 3 | cloud.google.com/go v0.115.0/go.mod h1:8jIM5vVgoAEoiVxQ/O4BFTfHqulPZgs/ufEzMcFMdWU= 4 | cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w= 5 | cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE= 6 | cloud.google.com/go/auth v0.6.0 h1:5x+d6b5zdezZ7gmLWD1m/xNjnaQ2YDhmIz/HH3doy1g= 7 | cloud.google.com/go/auth v0.6.0/go.mod h1:b4acV+jLQDyjwm4OXHYjNvRi4jvGBzHWJRtJcy+2P4g= 8 | cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= 9 | cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= 10 | cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= 11 | cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= 12 | cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= 13 | cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= 14 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 15 | github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= 16 | github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= 17 | github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= 18 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 19 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 20 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 21 | github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= 22 | github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= 23 | github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= 24 | github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= 25 | github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= 26 | github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= 27 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 28 | github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= 29 | github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 30 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 31 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 32 | github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= 33 | github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 34 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= 35 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 36 | github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= 37 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 38 | github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 39 | github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= 40 | github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= 41 | github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= 42 | github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= 43 | github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= 44 | github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= 45 | github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= 46 | github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= 47 | github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= 48 | github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= 49 | github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 50 | github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 51 | github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 52 | github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 53 | github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 54 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 55 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 56 | github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= 57 | github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= 58 | github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 59 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 60 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 61 | github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= 62 | github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= 63 | github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA= 64 | github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= 65 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 66 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 67 | github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= 68 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 69 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 70 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 71 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 72 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 73 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 74 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 75 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 76 | go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= 77 | go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= 78 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 h1:A3SayB3rNyt+1S6qpI9mHPkeHTZbD7XILEqWnYZb2l0= 79 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0/go.mod h1:27iA5uvhuRNmalO+iEUdVn5ZMj2qy10Mm+XRIpRmyuU= 80 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= 81 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc= 82 | go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= 83 | go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= 84 | go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= 85 | go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= 86 | go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= 87 | go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= 88 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 89 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 90 | golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= 91 | golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= 92 | golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= 93 | golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= 94 | golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= 95 | golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 96 | golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 97 | golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 98 | golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 99 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 100 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 101 | golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= 102 | golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= 103 | golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= 104 | golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= 105 | golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= 106 | golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= 107 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 108 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 109 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 110 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= 111 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 112 | golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 113 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 114 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 115 | golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 116 | golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= 117 | golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 118 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 119 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 120 | golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= 121 | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= 122 | golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= 123 | golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 124 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 125 | golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 126 | golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= 127 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 128 | golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= 129 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 130 | google.golang.org/api v0.186.0 h1:n2OPp+PPXX0Axh4GuSsL5QL8xQCTb2oDwyzPnQvqUug= 131 | google.golang.org/api v0.186.0/go.mod h1:hvRbBmgoje49RV3xqVXrmP6w93n6ehGgIVPYrGtBFFc= 132 | google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= 133 | google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= 134 | google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= 135 | google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= 136 | google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= 137 | google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc= 138 | google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c= 139 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240617180043-68d350f18fd4 h1:Di6ANFilr+S60a4S61ZM00vLdw0IrQOSMS2/6mrnOU0= 140 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240617180043-68d350f18fd4/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= 141 | google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= 142 | google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= 143 | google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= 144 | google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= 145 | google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= 146 | google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= 147 | google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= 148 | google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= 149 | google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= 150 | google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= 151 | google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= 152 | google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= 153 | google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 154 | google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 155 | google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 156 | google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= 157 | google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= 158 | google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= 159 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 160 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 161 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 162 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 163 | honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= 164 | honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= 165 | -------------------------------------------------------------------------------- /license_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package main 16 | 17 | import ( 18 | "bufio" 19 | "errors" 20 | "io/fs" 21 | "os" 22 | "path/filepath" 23 | "regexp" 24 | "strings" 25 | "testing" 26 | "unicode" 27 | ) 28 | 29 | var wantFile = filepath.Join("genai", "license.txt") 30 | 31 | func TestLicense(t *testing.T) { 32 | lic, err := os.ReadFile(wantFile) 33 | if err != nil { 34 | t.Fatal(err) 35 | } 36 | want := string(lic) 37 | want = eraseYear(want) 38 | want = removeCommentPrefix(want, "//") 39 | // Remove final blank line(s). 40 | want = strings.TrimRightFunc(want, unicode.IsSpace) 41 | 42 | check := func(t *testing.T, file, prefix, contents string) { 43 | t.Helper() 44 | got := removeCommentPrefix(contents, prefix) 45 | got = eraseYear(got) 46 | if got != want { 47 | t.Errorf("%s: bad license: does not match contents of %s", file, wantFile) 48 | t.Logf("got %q", got) 49 | t.Logf("want %q", want) 50 | } 51 | } 52 | 53 | t.Run("scripts", func(t *testing.T) { 54 | shellScripts, err := globTree(".", "*.sh") 55 | if err != nil { 56 | t.Fatal(err) 57 | } 58 | for _, f := range shellScripts { 59 | got, err := topComment(f, "#") 60 | if err != nil { 61 | t.Fatal(err) 62 | } 63 | // Remove shbang line. 64 | if strings.HasPrefix(got, "#!") { 65 | if i := strings.IndexByte(got, '\n'); i > 0 { 66 | got = got[i+1:] 67 | } 68 | } 69 | check(t, f, "#", got) 70 | } 71 | }) 72 | t.Run("go source", func(t *testing.T) { 73 | goFiles, err := globTree(".", "*.go") 74 | if err != nil { 75 | t.Fatal(err) 76 | } 77 | for _, f := range goFiles { 78 | got, err := topComment(f, "//") 79 | if err != nil { 80 | t.Fatal(err) 81 | } 82 | check(t, f, "//", got) 83 | } 84 | }) 85 | } 86 | 87 | var yearRegexp = regexp.MustCompile(`[Cc]opyright \d\d\d\d`) 88 | 89 | func eraseYear(s string) string { 90 | return yearRegexp.ReplaceAllLiteralString(s, "Copyright YYYY") 91 | } 92 | 93 | func removeCommentPrefix(s, prefix string) string { 94 | var lines []string 95 | for _, line := range strings.Split(s, "\n") { 96 | lines = append(lines, strings.TrimPrefix(line, prefix)) 97 | } 98 | return strings.Join(lines, "\n") 99 | } 100 | 101 | // topComment returns the comment at the top of the file, up to the first blank or non-comment line. 102 | // Exception: the first comment contains "generated", in which case we take the second one. 103 | func topComment(file, commentPrefix string) (string, error) { 104 | f, err := os.Open(file) 105 | if err != nil { 106 | return "", err 107 | } 108 | defer f.Close() 109 | scan := bufio.NewScanner(f) 110 | var lines []string 111 | gen := false 112 | n := 0 113 | for scan.Scan() { 114 | line := scan.Text() 115 | n++ 116 | if n == 1 && strings.Contains(line, "generated") { 117 | gen = true 118 | continue 119 | } 120 | if gen && line == "" { 121 | gen = false 122 | continue 123 | } 124 | if !gen { 125 | if strings.HasPrefix(line, commentPrefix) { 126 | lines = append(lines, line) 127 | } else { 128 | break 129 | } 130 | } 131 | } 132 | if scan.Err() != nil { 133 | return "", scan.Err() 134 | } 135 | return strings.Join(lines, "\n"), nil 136 | } 137 | 138 | // globTree runs filepath.Glob on dir and all its subdirectories, recursively. 139 | // The filenames it returns begin with dir. 140 | // The pattern must not contain path separators. 141 | func globTree(dir, pattern string) ([]string, error) { 142 | if strings.ContainsRune(pattern, filepath.Separator) { 143 | return nil, errors.New("pattern contains path separator") 144 | } 145 | 146 | // Check for bad pattern. 147 | if _, err := filepath.Match(pattern, ""); err != nil { 148 | return nil, err 149 | } 150 | var paths []string 151 | err := filepath.WalkDir(dir, func(path string, _ fs.DirEntry, err error) error { 152 | if err != nil { 153 | return err 154 | } 155 | if ok, _ := filepath.Match(pattern, filepath.Base(path)); ok { 156 | paths = append(paths, path) 157 | } 158 | return nil 159 | }) 160 | if err != nil { 161 | return nil, err 162 | } 163 | return paths, nil 164 | } 165 | --------------------------------------------------------------------------------