├── .dockerignore ├── .gitignore ├── CODE_OF_CONDUCT.md ├── Dockerfile ├── LICENSE ├── README.md ├── api ├── generate.go ├── language_model.pb.go ├── language_model.proto └── language_model_grpc.pb.go ├── cmd └── verbaflow │ └── main.go ├── decoder ├── control.go ├── decoder.go └── selection.go ├── downloader ├── downloadmodel.go └── downloadprogress.go ├── encoder └── encoder.go ├── examples └── prompttester │ ├── config.yaml │ ├── go.mod │ ├── go.sum │ ├── main.go │ └── prompts │ ├── classification_1.tmpl │ ├── classification_2.tmpl │ ├── extractive_question_answering_1.tmpl │ ├── extractive_question_answering_2.tmpl │ ├── sentence_to_question.tmpl │ ├── simplification_1.tmpl │ ├── simplification_2.tmpl │ ├── simplification_3.tmpl │ ├── summarization_1.tmpl │ ├── summarization_2.tmpl │ ├── summarization_3.tmpl │ ├── summarization_4.tmpl │ ├── summarization_5.tmpl │ ├── summarization_6.tmpl │ └── summarization_7.tmpl ├── go.mod ├── go.sum ├── go.work ├── go.work.sum ├── prompt.go ├── rwkvlm ├── converter.go ├── embeddings.go ├── gob.go └── rwkvlm.go ├── service └── service.go ├── sliceutils ├── indexedslice.go ├── indexedslice_test.go ├── orderedheap.go ├── orderedheap_test.go ├── reverseheap.go ├── reverseheap_test.go └── slices.go ├── tokenizer ├── internal │ └── bpetokenizer │ │ ├── testdata │ │ └── dummy-roberta-model │ │ │ ├── merges.txt │ │ │ └── vocab.json │ │ ├── tokenizer.go │ │ └── tokenizer_test.go └── tokenizer.go ├── tools └── tools.go └── verbaflow.go /.dockerignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .dockerignore 3 | .env 4 | .git 5 | .gitignore 6 | .idea 7 | .vscode 8 | Dockerfile 9 | data 10 | examples 11 | go.work 12 | go.work.sum 13 | models 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | .env 4 | .vscode 5 | models 6 | data 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | matteogrella@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.20.1-alpine3.16@sha256:020cc6a446af866cea4bba2e5732b01620414f18c2da9a8a91c04920f2da02ce as Builder 2 | 3 | WORKDIR /go/src/verbaflow 4 | COPY . . 5 | 6 | RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /go/bin/verbaflow ./cmd/verbaflow 7 | 8 | FROM alpine:3.17.2@sha256:e2e16842c9b54d985bf1ef9242a313f36b856181f188de21313820e177002501 9 | COPY --from=Builder /go/bin/verbaflow /bin/verbaflow 10 | ENTRYPOINT ["/bin/verbaflow"] 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2023, NLP Odyssey 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VerbaFlow 2 | 3 | Welcome to VerbaFlow, a neural architecture written in Go designed specifically for language modeling tasks. 4 | Built on the robust RWKV RNN, this model is optimized for efficient performance on standard CPUs, enabling smooth running of relatively large language models even on consumer hardware. 5 | 6 | With the ability to utilize pretrained models on the [Pile](https://arxiv.org/abs/2101.00027) dataset, VerbaFlow performs comparably to GPT-like Transformer models in predicting the next token, as well as in other tasks such as text summarization, text classification, question answering, and general conversation. 7 | 8 | # Installation 9 | 10 | Requirements: 11 | 12 | * [Go 1.20](https://golang.org/dl/) 13 | 14 | Clone this repo or get the library: 15 | 16 | ```console 17 | go get -u github.com/nlpodyssey/verbaflow 18 | ``` 19 | 20 | # Usage 21 | 22 | To start using VerbaFlow, we recommend using the pre-trained model `RWKV-4-Pile-1B5-Instruct`, available on the [Hugging Face Hub](https://huggingface.co/nlpodyssey/RWKV-4-Pile-1B5-Instruct). 23 | This model has been fine-tuned using the [Pile](https://huggingface.co/datasets/the_pile) dataset and has been specially designed to understand and execute human instructions, as fine-tuned on the [xP3](https://huggingface.co/datasets/bigscience/xP3all) dataset. 24 | The original `RWKV-4-Pile-1B5-Instruct-test2-20230209` model, from which this model is derived, can be accessed [here](https://huggingface.co/BlinkDL/rwkv-4-pile-1b5). 25 | 26 | > The library is optimized to run in x86-64 CPUs. If you want to run it on a different architecture, you can use the `GOARCH=amd64` environment variable. 27 | 28 | The following commands can be used to build and use VerbaFlow: 29 | 30 | ```console 31 | go build ./cmd/verbaflow 32 | ``` 33 | 34 | This command builds the go program and creates an executable named `verbaflow`. 35 | 36 | ```console 37 | ./verbaflow -model-dir models/nlpodyssey/RWKV-4-Pile-1B5-Instruct download 38 | ``` 39 | 40 | This command downloads the model specified (in this case, "nlpodyssey/RWKV-4-Pile-1B5-Instruct" under the "models" directory) 41 | 42 | ```console 43 | ./verbaflow -model-dir models/nlpodyssey/RWKV-4-Pile-1B5-Instruct convert 44 | ``` 45 | 46 | This command converts the downloaded model to the format used by the program. 47 | 48 | ```console 49 | ./verbaflow -log-level trace -model-dir models/nlpodyssey/RWKV-4-Pile-1B5-Instruct inference --address :50051 50 | ``` 51 | 52 | This command runs the gRPC inference endpoint on the specified model. 53 | 54 | Please make sure to have the necessary dependencies installed before running the above commands. 55 | 56 | ## Examples 57 | 58 | One of the most interesting features of the LLM is the ability to react based on the prompt. 59 | 60 | Run the `verbaflow` gRPC endpoint with the command in inference, then run the `prompttester` example entering the following prompts: 61 | 62 | ### Example 1 63 | 64 | Prompt: 65 | 66 | ```console 67 | echo '\nQ: Briefly: The Universe is expanding, its constituent galaxies flying apart like pieces of cosmic shrapnel in the aftermath of the Big Bang. Which section of a newspaper would this article likely appear in?\n\nA:' | go run ./examples/prompttester --dconfig ./examples/prompttester/config.yaml 68 | ``` 69 | 70 | Expected output: 71 | 72 | ```console 73 | Science and Technology 74 | ``` 75 | 76 | ### Example 2 77 | 78 | Prompt: 79 | 80 | ```console 81 | echo '\nQ:Translate the following text from French to English Je suis le père le plus heureux du monde\n\nA:' | go run ./examples/prompttester --dconfig ./examples/prompttester/config.yaml 82 | ``` 83 | 84 | Expected output: 85 | 86 | ```console 87 | I am the happiest father in the world. 88 | ``` 89 | 90 | ## Dependencies 91 | 92 | A list of the main dependencies follows: 93 | 94 | - [Spago](http://github.com/nlpodyssey/spago) - Machine Learning framework 95 | - [RWKV](http://github.com/nlpodyssey/rwkv) - RWKV RNN implementation 96 | - [GoTokenizers](http://github.com/nlpodyssey/gotokenizers) - Tokenizers library 97 | - [GoPickle](http://github.com/nlpodyssey/gopickle) - Pickle library for Go 98 | 99 | # Roadmap 100 | 101 | - [x] Download pretrained models from the Hugging Face models hub 102 | - [ ] Effective "prompts" catalog 103 | - [x] Better sampling 104 | - [ ] Beam search 105 | - [ ] Better Tokenizer 106 | - [ ] Unit tests 107 | - [ ] Code refactoring 108 | - [ ] Documentation 109 | - [x] gRPC ~~/HTTP~~ API 110 | 111 | # Credits 112 | 113 | Thanks [PENG Bo](https://github.com/BlinkDL) for creating the RWKV RNN and all related resources, including pre-trained models! 114 | 115 | # Trivia about the project's name 116 | 117 | "VerbaFlow" combines "verba", which is the Latin word for *words*, and "flow", which alludes to the characteristics of recurrent neural networks by evoking the idea of a fluent and continuous flow of words, which is made possible by the network's ability to maintain an internal state and "remember" previous words and context when generating new words. -------------------------------------------------------------------------------- /api/generate.go: -------------------------------------------------------------------------------- 1 | //go:generate protoc -I . --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative language_model.proto 2 | 3 | package api 4 | -------------------------------------------------------------------------------- /api/language_model.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // versions: 3 | // protoc-gen-go v1.28.1 4 | // protoc v3.21.5 5 | // source: language_model.proto 6 | 7 | package api 8 | 9 | import ( 10 | protoreflect "google.golang.org/protobuf/reflect/protoreflect" 11 | protoimpl "google.golang.org/protobuf/runtime/protoimpl" 12 | reflect "reflect" 13 | sync "sync" 14 | ) 15 | 16 | const ( 17 | // Verify that this generated code is sufficiently up-to-date. 18 | _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) 19 | // Verify that runtime/protoimpl is sufficiently up-to-date. 20 | _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) 21 | ) 22 | 23 | // TokenGenerationRequest contains the prompt and decoding parameters for generating tokens 24 | type TokenGenerationRequest struct { 25 | state protoimpl.MessageState 26 | sizeCache protoimpl.SizeCache 27 | unknownFields protoimpl.UnknownFields 28 | 29 | // Prompt is the input string to use as a starting point for token generation 30 | Prompt string `protobuf:"bytes,1,opt,name=prompt,proto3" json:"prompt,omitempty"` 31 | // DecodingParameters are the parameters to use for token generation 32 | DecodingParameters *DecodingParameters `protobuf:"bytes,2,opt,name=decoding_parameters,json=decodingParameters,proto3" json:"decoding_parameters,omitempty"` 33 | } 34 | 35 | func (x *TokenGenerationRequest) Reset() { 36 | *x = TokenGenerationRequest{} 37 | if protoimpl.UnsafeEnabled { 38 | mi := &file_language_model_proto_msgTypes[0] 39 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 40 | ms.StoreMessageInfo(mi) 41 | } 42 | } 43 | 44 | func (x *TokenGenerationRequest) String() string { 45 | return protoimpl.X.MessageStringOf(x) 46 | } 47 | 48 | func (*TokenGenerationRequest) ProtoMessage() {} 49 | 50 | func (x *TokenGenerationRequest) ProtoReflect() protoreflect.Message { 51 | mi := &file_language_model_proto_msgTypes[0] 52 | if protoimpl.UnsafeEnabled && x != nil { 53 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 54 | if ms.LoadMessageInfo() == nil { 55 | ms.StoreMessageInfo(mi) 56 | } 57 | return ms 58 | } 59 | return mi.MessageOf(x) 60 | } 61 | 62 | // Deprecated: Use TokenGenerationRequest.ProtoReflect.Descriptor instead. 63 | func (*TokenGenerationRequest) Descriptor() ([]byte, []int) { 64 | return file_language_model_proto_rawDescGZIP(), []int{0} 65 | } 66 | 67 | func (x *TokenGenerationRequest) GetPrompt() string { 68 | if x != nil { 69 | return x.Prompt 70 | } 71 | return "" 72 | } 73 | 74 | func (x *TokenGenerationRequest) GetDecodingParameters() *DecodingParameters { 75 | if x != nil { 76 | return x.DecodingParameters 77 | } 78 | return nil 79 | } 80 | 81 | // DecodingParameters contains the parameters to use for token generation 82 | type DecodingParameters struct { 83 | state protoimpl.MessageState 84 | sizeCache protoimpl.SizeCache 85 | unknownFields protoimpl.UnknownFields 86 | 87 | // MaxLen is the maximum number of tokens to generate. 88 | MaxLen int32 `protobuf:"varint,1,opt,name=max_len,json=maxLen,proto3" json:"max_len,omitempty"` 89 | // MinLen is the minimum number of tokens to generate. 90 | MinLen int32 `protobuf:"varint,2,opt,name=min_len,json=minLen,proto3" json:"min_len,omitempty"` 91 | // Temperature controls the randomness of the generated tokens. A higher temperature will result in more diverse generated tokens. 92 | Temperature float32 `protobuf:"fixed32,3,opt,name=temperature,proto3" json:"temperature,omitempty"` 93 | // TopK is the maximum number of tokens to consider when sampling the next token. 94 | TopK int32 `protobuf:"varint,4,opt,name=top_k,json=topK,proto3" json:"top_k,omitempty"` 95 | // TopP is the cumulative probability of the tokens to consider when sampling the next token. 96 | TopP float32 `protobuf:"fixed32,5,opt,name=top_p,json=topP,proto3" json:"top_p,omitempty"` 97 | // UseSampling uses sampling to generate the next token. 98 | UseSampling bool `protobuf:"varint,6,opt,name=use_sampling,json=useSampling,proto3" json:"use_sampling,omitempty"` 99 | // EndTokenID is the end-of-sequence token (default: 0). 100 | EndTokenId int32 `protobuf:"varint,7,opt,name=end_token_id,json=endTokenId,proto3" json:"end_token_id,omitempty"` 101 | // SkipEndTokenID when true, the end token is not added to the generated sequence. 102 | SkipEndTokenId bool `protobuf:"varint,8,opt,name=skip_end_token_id,json=skipEndTokenId,proto3" json:"skip_end_token_id,omitempty"` 103 | // StopSequences are the sequences of token ids that will cause the generation to stop. 104 | StopSequences []*Sequence `protobuf:"bytes,9,rep,name=stop_sequences,json=stopSequences,proto3" json:"stop_sequences,omitempty"` 105 | } 106 | 107 | func (x *DecodingParameters) Reset() { 108 | *x = DecodingParameters{} 109 | if protoimpl.UnsafeEnabled { 110 | mi := &file_language_model_proto_msgTypes[1] 111 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 112 | ms.StoreMessageInfo(mi) 113 | } 114 | } 115 | 116 | func (x *DecodingParameters) String() string { 117 | return protoimpl.X.MessageStringOf(x) 118 | } 119 | 120 | func (*DecodingParameters) ProtoMessage() {} 121 | 122 | func (x *DecodingParameters) ProtoReflect() protoreflect.Message { 123 | mi := &file_language_model_proto_msgTypes[1] 124 | if protoimpl.UnsafeEnabled && x != nil { 125 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 126 | if ms.LoadMessageInfo() == nil { 127 | ms.StoreMessageInfo(mi) 128 | } 129 | return ms 130 | } 131 | return mi.MessageOf(x) 132 | } 133 | 134 | // Deprecated: Use DecodingParameters.ProtoReflect.Descriptor instead. 135 | func (*DecodingParameters) Descriptor() ([]byte, []int) { 136 | return file_language_model_proto_rawDescGZIP(), []int{1} 137 | } 138 | 139 | func (x *DecodingParameters) GetMaxLen() int32 { 140 | if x != nil { 141 | return x.MaxLen 142 | } 143 | return 0 144 | } 145 | 146 | func (x *DecodingParameters) GetMinLen() int32 { 147 | if x != nil { 148 | return x.MinLen 149 | } 150 | return 0 151 | } 152 | 153 | func (x *DecodingParameters) GetTemperature() float32 { 154 | if x != nil { 155 | return x.Temperature 156 | } 157 | return 0 158 | } 159 | 160 | func (x *DecodingParameters) GetTopK() int32 { 161 | if x != nil { 162 | return x.TopK 163 | } 164 | return 0 165 | } 166 | 167 | func (x *DecodingParameters) GetTopP() float32 { 168 | if x != nil { 169 | return x.TopP 170 | } 171 | return 0 172 | } 173 | 174 | func (x *DecodingParameters) GetUseSampling() bool { 175 | if x != nil { 176 | return x.UseSampling 177 | } 178 | return false 179 | } 180 | 181 | func (x *DecodingParameters) GetEndTokenId() int32 { 182 | if x != nil { 183 | return x.EndTokenId 184 | } 185 | return 0 186 | } 187 | 188 | func (x *DecodingParameters) GetSkipEndTokenId() bool { 189 | if x != nil { 190 | return x.SkipEndTokenId 191 | } 192 | return false 193 | } 194 | 195 | func (x *DecodingParameters) GetStopSequences() []*Sequence { 196 | if x != nil { 197 | return x.StopSequences 198 | } 199 | return nil 200 | } 201 | 202 | type Sequence struct { 203 | state protoimpl.MessageState 204 | sizeCache protoimpl.SizeCache 205 | unknownFields protoimpl.UnknownFields 206 | 207 | Sequence []int32 `protobuf:"varint,1,rep,packed,name=sequence,proto3" json:"sequence,omitempty"` 208 | } 209 | 210 | func (x *Sequence) Reset() { 211 | *x = Sequence{} 212 | if protoimpl.UnsafeEnabled { 213 | mi := &file_language_model_proto_msgTypes[2] 214 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 215 | ms.StoreMessageInfo(mi) 216 | } 217 | } 218 | 219 | func (x *Sequence) String() string { 220 | return protoimpl.X.MessageStringOf(x) 221 | } 222 | 223 | func (*Sequence) ProtoMessage() {} 224 | 225 | func (x *Sequence) ProtoReflect() protoreflect.Message { 226 | mi := &file_language_model_proto_msgTypes[2] 227 | if protoimpl.UnsafeEnabled && x != nil { 228 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 229 | if ms.LoadMessageInfo() == nil { 230 | ms.StoreMessageInfo(mi) 231 | } 232 | return ms 233 | } 234 | return mi.MessageOf(x) 235 | } 236 | 237 | // Deprecated: Use Sequence.ProtoReflect.Descriptor instead. 238 | func (*Sequence) Descriptor() ([]byte, []int) { 239 | return file_language_model_proto_rawDescGZIP(), []int{2} 240 | } 241 | 242 | func (x *Sequence) GetSequence() []int32 { 243 | if x != nil { 244 | return x.Sequence 245 | } 246 | return nil 247 | } 248 | 249 | // GeneratedToken contains a generated token, its score, and its encoded representation 250 | type GeneratedToken struct { 251 | state protoimpl.MessageState 252 | sizeCache protoimpl.SizeCache 253 | unknownFields protoimpl.UnknownFields 254 | 255 | // Token is the generated token 256 | Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` 257 | // Score is the sum of the negative log probabilities up to the current step. 258 | Score float32 `protobuf:"fixed32,2,opt,name=score,proto3" json:"score,omitempty"` 259 | } 260 | 261 | func (x *GeneratedToken) Reset() { 262 | *x = GeneratedToken{} 263 | if protoimpl.UnsafeEnabled { 264 | mi := &file_language_model_proto_msgTypes[3] 265 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 266 | ms.StoreMessageInfo(mi) 267 | } 268 | } 269 | 270 | func (x *GeneratedToken) String() string { 271 | return protoimpl.X.MessageStringOf(x) 272 | } 273 | 274 | func (*GeneratedToken) ProtoMessage() {} 275 | 276 | func (x *GeneratedToken) ProtoReflect() protoreflect.Message { 277 | mi := &file_language_model_proto_msgTypes[3] 278 | if protoimpl.UnsafeEnabled && x != nil { 279 | ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) 280 | if ms.LoadMessageInfo() == nil { 281 | ms.StoreMessageInfo(mi) 282 | } 283 | return ms 284 | } 285 | return mi.MessageOf(x) 286 | } 287 | 288 | // Deprecated: Use GeneratedToken.ProtoReflect.Descriptor instead. 289 | func (*GeneratedToken) Descriptor() ([]byte, []int) { 290 | return file_language_model_proto_rawDescGZIP(), []int{3} 291 | } 292 | 293 | func (x *GeneratedToken) GetToken() string { 294 | if x != nil { 295 | return x.Token 296 | } 297 | return "" 298 | } 299 | 300 | func (x *GeneratedToken) GetScore() float32 { 301 | if x != nil { 302 | return x.Score 303 | } 304 | return 0 305 | } 306 | 307 | var File_language_model_proto protoreflect.FileDescriptor 308 | 309 | var file_language_model_proto_rawDesc = []byte{ 310 | 0x0a, 0x14, 0x6c, 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 311 | 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x03, 0x61, 0x70, 0x69, 0x22, 0x7a, 0x0a, 0x16, 0x54, 312 | 0x6f, 0x6b, 0x65, 0x6e, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 313 | 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x18, 314 | 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x12, 0x48, 0x0a, 315 | 0x13, 0x64, 0x65, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x65, 316 | 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x61, 0x70, 0x69, 317 | 0x2e, 0x44, 0x65, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 318 | 0x65, 0x72, 0x73, 0x52, 0x12, 0x64, 0x65, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x50, 0x61, 0x72, 319 | 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x22, 0xb8, 0x02, 0x0a, 0x12, 0x44, 0x65, 0x63, 0x6f, 320 | 0x64, 0x69, 0x6e, 0x67, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, 0x17, 321 | 0x0a, 0x07, 0x6d, 0x61, 0x78, 0x5f, 0x6c, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 322 | 0x06, 0x6d, 0x61, 0x78, 0x4c, 0x65, 0x6e, 0x12, 0x17, 0x0a, 0x07, 0x6d, 0x69, 0x6e, 0x5f, 0x6c, 323 | 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x6d, 0x69, 0x6e, 0x4c, 0x65, 0x6e, 324 | 0x12, 0x20, 0x0a, 0x0b, 0x74, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 325 | 0x03, 0x20, 0x01, 0x28, 0x02, 0x52, 0x0b, 0x74, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, 0x75, 326 | 0x72, 0x65, 0x12, 0x13, 0x0a, 0x05, 0x74, 0x6f, 0x70, 0x5f, 0x6b, 0x18, 0x04, 0x20, 0x01, 0x28, 327 | 0x05, 0x52, 0x04, 0x74, 0x6f, 0x70, 0x4b, 0x12, 0x13, 0x0a, 0x05, 0x74, 0x6f, 0x70, 0x5f, 0x70, 328 | 0x18, 0x05, 0x20, 0x01, 0x28, 0x02, 0x52, 0x04, 0x74, 0x6f, 0x70, 0x50, 0x12, 0x21, 0x0a, 0x0c, 329 | 0x75, 0x73, 0x65, 0x5f, 0x73, 0x61, 0x6d, 0x70, 0x6c, 0x69, 0x6e, 0x67, 0x18, 0x06, 0x20, 0x01, 330 | 0x28, 0x08, 0x52, 0x0b, 0x75, 0x73, 0x65, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x69, 0x6e, 0x67, 0x12, 331 | 0x20, 0x0a, 0x0c, 0x65, 0x6e, 0x64, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x69, 0x64, 0x18, 332 | 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x65, 0x6e, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x49, 333 | 0x64, 0x12, 0x29, 0x0a, 0x11, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x65, 0x6e, 0x64, 0x5f, 0x74, 0x6f, 334 | 0x6b, 0x65, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x73, 0x6b, 335 | 0x69, 0x70, 0x45, 0x6e, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x49, 0x64, 0x12, 0x34, 0x0a, 0x0e, 336 | 0x73, 0x74, 0x6f, 0x70, 0x5f, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x18, 0x09, 337 | 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x53, 0x65, 0x71, 0x75, 0x65, 338 | 0x6e, 0x63, 0x65, 0x52, 0x0d, 0x73, 0x74, 0x6f, 0x70, 0x53, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 339 | 0x65, 0x73, 0x22, 0x26, 0x0a, 0x08, 0x53, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x1a, 340 | 0x0a, 0x08, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 341 | 0x52, 0x08, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x65, 0x22, 0x3c, 0x0a, 0x0e, 0x47, 0x65, 342 | 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x14, 0x0a, 0x05, 343 | 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b, 344 | 0x65, 0x6e, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x63, 0x6f, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 345 | 0x02, 0x52, 0x05, 0x73, 0x63, 0x6f, 0x72, 0x65, 0x32, 0x55, 0x0a, 0x0d, 0x4c, 0x61, 0x6e, 0x67, 346 | 0x75, 0x61, 0x67, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x44, 0x0a, 0x0e, 0x47, 0x65, 0x6e, 347 | 0x65, 0x72, 0x61, 0x74, 0x65, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x1b, 0x2e, 0x61, 0x70, 348 | 0x69, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 349 | 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x47, 350 | 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x64, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x30, 0x01, 0x42, 351 | 0x25, 0x5a, 0x23, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6e, 0x6c, 352 | 0x70, 0x6f, 0x64, 0x79, 0x73, 0x73, 0x65, 0x79, 0x2f, 0x76, 0x65, 0x72, 0x62, 0x61, 0x66, 0x6c, 353 | 0x6f, 0x77, 0x2f, 0x61, 0x70, 0x69, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 354 | } 355 | 356 | var ( 357 | file_language_model_proto_rawDescOnce sync.Once 358 | file_language_model_proto_rawDescData = file_language_model_proto_rawDesc 359 | ) 360 | 361 | func file_language_model_proto_rawDescGZIP() []byte { 362 | file_language_model_proto_rawDescOnce.Do(func() { 363 | file_language_model_proto_rawDescData = protoimpl.X.CompressGZIP(file_language_model_proto_rawDescData) 364 | }) 365 | return file_language_model_proto_rawDescData 366 | } 367 | 368 | var file_language_model_proto_msgTypes = make([]protoimpl.MessageInfo, 4) 369 | var file_language_model_proto_goTypes = []interface{}{ 370 | (*TokenGenerationRequest)(nil), // 0: api.TokenGenerationRequest 371 | (*DecodingParameters)(nil), // 1: api.DecodingParameters 372 | (*Sequence)(nil), // 2: api.Sequence 373 | (*GeneratedToken)(nil), // 3: api.GeneratedToken 374 | } 375 | var file_language_model_proto_depIdxs = []int32{ 376 | 1, // 0: api.TokenGenerationRequest.decoding_parameters:type_name -> api.DecodingParameters 377 | 2, // 1: api.DecodingParameters.stop_sequences:type_name -> api.Sequence 378 | 0, // 2: api.LanguageModel.GenerateTokens:input_type -> api.TokenGenerationRequest 379 | 3, // 3: api.LanguageModel.GenerateTokens:output_type -> api.GeneratedToken 380 | 3, // [3:4] is the sub-list for method output_type 381 | 2, // [2:3] is the sub-list for method input_type 382 | 2, // [2:2] is the sub-list for extension type_name 383 | 2, // [2:2] is the sub-list for extension extendee 384 | 0, // [0:2] is the sub-list for field type_name 385 | } 386 | 387 | func init() { file_language_model_proto_init() } 388 | func file_language_model_proto_init() { 389 | if File_language_model_proto != nil { 390 | return 391 | } 392 | if !protoimpl.UnsafeEnabled { 393 | file_language_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { 394 | switch v := v.(*TokenGenerationRequest); i { 395 | case 0: 396 | return &v.state 397 | case 1: 398 | return &v.sizeCache 399 | case 2: 400 | return &v.unknownFields 401 | default: 402 | return nil 403 | } 404 | } 405 | file_language_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { 406 | switch v := v.(*DecodingParameters); i { 407 | case 0: 408 | return &v.state 409 | case 1: 410 | return &v.sizeCache 411 | case 2: 412 | return &v.unknownFields 413 | default: 414 | return nil 415 | } 416 | } 417 | file_language_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { 418 | switch v := v.(*Sequence); i { 419 | case 0: 420 | return &v.state 421 | case 1: 422 | return &v.sizeCache 423 | case 2: 424 | return &v.unknownFields 425 | default: 426 | return nil 427 | } 428 | } 429 | file_language_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { 430 | switch v := v.(*GeneratedToken); i { 431 | case 0: 432 | return &v.state 433 | case 1: 434 | return &v.sizeCache 435 | case 2: 436 | return &v.unknownFields 437 | default: 438 | return nil 439 | } 440 | } 441 | } 442 | type x struct{} 443 | out := protoimpl.TypeBuilder{ 444 | File: protoimpl.DescBuilder{ 445 | GoPackagePath: reflect.TypeOf(x{}).PkgPath(), 446 | RawDescriptor: file_language_model_proto_rawDesc, 447 | NumEnums: 0, 448 | NumMessages: 4, 449 | NumExtensions: 0, 450 | NumServices: 1, 451 | }, 452 | GoTypes: file_language_model_proto_goTypes, 453 | DependencyIndexes: file_language_model_proto_depIdxs, 454 | MessageInfos: file_language_model_proto_msgTypes, 455 | }.Build() 456 | File_language_model_proto = out.File 457 | file_language_model_proto_rawDesc = nil 458 | file_language_model_proto_goTypes = nil 459 | file_language_model_proto_depIdxs = nil 460 | } 461 | -------------------------------------------------------------------------------- /api/language_model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package api; 4 | 5 | option go_package = "github.com/nlpodyssey/verbaflow/api"; 6 | 7 | // LanguageModel is a gRPC service for generating tokens from a language model 8 | service LanguageModel { 9 | // GenerateTokens generates tokens for the given prompt using the specified decoding parameters. 10 | // The response is a stream of GeneratedToken messages, each containing a generated token and its score and encoded representation. 11 | rpc GenerateTokens (TokenGenerationRequest) returns (stream GeneratedToken); 12 | } 13 | 14 | // TokenGenerationRequest contains the prompt and decoding parameters for generating tokens 15 | message TokenGenerationRequest { 16 | // Prompt is the input string to use as a starting point for token generation 17 | string prompt = 1; 18 | // DecodingParameters are the parameters to use for token generation 19 | DecodingParameters decoding_parameters = 2; 20 | } 21 | 22 | // DecodingParameters contains the parameters to use for token generation 23 | message DecodingParameters { 24 | // MaxLen is the maximum number of tokens to generate. 25 | int32 max_len = 1; 26 | // MinLen is the minimum number of tokens to generate. 27 | int32 min_len = 2; 28 | // Temperature controls the randomness of the generated tokens. A higher temperature will result in more diverse generated tokens. 29 | float temperature = 3; 30 | // TopK is the maximum number of tokens to consider when sampling the next token. 31 | int32 top_k = 4; 32 | // TopP is the cumulative probability of the tokens to consider when sampling the next token. 33 | float top_p = 5; 34 | // UseSampling uses sampling to generate the next token. 35 | bool use_sampling = 6; 36 | // EndTokenID is the end-of-sequence token (default: 0). 37 | int32 end_token_id = 7; 38 | // SkipEndTokenID when true, the end token is not added to the generated sequence. 39 | bool skip_end_token_id = 8; 40 | // StopSequences are the sequences of token ids that will cause the generation to stop. 41 | repeated Sequence stop_sequences = 9; 42 | } 43 | 44 | // Sequence is a sequence of token ids 45 | message Sequence { 46 | // Sequence is the sequence of token ids 47 | repeated int32 sequence=1; 48 | } 49 | 50 | // GeneratedToken contains a generated token, its score, and its encoded representation 51 | message GeneratedToken { 52 | // Token is the generated token 53 | string token = 1; 54 | // Score is the sum of the negative log probabilities up to the current step. 55 | float score = 2; 56 | } -------------------------------------------------------------------------------- /api/language_model_grpc.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go-grpc. DO NOT EDIT. 2 | // versions: 3 | // - protoc-gen-go-grpc v1.2.0 4 | // - protoc v3.21.5 5 | // source: language_model.proto 6 | 7 | package api 8 | 9 | import ( 10 | context "context" 11 | grpc "google.golang.org/grpc" 12 | codes "google.golang.org/grpc/codes" 13 | status "google.golang.org/grpc/status" 14 | ) 15 | 16 | // This is a compile-time assertion to ensure that this generated file 17 | // is compatible with the grpc package it is being compiled against. 18 | // Requires gRPC-Go v1.32.0 or later. 19 | const _ = grpc.SupportPackageIsVersion7 20 | 21 | // LanguageModelClient is the client API for LanguageModel service. 22 | // 23 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. 24 | type LanguageModelClient interface { 25 | // GenerateTokens generates tokens for the given prompt using the specified decoding parameters. 26 | // The response is a stream of GeneratedToken messages, each containing a generated token and its score and encoded representation. 27 | GenerateTokens(ctx context.Context, in *TokenGenerationRequest, opts ...grpc.CallOption) (LanguageModel_GenerateTokensClient, error) 28 | } 29 | 30 | type languageModelClient struct { 31 | cc grpc.ClientConnInterface 32 | } 33 | 34 | func NewLanguageModelClient(cc grpc.ClientConnInterface) LanguageModelClient { 35 | return &languageModelClient{cc} 36 | } 37 | 38 | func (c *languageModelClient) GenerateTokens(ctx context.Context, in *TokenGenerationRequest, opts ...grpc.CallOption) (LanguageModel_GenerateTokensClient, error) { 39 | stream, err := c.cc.NewStream(ctx, &LanguageModel_ServiceDesc.Streams[0], "/api.LanguageModel/GenerateTokens", opts...) 40 | if err != nil { 41 | return nil, err 42 | } 43 | x := &languageModelGenerateTokensClient{stream} 44 | if err := x.ClientStream.SendMsg(in); err != nil { 45 | return nil, err 46 | } 47 | if err := x.ClientStream.CloseSend(); err != nil { 48 | return nil, err 49 | } 50 | return x, nil 51 | } 52 | 53 | type LanguageModel_GenerateTokensClient interface { 54 | Recv() (*GeneratedToken, error) 55 | grpc.ClientStream 56 | } 57 | 58 | type languageModelGenerateTokensClient struct { 59 | grpc.ClientStream 60 | } 61 | 62 | func (x *languageModelGenerateTokensClient) Recv() (*GeneratedToken, error) { 63 | m := new(GeneratedToken) 64 | if err := x.ClientStream.RecvMsg(m); err != nil { 65 | return nil, err 66 | } 67 | return m, nil 68 | } 69 | 70 | // LanguageModelServer is the server API for LanguageModel service. 71 | // All implementations must embed UnimplementedLanguageModelServer 72 | // for forward compatibility 73 | type LanguageModelServer interface { 74 | // GenerateTokens generates tokens for the given prompt using the specified decoding parameters. 75 | // The response is a stream of GeneratedToken messages, each containing a generated token and its score and encoded representation. 76 | GenerateTokens(*TokenGenerationRequest, LanguageModel_GenerateTokensServer) error 77 | mustEmbedUnimplementedLanguageModelServer() 78 | } 79 | 80 | // UnimplementedLanguageModelServer must be embedded to have forward compatible implementations. 81 | type UnimplementedLanguageModelServer struct { 82 | } 83 | 84 | func (UnimplementedLanguageModelServer) GenerateTokens(*TokenGenerationRequest, LanguageModel_GenerateTokensServer) error { 85 | return status.Errorf(codes.Unimplemented, "method GenerateTokens not implemented") 86 | } 87 | func (UnimplementedLanguageModelServer) mustEmbedUnimplementedLanguageModelServer() {} 88 | 89 | // UnsafeLanguageModelServer may be embedded to opt out of forward compatibility for this service. 90 | // Use of this interface is not recommended, as added methods to LanguageModelServer will 91 | // result in compilation errors. 92 | type UnsafeLanguageModelServer interface { 93 | mustEmbedUnimplementedLanguageModelServer() 94 | } 95 | 96 | func RegisterLanguageModelServer(s grpc.ServiceRegistrar, srv LanguageModelServer) { 97 | s.RegisterService(&LanguageModel_ServiceDesc, srv) 98 | } 99 | 100 | func _LanguageModel_GenerateTokens_Handler(srv interface{}, stream grpc.ServerStream) error { 101 | m := new(TokenGenerationRequest) 102 | if err := stream.RecvMsg(m); err != nil { 103 | return err 104 | } 105 | return srv.(LanguageModelServer).GenerateTokens(m, &languageModelGenerateTokensServer{stream}) 106 | } 107 | 108 | type LanguageModel_GenerateTokensServer interface { 109 | Send(*GeneratedToken) error 110 | grpc.ServerStream 111 | } 112 | 113 | type languageModelGenerateTokensServer struct { 114 | grpc.ServerStream 115 | } 116 | 117 | func (x *languageModelGenerateTokensServer) Send(m *GeneratedToken) error { 118 | return x.ServerStream.SendMsg(m) 119 | } 120 | 121 | // LanguageModel_ServiceDesc is the grpc.ServiceDesc for LanguageModel service. 122 | // It's only intended for direct use with grpc.RegisterService, 123 | // and not to be introspected or modified (even as a copy) 124 | var LanguageModel_ServiceDesc = grpc.ServiceDesc{ 125 | ServiceName: "api.LanguageModel", 126 | HandlerType: (*LanguageModelServer)(nil), 127 | Methods: []grpc.MethodDesc{}, 128 | Streams: []grpc.StreamDesc{ 129 | { 130 | StreamName: "GenerateTokens", 131 | Handler: _LanguageModel_GenerateTokens_Handler, 132 | ServerStreams: true, 133 | }, 134 | }, 135 | Metadata: "language_model.proto", 136 | } 137 | -------------------------------------------------------------------------------- /cmd/verbaflow/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "os" 11 | "os/signal" 12 | "path/filepath" 13 | "strings" 14 | 15 | "github.com/nlpodyssey/spago/ag" 16 | "github.com/nlpodyssey/verbaflow" 17 | "github.com/nlpodyssey/verbaflow/downloader" 18 | "github.com/nlpodyssey/verbaflow/rwkvlm" 19 | "github.com/nlpodyssey/verbaflow/service" 20 | "github.com/rs/zerolog" 21 | "github.com/rs/zerolog/log" 22 | "github.com/urfave/cli/v2" 23 | ) 24 | 25 | func main() { 26 | log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Level(zerolog.InfoLevel) 27 | 28 | app := &cli.App{ 29 | Name: "verbaflow", 30 | Usage: "Perform various operations with a language model", 31 | Flags: []cli.Flag{ 32 | &cli.StringFlag{ 33 | Name: "log-level", 34 | Usage: "set log level (trace, debug, info, warn, error, fatal, panic)", 35 | Action: func(c *cli.Context, s string) error { 36 | return setDebugLevel(s) 37 | }, 38 | Value: "info", 39 | EnvVars: []string{"VERBAFLOW_LOGLEVEL"}, 40 | }, 41 | &cli.StringFlag{ 42 | Name: "model-dir", 43 | Usage: "directory of the model to operate on", 44 | Required: true, 45 | }, 46 | }, 47 | Commands: []*cli.Command{ 48 | { 49 | Name: "download", 50 | Usage: "Download model to directory", 51 | Action: func(c *cli.Context) error { 52 | if err := download(c.String("model-dir")); err != nil { 53 | log.Err(err).Send() 54 | } 55 | return nil 56 | }, 57 | }, 58 | { 59 | Name: "convert", 60 | Usage: "Convert model in directory", 61 | Action: func(c *cli.Context) error { 62 | if err := convert(c.String("model-dir")); err != nil { 63 | log.Fatal().Err(err).Send() 64 | } 65 | return nil 66 | }, 67 | }, 68 | { 69 | Name: "inference", 70 | Usage: "Serve a gRPC inference endpoint", 71 | Action: func(c *cli.Context) error { 72 | modelDir := c.String("model-dir") 73 | address := c.String("address") 74 | 75 | ctx, stop := signal.NotifyContext(c.Context, os.Interrupt, os.Kill) 76 | defer stop() 77 | 78 | if err := inference(ctx, modelDir, address); err != nil { 79 | fmt.Print(err) 80 | log.Err(err).Send() 81 | } 82 | return nil 83 | }, 84 | Flags: []cli.Flag{ 85 | &cli.StringFlag{ 86 | Name: "address", 87 | Usage: "The address to listen on for gRPC connections", 88 | Value: ":50051", 89 | Required: false, 90 | }, 91 | }, 92 | }, 93 | }, 94 | } 95 | 96 | if err := app.Run(os.Args); err != nil { 97 | log.Fatal().Err(err).Send() 98 | } 99 | } 100 | 101 | func setDebugLevel(debugLevel string) error { 102 | level, err := zerolog.ParseLevel(debugLevel) 103 | if err != nil { 104 | return err 105 | } 106 | log.Logger = log.Level(level) 107 | return nil 108 | } 109 | 110 | func download(modelDir string) error { 111 | log.Debug().Msgf("Downloading model in dir: %s", modelDir) 112 | dir, name, err := splitPathAndModelName(modelDir) 113 | if err != nil { 114 | log.Fatal().Err(err).Send() 115 | } 116 | err = downloader.Download(dir, name, false, "") 117 | if err != nil { 118 | log.Fatal().Err(err).Send() 119 | } 120 | log.Debug().Msg("Done.") 121 | return nil 122 | } 123 | 124 | func convert(modelDir string) error { 125 | log.Debug().Msgf("Converting model in dir: %s", modelDir) 126 | err := rwkvlm.ConvertPickledModelToRWKVLM[float32](rwkvlm.ConverterConfig{ 127 | ModelDir: modelDir, 128 | OverwriteIfExist: false, 129 | }) 130 | if err != nil { 131 | log.Fatal().Err(err).Send() 132 | } 133 | log.Debug().Msg("Done.") 134 | return nil 135 | } 136 | 137 | func inference(ctx context.Context, modelDir string, address string) error { 138 | log.Debug().Msgf("Starting inference server for model in dir: %s", modelDir) 139 | log.Debug().Msgf("Loading model...") 140 | vf, err := verbaflow.Load(modelDir) 141 | if err != nil { 142 | return err 143 | } 144 | defer vf.Close() 145 | 146 | log.Debug().Msgf("Server listening on %s", address) 147 | server := service.NewServer(vf) 148 | return server.Start(ctx, address) 149 | } 150 | 151 | // splitPathAndModelName separate the models directory from the model name, which format is "organization/model" 152 | func splitPathAndModelName(path string) (string, string, error) { 153 | dirs := strings.Split(strings.TrimSuffix(path, "/"), "/") 154 | if len(dirs) < 3 { 155 | return "", "", fmt.Errorf("path must have at least three levels of directories") 156 | } 157 | lastDir := dirs[len(dirs)-1] 158 | secondLastDir := dirs[len(dirs)-2] 159 | 160 | pathExceptLastTwo := strings.Join(dirs[:len(dirs)-2], "/") 161 | return pathExceptLastTwo, filepath.Join(secondLastDir, lastDir), nil 162 | } 163 | 164 | func init() { 165 | ag.SetDebugMode(false) 166 | } 167 | -------------------------------------------------------------------------------- /decoder/control.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package decoder 6 | 7 | import ( 8 | "container/heap" 9 | "fmt" 10 | "math" 11 | "sort" 12 | 13 | "github.com/nlpodyssey/spago/mat" 14 | "github.com/nlpodyssey/spago/mat/float" 15 | "github.com/nlpodyssey/verbaflow/sliceutils" 16 | "github.com/rs/zerolog/log" 17 | ) 18 | 19 | // OutputDiversityControlFunc performs the pre-processing steps that are used to narrow down the set of candidate items 20 | // before using greedy decoding or multinomial sampling to generate the final output. 21 | type OutputDiversityControlFunc func(logits mat.Matrix) (mat.Matrix, error) 22 | 23 | // OutputDiversityControl returns a function used to select the next token. 24 | func OutputDiversityControl(temp float64, topK int, topP float64) (OutputDiversityControlFunc, error) { 25 | if temp < 0 || temp > 1 { 26 | return nil, fmt.Errorf("invalid temperature value: %f. Must be between 0 and 1", temp) 27 | } 28 | if topK < 0 { 29 | return nil, fmt.Errorf("invalid topK value: %d. Must be >= 0", topK) 30 | } 31 | if topP < 0 || topP > 1 { 32 | return nil, fmt.Errorf("invalid topP value: %f. Must be between 0 and 1", topP) 33 | } 34 | 35 | result := make([]OutputDiversityControlFunc, 0, 3) 36 | if temp != 1 { 37 | log.Trace().Float64("temperature", temp).Msg("Applying temperature control") 38 | if temp == 0 { 39 | log.Trace().Msg("Temperature is 0, setting it to 0.01 to avoid division by zero") 40 | temp = 0.01 // avoid division by zero 41 | } 42 | result = append(result, TemperatureFunc(temp)) 43 | } 44 | if topK != 0 { 45 | log.Trace().Int("topK", topK).Msg("Applying topK control") 46 | result = append(result, TopKFunc(topK, math.Inf(-1))) 47 | } 48 | if topP != 1 { 49 | log.Trace().Float64("topP", topP).Msg("Applying topP control") 50 | result = append(result, TopPFunc(topP, math.Inf(-1), 1)) // minSize = 2 if beam search is enabled 51 | } 52 | 53 | return func(logits mat.Matrix) (mat.Matrix, error) { 54 | var err error 55 | for _, p := range result { 56 | logits, err = p(logits) 57 | if err != nil { 58 | return nil, err 59 | } 60 | } 61 | return logits, err 62 | }, nil 63 | } 64 | 65 | // TemperatureFunc applies a temperature to a matrix of scores. 66 | func TemperatureFunc(temperature float64) OutputDiversityControlFunc { 67 | if temperature == 1 { 68 | return func(scores mat.Matrix) (mat.Matrix, error) { 69 | return scores, nil 70 | } 71 | } 72 | invTemperature := 1 / temperature 73 | return func(scores mat.Matrix) (mat.Matrix, error) { 74 | return scores.ProdScalar(invTemperature), nil 75 | } 76 | } 77 | 78 | // TopKFunc applies a top-k filter to a matrix of scores. 79 | func TopKFunc(topK int, filterValue float64) OutputDiversityControlFunc { 80 | return func(scores mat.Matrix) (mat.Matrix, error) { 81 | topK := topK 82 | if size := scores.Size(); size <= topK { 83 | topK = size 84 | } 85 | 86 | inScores := scores.Data().F64() 87 | 88 | rawTopScores := make(sliceutils.OrderedHeap[float64], len(inScores)) 89 | copy(rawTopScores, inScores) 90 | 91 | topScores := sliceutils.ReverseHeap(&rawTopScores) 92 | heap.Init(topScores) 93 | for i := 1; i < topK; i++ { 94 | heap.Pop(topScores) 95 | } 96 | minScore := heap.Pop(topScores).(float64) 97 | 98 | return scores.Apply(func(_, _ int, v float64) float64 { 99 | if v < minScore { 100 | return filterValue 101 | } 102 | return v 103 | }), nil 104 | } 105 | } 106 | 107 | // TopPFunc applies a top-p filter to a matrix of scores. 108 | // Note that when using beam decoding (with beam > 1) then minSize must be at least 2. 109 | func TopPFunc[T float.DType](topP, filterValue T, minSize int) OutputDiversityControlFunc { 110 | return func(scores mat.Matrix) (mat.Matrix, error) { 111 | dataCopy := make([]T, scores.Size()) 112 | copy(dataCopy, mat.Data[T](scores)) 113 | sortedData := sliceutils.NewIndexedSlice[T](dataCopy) 114 | sort.Stable(sort.Reverse(sortedData)) 115 | 116 | cumulativeProbs := mat.NewVecDense(sortedData.Slice).Softmax().CumSum() 117 | cumProbData := mat.Data[T](cumulativeProbs) 118 | 119 | indicesToRemove := make([]bool, len(cumProbData)) 120 | for i, cp := range cumProbData { 121 | indicesToRemove[i] = cp > topP 122 | } 123 | 124 | if minSize > 1 { 125 | // Keep at least minSize (minSize-1 because we add the first one below) 126 | for i := minSize - 1; i >= 0; i-- { 127 | indicesToRemove[i] = false 128 | } 129 | } 130 | 131 | // Shift the indices to the right to keep also the first token above the threshold 132 | copy(indicesToRemove[1:], indicesToRemove[:len(indicesToRemove)-1]) 133 | indicesToRemove[0] = false 134 | 135 | // Scatter sorted tensors to original indexing 136 | 137 | outData := make([]T, scores.Size()) 138 | copy(outData, mat.Data[T](scores)) 139 | for maskIndex, toRemove := range indicesToRemove { 140 | if !toRemove { 141 | continue 142 | } 143 | index := sortedData.Indices[maskIndex] 144 | outData[index] = filterValue 145 | } 146 | 147 | return mat.NewVecDense[T](outData), nil 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /decoder/decoder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package decoder 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "math" 11 | "reflect" 12 | 13 | "github.com/nlpodyssey/rwkv" 14 | "github.com/nlpodyssey/spago/ag" 15 | "github.com/nlpodyssey/spago/mat" 16 | "github.com/nlpodyssey/spago/mat/float" 17 | "github.com/nlpodyssey/verbaflow/encoder" 18 | "github.com/nlpodyssey/verbaflow/rwkvlm" 19 | "github.com/rs/zerolog/log" 20 | ) 21 | 22 | var floatNegInf = float.Interface(math.Inf(-1)) 23 | 24 | type Decoder struct { 25 | model *rwkvlm.Model 26 | applyOutputControl OutputDiversityControlFunc 27 | applySelection OutputSelectionFunc 28 | opts DecodingOptions 29 | } 30 | 31 | // DecodingOptions contains the options for the conditional text generation. 32 | type DecodingOptions struct { 33 | // MaxLen is the maximum number of tokens to generate. 34 | MaxLen int `json:"max_len" yaml:"max_len"` 35 | // MinLen is the minimum number of tokens to generate. 36 | MinLen int `json:"min_len" yaml:"min_len"` 37 | // StopSequencesIDs is a list of token ids that if generated, the generation process will stop. 38 | StopSequencesIDs [][]int `json:"stop_sequences_ids" yaml:"stop_sequences_ids"` 39 | // EndTokenID is the end-of-sequence token (default: 0). 40 | EndTokenID int `json:"end_token_id" yaml:"end_token_id"` 41 | // SkipEndTokenID when true, the end token is not added to the generated sequence. 42 | SkipEndTokenID bool `json:"skip_end_token_id" yaml:"skip_end_token_id"` 43 | // Temperature is the temperature used to control the randomness of the generated text. 44 | Temp float64 `json:"temp" yaml:"temp"` 45 | // TopK is the number of tokens to consider when sampling the next token. 46 | TopK int `json:"top_k" yaml:"top_k"` 47 | // TopP is the cumulative probability of the tokens to consider when sampling the next token. 48 | TopP float64 `json:"top_p" yaml:"top_p"` 49 | // UseSampling uses sampling to generate the next token. 50 | UseSampling bool `json:"use_sampling" yaml:"use_sampling"` 51 | } 52 | 53 | // GeneratedToken is the result of a single step of the decoder. 54 | type GeneratedToken struct { 55 | // TokenID is the ID of the token predicted by the decoder at the current step. 56 | TokenID int 57 | // SumNegLogProbs is the sum of the negative log probabilities up to the current step. 58 | SumNegLogProbs float64 59 | } 60 | 61 | func New(m *rwkvlm.Model, opts DecodingOptions) (*Decoder, error) { 62 | dc, err := OutputDiversityControl(opts.Temp, opts.TopK, opts.TopP) 63 | if err != nil { 64 | return nil, err 65 | } 66 | return &Decoder{ 67 | model: m, 68 | opts: opts, 69 | applyOutputControl: dc, 70 | applySelection: OutputSelection(opts.UseSampling), 71 | }, nil 72 | } 73 | 74 | func (d *Decoder) Decode(ctx context.Context, nt *ag.NodesTracker, input encoder.Result, chGen chan GeneratedToken) error { 75 | defer close(chGen) 76 | 77 | x, s := input.Encoding, input.State 78 | if x == nil || s == nil { 79 | return fmt.Errorf("invalid input: hidden representation and state are required") 80 | } 81 | 82 | var sequence []int 83 | var sumNegLogProbs float64 84 | 85 | Loop: 86 | for i := 0; ; i++ { 87 | select { 88 | case <-ctx.Done(): 89 | log.Trace().Msgf("Generation cancelled after %d steps due to context cancellation", i) 90 | break Loop 91 | default: 92 | tokenID, tokenScore, err := d.generateToken(ctx, x, i, nt) 93 | if err != nil { 94 | return err 95 | } 96 | sequence = append(sequence, tokenID) 97 | sumNegLogProbs -= math.Log(tokenScore) 98 | 99 | chGen <- GeneratedToken{ 100 | TokenID: tokenID, 101 | SumNegLogProbs: sumNegLogProbs, 102 | } 103 | 104 | if d.checkStopConditions(sequence) { 105 | break Loop 106 | } 107 | 108 | // update the hidden representation `x` with the result of encoding the last generated token, 109 | // which is used as input for the next iteration of the loop. 110 | x, err = d.encode(ctx, nt, tokenID, s) 111 | if err != nil { 112 | return err 113 | } 114 | } 115 | } 116 | 117 | log.Trace().Msgf("[%.2f] Generated token IDs: %v", sumNegLogProbs, sequence) 118 | 119 | return nil 120 | } 121 | 122 | // generateToken performs a single step of the decoding process. 123 | // It returns the selected output token ID and its score. 124 | func (d *Decoder) generateToken(_ context.Context, x ag.Node, seqLen int, nt *ag.NodesTracker) (int, float64, error) { 125 | logits := nt.TrackNode(d.model.Predict(x)) 126 | candidates, err := d.applyOutputControl(d.adjustLogits(logits.Value(), seqLen)) 127 | if err != nil { 128 | return 0, 0, err 129 | } 130 | return d.applySelection(candidates) 131 | } 132 | 133 | // adjustLogits checks if the sequence is too short and if so, set the logits of the end token to a very low value. 134 | func (d *Decoder) adjustLogits(logits mat.Matrix, sequenceLength int) mat.Matrix { 135 | if sequenceLength >= d.opts.MinLen { 136 | return logits 137 | } 138 | log.Trace().Msgf("Sequence too short (%d), setting end token (%d) logits to -inf", sequenceLength, d.opts.EndTokenID) 139 | logits.SetVecScalar(d.opts.EndTokenID, floatNegInf) 140 | return logits 141 | } 142 | 143 | func (d *Decoder) checkStopConditions(sequence []int) bool { 144 | if len(sequence) >= d.opts.MaxLen { 145 | log.Trace().Msgf("Reached max length (%d)", d.opts.MaxLen) 146 | return true 147 | } 148 | last := sequence[len(sequence)-1] 149 | if last == d.opts.EndTokenID { 150 | log.Trace().Msgf("Reached end token (%d)", d.opts.EndTokenID) 151 | return true 152 | } 153 | if len(sequence) >= d.opts.MinLen && hasStopSequence(sequence, d.opts.StopSequencesIDs) { 154 | return true 155 | } 156 | return false 157 | } 158 | 159 | func hasStopSequence(sequence []int, stopSequences [][]int) bool { 160 | for _, stopSeq := range stopSequences { 161 | if len(sequence) < len(stopSeq) { 162 | continue 163 | } 164 | 165 | if reflect.DeepEqual(stopSeq, sequence[len(sequence)-len(stopSeq):]) { 166 | log.Trace().Msgf("Reached stop sequence %v", stopSeq) 167 | return true 168 | } 169 | } 170 | return false 171 | } 172 | 173 | func (d *Decoder) encode(ctx context.Context, nt *ag.NodesTracker, tokenID int, state rwkv.State) (ag.Node, error) { 174 | x, s := d.model.Encode(ctx, state, tokenID) 175 | nt.TrackNodes(waitForNodes(extractNodesToRelease(x, s))...) 176 | return x, nil 177 | } 178 | 179 | // waitForNodes waits for the nodes to be computed. 180 | // It is used to ensure that the nodes are computed before releasing them. 181 | func waitForNodes(nodes []ag.Node) []ag.Node { 182 | for _, n := range nodes { 183 | n.Value() 184 | } 185 | return nodes 186 | } 187 | 188 | // extractNodesToRelease extracts the nodes to release from the states. 189 | // It also considers the explicit x node. 190 | func extractNodesToRelease(x ag.Node, s rwkv.State) []ag.Node { 191 | nodes := []ag.Node{x} 192 | for _, layer := range s { 193 | nodes = append(nodes, layer.FfnXX, layer.AttXX, layer.AttAA, layer.AttBB, layer.AttPP) 194 | } 195 | return nodes 196 | } 197 | -------------------------------------------------------------------------------- /decoder/selection.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package decoder 6 | 7 | import ( 8 | "fmt" 9 | 10 | "github.com/nlpodyssey/spago/mat" 11 | "github.com/nlpodyssey/spago/mat/rand" 12 | "github.com/rs/zerolog/log" 13 | ) 14 | 15 | type OutputSelectionFunc func(logits mat.Matrix) (int, float64, error) 16 | 17 | func OutputSelection(sampling bool) OutputSelectionFunc { 18 | if sampling { 19 | log.Trace().Msg("using multinomial sampling") 20 | return MultinomialSampling() 21 | } 22 | log.Trace().Msg("using greedy decoding") 23 | return GreedyDecoding() 24 | } 25 | 26 | func GreedyDecoding() OutputSelectionFunc { 27 | return func(logits mat.Matrix) (int, float64, error) { 28 | probs := logits.Softmax() 29 | argmax := probs.ArgMax() 30 | return argmax, probs.ScalarAtVec(argmax).F64(), nil 31 | } 32 | } 33 | 34 | func MultinomialSampling() OutputSelectionFunc { 35 | return func(logits mat.Matrix) (int, float64, error) { 36 | probs := logits.Softmax() 37 | samples, err := multinomial(probs, 1) 38 | if err != nil { 39 | return 0, 0, err 40 | } 41 | return samples[0], probs.ScalarAtVec(samples[0]).F64(), nil 42 | } 43 | } 44 | 45 | // multinomial extracts the next indices from a multinomial probability distribution. 46 | func multinomial(input mat.Matrix, numSamples int) ([]int, error) { 47 | if numSamples > input.Size() { 48 | return nil, fmt.Errorf("numSamples (%d) must be less than or equal to the size of the input (%d)", numSamples, input.Size()) 49 | } 50 | 51 | samples := make([]int, 0, numSamples) 52 | samplesMap := make(map[int]struct{}, numSamples) 53 | 54 | data := input.Data().F64() 55 | for len(samples) < numSamples { 56 | p := rand.Float[float64]() 57 | 58 | for i, value := range data { 59 | p -= value 60 | if p < 0 { 61 | if _, alreadySampled := samplesMap[i]; !alreadySampled { 62 | samplesMap[i] = struct{}{} 63 | samples = append(samples, i) 64 | } 65 | break 66 | } 67 | } 68 | } 69 | 70 | return samples, nil 71 | } 72 | -------------------------------------------------------------------------------- /downloader/downloadmodel.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package downloader 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "os" 12 | "path/filepath" 13 | 14 | "github.com/rs/zerolog/log" 15 | ) 16 | 17 | const ( 18 | // Hugging Face repository URL, in the format: 19 | // "https://huggingface.co/{model_id}/resolve/{revision}/{filename}" 20 | huggingFaceCoPrefix = "https://huggingface.co/%s/resolve/%s/%s" 21 | // Default revision name for fetching model from Hugging Face repository 22 | defaultRevision = "main" 23 | ) 24 | 25 | // modelsFiles contains the set of files to download. 26 | var modelsFiles = []string{ 27 | "config.json", "pytorch_model.pt", "vocab.json", "merges.txt", 28 | } 29 | 30 | // Download downloads a supported pre-trained model from huggingface.co 31 | // repositories. 32 | // 33 | // If one or more directory levels don't yet exist, they are created 34 | // setting the permissions bits to 0755 (rwxr-xr-x). 35 | // 36 | // By setting the flag overwriteIfExist to false, any file that already 37 | // exists is kept and considered as already successfully downloaded. If 38 | // the flag is otherwise set to true, existing files will be forcefully 39 | // downloaded and overwritten. 40 | func Download(modelsDir, modelName string, overwriteIfExists bool, accessToken string) error { 41 | return downloader{ 42 | modelPath: filepath.Join(modelsDir, modelName), 43 | modelName: modelName, 44 | overwriteIfExist: overwriteIfExists, 45 | accessToken: accessToken, 46 | }.download() 47 | } 48 | 49 | // downloader is a helper struct for downloading a model. 50 | type downloader struct { 51 | modelPath string 52 | modelName string 53 | accessToken string 54 | overwriteIfExist bool 55 | } 56 | 57 | func (d downloader) download() error { 58 | if err := d.ensureModelPath(); err != nil { 59 | return err 60 | } 61 | for _, filename := range modelsFiles { 62 | if err := d.downloadFile(filename); err != nil { 63 | return err 64 | } 65 | } 66 | return nil 67 | 68 | } 69 | 70 | func (d downloader) ensureModelPath() error { 71 | if info, err := os.Stat(d.modelPath); err == nil && info.IsDir() { 72 | return nil 73 | } 74 | if err := os.MkdirAll(d.modelPath, 0755); err != nil { 75 | return fmt.Errorf("error creating model path %#v: %w", d.modelPath, err) 76 | } 77 | return nil 78 | } 79 | 80 | func (d downloader) downloadFile(name string) (err error) { 81 | fPath := filepath.Join(d.modelPath, name) 82 | if info, err := os.Stat(fPath); !d.overwriteIfExist && err == nil && !info.IsDir() { 83 | log.Debug().Str("file", fPath).Msg("model file already exists, skipping download") 84 | return nil 85 | } 86 | 87 | url := d.bucketURL(name) 88 | log.Debug().Str("url", url).Str("destination", fPath).Msg("downloading") 89 | 90 | f, err := os.Create(fPath) 91 | if err != nil { 92 | return fmt.Errorf("error creating file %#v: %w", fPath, err) 93 | } 94 | defer func() { 95 | if e := f.Close(); e != nil && err == nil { 96 | err = fmt.Errorf("error closing file %#v: %w", fPath, e) 97 | } 98 | }() 99 | 100 | resp, err := d.httpGet(url) 101 | if err != nil { 102 | return fmt.Errorf("error getting %#v: %w", url, err) 103 | } 104 | defer func() { 105 | if e := resp.Body.Close(); e != nil && err == nil { 106 | err = fmt.Errorf("error closing %#v response body: %w", url, e) 107 | } 108 | }() 109 | 110 | if resp.StatusCode != http.StatusOK { 111 | return fmt.Errorf("%#v responded with %s", url, resp.Status) 112 | } 113 | 114 | prog := newDownloadProgress(int(resp.ContentLength)) 115 | prog.Start() 116 | defer prog.Stop() 117 | 118 | _, err = io.Copy(f, io.TeeReader(resp.Body, prog)) 119 | if err != nil { 120 | return fmt.Errorf("error downloading %#v to %#v: %w", url, fPath, err) 121 | } 122 | return nil 123 | } 124 | 125 | func (d downloader) httpGet(url string) (*http.Response, error) { 126 | req, err := http.NewRequest("GET", url, nil) 127 | if err != nil { 128 | return nil, err 129 | } 130 | if d.accessToken != "" { 131 | req.Header.Set("Authorization", "Bearer "+d.accessToken) 132 | } 133 | return http.DefaultClient.Do(req) 134 | } 135 | 136 | func (d downloader) bucketURL(fileName string) string { 137 | return fmt.Sprintf(huggingFaceCoPrefix, d.modelName, defaultRevision, fileName) 138 | } 139 | -------------------------------------------------------------------------------- /downloader/downloadprogress.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package downloader 6 | 7 | import ( 8 | "fmt" 9 | "sync" 10 | "time" 11 | 12 | "github.com/rs/zerolog/log" 13 | ) 14 | 15 | // downloadProgress is a helper struct for reporting download progress. 16 | type downloadProgress struct { 17 | contentLength int 18 | readContentLength int 19 | stopCh chan struct{} 20 | wg sync.WaitGroup 21 | } 22 | 23 | const downloadProgressUpdateFrequency = 3 * time.Second 24 | 25 | func newDownloadProgress(contentLength int) *downloadProgress { 26 | return &downloadProgress{ 27 | contentLength: contentLength, 28 | readContentLength: 0, 29 | stopCh: nil, 30 | } 31 | } 32 | 33 | // Start starts the progress reporting goroutine. 34 | func (dp *downloadProgress) Start() { 35 | dp.stopCh = make(chan struct{}, 1) 36 | dp.wg.Add(1) 37 | go dp.goRoutine() 38 | } 39 | 40 | // Stop stops the progress reporting goroutine. 41 | func (dp *downloadProgress) Stop() { 42 | dp.stopCh <- struct{}{} 43 | dp.wg.Wait() 44 | dp.stopCh = nil 45 | } 46 | 47 | func (dp *downloadProgress) goRoutine() { 48 | stopCh := dp.stopCh 49 | 50 | for { 51 | select { 52 | case <-stopCh: 53 | dp.reportProgress() 54 | close(stopCh) 55 | dp.wg.Done() 56 | return 57 | case <-time.After(downloadProgressUpdateFrequency): 58 | dp.reportProgress() 59 | } 60 | } 61 | } 62 | 63 | func (dp *downloadProgress) reportProgress() { 64 | cl := dp.contentLength 65 | rcl := dp.readContentLength 66 | hrcl := humanizeBytesSize(rcl) 67 | 68 | switch { 69 | case cl < 0: 70 | log.Debug().Msgf("%s downloaded", hrcl) 71 | case cl == rcl: 72 | log.Debug().Msgf("%s (100%%) downloaded", hrcl) 73 | default: 74 | hcl := humanizeBytesSize(cl) 75 | perc := rcl * 100 / cl 76 | log.Debug().Msgf("%s of %s (%d%%) downloaded", hrcl, hcl, perc) 77 | } 78 | } 79 | 80 | // Write satisfies io.Writer interface. 81 | func (dp *downloadProgress) Write(p []byte) (int, error) { 82 | dp.readContentLength += len(p) 83 | return len(p), nil 84 | } 85 | 86 | func humanizeBytesSize(n int) string { 87 | switch { 88 | case n < 1024: 89 | return fmt.Sprintf("%d B", n) 90 | case n < 1_048_576: 91 | return fmt.Sprintf("%.2f KiB", float64(n)/1024) 92 | case n < 1_073_741_824: 93 | return fmt.Sprintf("%.2f MiB", float64(n)/1_048_576) 94 | default: 95 | return fmt.Sprintf("%.2f GiB", float64(n)/1_073_741_824) 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /encoder/encoder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package encoder 6 | 7 | import ( 8 | "context" 9 | 10 | "github.com/nlpodyssey/rwkv" 11 | "github.com/nlpodyssey/spago/ag" 12 | "github.com/nlpodyssey/verbaflow/rwkvlm" 13 | ) 14 | 15 | type Encoder struct { 16 | model *rwkvlm.Model 17 | } 18 | 19 | type Result struct { 20 | Encoding ag.Node 21 | State rwkv.State 22 | } 23 | 24 | func New(model *rwkvlm.Model) *Encoder { 25 | return &Encoder{model: model} 26 | } 27 | 28 | func (e *Encoder) Encode(ctx context.Context, tokens []int) (Result, error) { 29 | x, s := e.model.Encode(ctx, nil, tokens...) 30 | return Result{ 31 | Encoding: ag.WaitForValue(x), 32 | State: s, 33 | }, nil 34 | } 35 | -------------------------------------------------------------------------------- /examples/prompttester/config.yaml: -------------------------------------------------------------------------------- 1 | # min_len is the minimum number of tokens to generate. 2 | min_len: 0 3 | # max_len is the maximum number of tokens to generate. 4 | max_len: 200 5 | # end_token_id is the end-of-sequence token. 6 | end_token_id: 0 7 | # skip_end_token_id when true, the end token is not added to the generated sequence. 8 | skip_end_token_id: true 9 | # temp is the temperature used to control the randomness of the generated text. 10 | temp: 1.0 11 | # top_p is the cumulative probability of the tokens to consider when sampling the next token. 12 | top_p: 0.8 13 | # top_k is the number of tokens to consider when sampling the next token. 14 | top_k: 0 15 | # use_sampling uses sampling to generate the next token. 16 | use_sampling: true 17 | # stop_sequences_ids is a list of token ids that if generated, the generation process will stop. 18 | stop_sequences_ids: 19 | - # \nQuestion: 20 | - 187 21 | - 23433 22 | - 27 23 | - # \nQ & A: 24 | - 187 25 | - 50 26 | - 708 27 | - 329 28 | - # \nQ: 29 | - 187 30 | - 50 31 | - 27 32 | - # \nA: 33 | - 187 34 | - 34 35 | - 27 -------------------------------------------------------------------------------- /examples/prompttester/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/nlpodyssey/verbaflow/examples/prompttester 2 | 3 | go 1.20 4 | 5 | replace github.com/nlpodyssey/verbaflow => ../.. 6 | 7 | require ( 8 | github.com/nlpodyssey/verbaflow v0.0.0-20230203211617-0a0020374374 9 | github.com/rs/zerolog v1.29.0 10 | github.com/urfave/cli/v2 v2.24.3 11 | google.golang.org/grpc v1.33.2 12 | gopkg.in/yaml.v3 v3.0.1 13 | ) 14 | 15 | require ( 16 | github.com/cespare/xxhash v1.1.0 // indirect 17 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 18 | github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect 19 | github.com/dgraph-io/badger/v3 v3.2103.5 // indirect 20 | github.com/dgraph-io/ristretto v0.1.1 // indirect 21 | github.com/dlclark/regexp2 v1.8.0 // indirect 22 | github.com/dustin/go-humanize v1.0.1 // indirect 23 | github.com/gogo/protobuf v1.3.2 // indirect 24 | github.com/golang/glog v1.0.0 // indirect 25 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 26 | github.com/golang/protobuf v1.5.2 // indirect 27 | github.com/golang/snappy v0.0.4 // indirect 28 | github.com/google/flatbuffers v23.1.21+incompatible // indirect 29 | github.com/klauspost/compress v1.15.15 // indirect 30 | github.com/mattn/go-colorable v0.1.13 // indirect 31 | github.com/mattn/go-isatty v0.0.17 // indirect 32 | github.com/nlpodyssey/gopickle v0.2.0 // indirect 33 | github.com/nlpodyssey/gotokenizers v0.2.0 // indirect 34 | github.com/nlpodyssey/rwkv v0.0.0-20230212203924-6a6eeeabd546 // indirect 35 | github.com/nlpodyssey/spago v1.0.2-0.20230202124145-3cffe41f485c // indirect 36 | github.com/nlpodyssey/spago/embeddings/store/diskstore v0.0.0-20230202124145-3cffe41f485c // indirect 37 | github.com/pkg/errors v0.9.1 // indirect 38 | github.com/russross/blackfriday/v2 v2.1.0 // indirect 39 | github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect 40 | go.opencensus.io v0.24.0 // indirect 41 | golang.org/x/net v0.5.0 // indirect 42 | golang.org/x/sys v0.4.0 // indirect 43 | golang.org/x/text v0.6.0 // indirect 44 | google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 // indirect 45 | google.golang.org/protobuf v1.28.1 // indirect 46 | ) 47 | -------------------------------------------------------------------------------- /examples/prompttester/go.sum: -------------------------------------------------------------------------------- 1 | cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= 2 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 3 | github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= 4 | github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= 5 | github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= 6 | github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= 7 | github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= 8 | github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= 9 | github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 10 | github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= 11 | github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 12 | github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= 13 | github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= 14 | github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= 15 | github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= 16 | github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= 17 | github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= 18 | github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= 19 | github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= 20 | github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= 21 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 22 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 23 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 24 | github.com/dgraph-io/badger/v3 v3.2103.5 h1:ylPa6qzbjYRQMU6jokoj4wzcaweHylt//CH0AKt0akg= 25 | github.com/dgraph-io/badger/v3 v3.2103.5/go.mod h1:4MPiseMeDQ3FNCYwRbbcBOGJLf5jsE0PPFzRiKjtcdw= 26 | github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= 27 | github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= 28 | github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= 29 | github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= 30 | github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= 31 | github.com/dlclark/regexp2 v1.8.0 h1:rJD5HeGIT/2b5CDk63FVCwZA3qgYElfg+oQK7uH5pfE= 32 | github.com/dlclark/regexp2 v1.8.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= 33 | github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= 34 | github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= 35 | github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= 36 | github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= 37 | github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= 38 | github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= 39 | github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= 40 | github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= 41 | github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= 42 | github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= 43 | github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= 44 | github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= 45 | github.com/golang/glog v1.0.0 h1:nfP3RFugxnNRyKgeWd4oI1nYvXpxrx8ck8ZrcizshdQ= 46 | github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= 47 | github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 48 | github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 49 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= 50 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 51 | github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= 52 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 53 | github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 54 | github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 55 | github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= 56 | github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= 57 | github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= 58 | github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= 59 | github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= 60 | github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= 61 | github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= 62 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 63 | github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= 64 | github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= 65 | github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= 66 | github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= 67 | github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= 68 | github.com/google/flatbuffers v1.12.1/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= 69 | github.com/google/flatbuffers v23.1.21+incompatible h1:bUqzx/MXCDxuS0hRJL2EfjyZL3uQrPbMocUa8zGqsTA= 70 | github.com/google/flatbuffers v23.1.21+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= 71 | github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= 72 | github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 73 | github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 74 | github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 75 | github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 76 | github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 77 | github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 78 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 79 | github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= 80 | github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 81 | github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= 82 | github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= 83 | github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= 84 | github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= 85 | github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= 86 | github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= 87 | github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= 88 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 89 | github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= 90 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 91 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 92 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 93 | github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= 94 | github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= 95 | github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= 96 | github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= 97 | github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= 98 | github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= 99 | github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= 100 | github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= 101 | github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= 102 | github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= 103 | github.com/nlpodyssey/gopickle v0.2.0 h1:4naD2DVylYJupQLbCQFdwo6yiXEmPyp+0xf5MVlrBDY= 104 | github.com/nlpodyssey/gopickle v0.2.0/go.mod h1:YIUwjJ2O7+vnBsxUN+MHAAI3N+adqEGiw+nDpwW95bY= 105 | github.com/nlpodyssey/gotokenizers v0.2.0 h1:CWx/sp9s35XMO5lT1kNXCshFGDCfPuuWdx/9JiQBsVc= 106 | github.com/nlpodyssey/gotokenizers v0.2.0/go.mod h1:SBLbuSQhpni9M7U+Ie6O46TXYN73T2Cuw/4eeYHYJ+s= 107 | github.com/nlpodyssey/rwkv v0.0.0-20230212203924-6a6eeeabd546 h1:i5zuhz9K0+EOzt9dDn65fK/uP7p/5yjdUxIsfT6IN4k= 108 | github.com/nlpodyssey/rwkv v0.0.0-20230212203924-6a6eeeabd546/go.mod h1:YXDwNBfpWqSma80h4D5qocIeR604gwEFZM2M8MDGnlI= 109 | github.com/nlpodyssey/spago v1.0.2-0.20230202124145-3cffe41f485c h1:T1Cn5J40B8zH6D+fnKRt2NmZOm1/2hsv0vkr7kqPW0Q= 110 | github.com/nlpodyssey/spago v1.0.2-0.20230202124145-3cffe41f485c/go.mod h1:myGGtjwdlurAzacv7iXHY7KVfWii6MaHZo/eCDd6iE4= 111 | github.com/nlpodyssey/spago/embeddings/store/diskstore v0.0.0-20230202124145-3cffe41f485c h1:IB7LjP9N7kE50gO3WjlfZDVGwwyDiTopC1YKpWWpwGI= 112 | github.com/nlpodyssey/spago/embeddings/store/diskstore v0.0.0-20230202124145-3cffe41f485c/go.mod h1:ElJ8F8dMxzuInbNMapI7m625LjE8OdIfGfmU/KwxPzM= 113 | github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= 114 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 115 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 116 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 117 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 118 | github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= 119 | github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= 120 | github.com/rs/zerolog v1.29.0 h1:Zes4hju04hjbvkVkOhdl2HpZa+0PmVwigmo8XoORE5w= 121 | github.com/rs/zerolog v1.29.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0= 122 | github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= 123 | github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= 124 | github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 125 | github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= 126 | github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= 127 | github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= 128 | github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= 129 | github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= 130 | github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= 131 | github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= 132 | github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= 133 | github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= 134 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 135 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 136 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 137 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 138 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 139 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 140 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 141 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 142 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 143 | github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= 144 | github.com/urfave/cli/v2 v2.24.3 h1:7Q1w8VN8yE0MJEHP06bv89PjYsN4IHWED2s1v/Zlfm0= 145 | github.com/urfave/cli/v2 v2.24.3/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= 146 | github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= 147 | github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= 148 | github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= 149 | github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= 150 | github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= 151 | go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= 152 | go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= 153 | go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= 154 | golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= 155 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 156 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 157 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 158 | golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= 159 | golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= 160 | golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= 161 | golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 162 | golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 163 | golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 164 | golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 165 | golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 166 | golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 167 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 168 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 169 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 170 | golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 171 | golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= 172 | golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= 173 | golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= 174 | golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= 175 | golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= 176 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 177 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 178 | golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 179 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 180 | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 181 | golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 182 | golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 183 | golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 184 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 185 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 186 | golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 187 | golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 188 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 189 | golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 190 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 191 | golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 192 | golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= 193 | golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 194 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 195 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 196 | golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= 197 | golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 198 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 199 | golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 200 | golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= 201 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 202 | golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= 203 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 204 | golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= 205 | golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= 206 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 207 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 208 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 209 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= 210 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 211 | google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= 212 | google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= 213 | google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= 214 | google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= 215 | google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= 216 | google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= 217 | google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= 218 | google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= 219 | google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= 220 | google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= 221 | google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= 222 | google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= 223 | google.golang.org/grpc v1.33.2 h1:EQyQC3sa8M+p6Ulc8yy9SWSS2GVwyRc83gAbG8lrl4o= 224 | google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= 225 | google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= 226 | google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= 227 | google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= 228 | google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= 229 | google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= 230 | google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 231 | google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 232 | google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 233 | google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= 234 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 235 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 236 | google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= 237 | google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= 238 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 239 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= 240 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 241 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 242 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 243 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 244 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 245 | honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= 246 | honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= 247 | -------------------------------------------------------------------------------- /examples/prompttester/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "io" 11 | "os" 12 | "os/signal" 13 | "strings" 14 | "text/template" 15 | 16 | "github.com/nlpodyssey/verbaflow" 17 | "github.com/nlpodyssey/verbaflow/api" 18 | "github.com/nlpodyssey/verbaflow/decoder" 19 | "github.com/rs/zerolog" 20 | "github.com/rs/zerolog/log" 21 | "github.com/urfave/cli/v2" 22 | "google.golang.org/grpc" 23 | "gopkg.in/yaml.v3" 24 | ) 25 | 26 | type pTemplate struct { 27 | pt *template.Template 28 | data string // the raw data of the template 29 | } 30 | 31 | var defaultPromptTemplate = pTemplate{ 32 | data: "{{.Text}}", 33 | pt: template.Must(template.New("").Parse("{{.Text}}")), 34 | } 35 | 36 | func main() { 37 | log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Level(zerolog.TraceLevel) 38 | 39 | app := &cli.App{ 40 | Name: "PromptTester", 41 | Usage: "Test how the language model responds to different prompts", 42 | Action: func(c *cli.Context) error { 43 | opts, err := decodingOptionsFromFile(c.String("dconfig")) 44 | if err != nil { 45 | return fmt.Errorf("error reading decoding options: %w", err) 46 | } 47 | log.Info().Msgf("Decoding options:\n %+v\n", opts) 48 | promptt, err := promptTemplateFromFile(c.String("promptt")) 49 | if err != nil { 50 | return fmt.Errorf("error reading prompt template: %w", err) 51 | } 52 | if err := inference(opts, promptt, c.String("endpoint")); err != nil { 53 | log.Err(err).Send() 54 | } 55 | return nil 56 | }, 57 | Flags: []cli.Flag{ 58 | &cli.StringFlag{ 59 | Name: "log-level", 60 | Usage: "set log level (trace, debug, info, warn, error, fatal, panic)", 61 | Action: func(c *cli.Context, s string) error { 62 | return setDebugLevel(s) 63 | }, 64 | Value: "trace", 65 | }, 66 | &cli.StringFlag{ 67 | Name: "dconfig", 68 | Usage: "the path to the YAML configuration file for the decoding options", 69 | Required: true, 70 | }, 71 | &cli.StringFlag{ 72 | Name: "endpoint", 73 | Usage: "The address of the gRPC server", 74 | Value: ":50051", 75 | Required: false, 76 | }, 77 | &cli.StringFlag{ 78 | Name: "promptt", 79 | Usage: `the path to the prompt template file. If not specified, the default template \n\n{{.Text}} will be used`, 80 | Required: false, 81 | }, 82 | }, 83 | } 84 | 85 | if err := app.Run(os.Args); err != nil { 86 | log.Fatal().Err(err).Send() 87 | } 88 | } 89 | 90 | func inference(opts decoder.DecodingOptions, promptt pTemplate, endpoint string) error { 91 | 92 | text, err := inputTextFromStdin() 93 | if err != nil { 94 | return err 95 | } 96 | 97 | conn, err := grpc.Dial(endpoint, grpc.WithInsecure(), grpc.WithBlock()) 98 | if err != nil { 99 | log.Fatal().Msgf("Failed to connect: %v", err) 100 | } 101 | defer conn.Close() 102 | 103 | client := api.NewLanguageModelClient(conn) 104 | 105 | log.Trace().Msgf("Building prompt from template: %q", promptt.data) 106 | input, err := buildInputPrompt(text, promptt.data) 107 | if err != nil { 108 | return err 109 | } 110 | log.Trace().Msgf("Input fields: %+v", input) 111 | prompt, err := verbaflow.BuildPromptFromTemplate(input, promptt.pt) 112 | if err != nil { 113 | return err 114 | } 115 | log.Trace().Msgf("Final prompt: %q", prompt) 116 | 117 | req := &api.TokenGenerationRequest{ 118 | Prompt: prompt, 119 | DecodingParameters: decodingOptionsToGRPC(opts), 120 | } 121 | 122 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) 123 | defer stop() 124 | 125 | stream, err := client.GenerateTokens(ctx, req) 126 | if err != nil { 127 | return fmt.Errorf("failed to call GenerateTokens: %v", err) 128 | } 129 | 130 | for { 131 | res, err := stream.Recv() 132 | 133 | if err != nil { 134 | if err == io.EOF { 135 | break 136 | } else if strings.Contains(err.Error(), "rpc error: code = Canceled desc = context canceled") { 137 | log.Debug().Msg("Context canceled.") 138 | break 139 | } else { 140 | return fmt.Errorf("failed to receive a message: %v", err) 141 | } 142 | } 143 | 144 | fmt.Printf(res.Token) 145 | } 146 | log.Debug().Msg("Done.") 147 | return nil 148 | } 149 | 150 | func inputTextFromStdin() (string, error) { 151 | info, err := os.Stdin.Stat() 152 | if err != nil { 153 | return "", fmt.Errorf("error getting standard input info: %w", err) 154 | } 155 | if info.Size() == 0 { 156 | return "", fmt.Errorf("no input provided") 157 | } 158 | data, err := io.ReadAll(os.Stdin) 159 | if err != nil { 160 | return "", fmt.Errorf("error reading from standard input: %w", err) 161 | } 162 | input := strings.TrimSuffix(string(data), "\n") 163 | log.Trace().Msgf("Input: %q", input) 164 | return input, nil 165 | } 166 | 167 | func buildInputPrompt(text, data string) (verbaflow.InputPrompt, error) { 168 | if strings.Contains(data, "{{.Question}}") { // extractive question answering 169 | log.Trace().Msgf("Splitting text into passage and question parts by \\n\\n") 170 | spl := strings.Split(text, "\n\n") 171 | if len(spl) != 2 { 172 | return verbaflow.InputPrompt{}, fmt.Errorf("required passage and question separated by \\n\\n") 173 | } 174 | return verbaflow.InputPrompt{ 175 | Text: spl[0], 176 | Question: spl[1], 177 | }, nil 178 | } 179 | 180 | return verbaflow.InputPrompt{Text: text}, nil 181 | } 182 | 183 | func setDebugLevel(debugLevel string) error { 184 | level, err := zerolog.ParseLevel(debugLevel) 185 | if err != nil { 186 | return err 187 | } 188 | log.Logger = log.Level(level) 189 | return nil 190 | } 191 | 192 | func decodingOptionsFromFile(filepath string) (decoder.DecodingOptions, error) { 193 | data, err := os.ReadFile(filepath) 194 | if err != nil { 195 | return decoder.DecodingOptions{}, fmt.Errorf("error reading configuration file: %w", err) 196 | } 197 | var opts decoder.DecodingOptions 198 | if err := yaml.Unmarshal(data, &opts); err != nil { 199 | return decoder.DecodingOptions{}, fmt.Errorf("error unmarshaling configuration file: %w", err) 200 | } 201 | return opts, nil 202 | } 203 | 204 | func promptTemplateFromFile(filepath string) (pTemplate, error) { 205 | if filepath == "" { 206 | return defaultPromptTemplate, nil 207 | } 208 | t, err := template.ParseFiles(filepath) 209 | if err != nil { 210 | return pTemplate{}, fmt.Errorf("error parsing template file: %w", err) 211 | } 212 | b, err := os.ReadFile(filepath) // just pass the file name 213 | if err != nil { 214 | return pTemplate{}, fmt.Errorf("error reading template file: %w", err) 215 | } 216 | return pTemplate{ 217 | pt: t, 218 | data: string(b), 219 | }, nil 220 | } 221 | 222 | func decodingOptionsToGRPC(opts decoder.DecodingOptions) *api.DecodingParameters { 223 | return &api.DecodingParameters{ 224 | MaxLen: int32(opts.MaxLen), 225 | MinLen: int32(opts.MinLen), 226 | Temperature: float32(opts.Temp), 227 | TopK: int32(opts.TopK), 228 | TopP: float32(opts.TopP), 229 | UseSampling: opts.UseSampling, 230 | EndTokenId: int32(opts.EndTokenID), 231 | SkipEndTokenId: opts.SkipEndTokenID, 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/classification_1.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | {{.Text}} 4 | Which section of a newspaper would this article likely appear in? 5 | A: -------------------------------------------------------------------------------- /examples/prompttester/prompts/classification_2.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | {{.Text}} 4 | Can you identify the topic of the paragraph? 5 | A: -------------------------------------------------------------------------------- /examples/prompttester/prompts/extractive_question_answering_1.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | Answer the question using the given context. 4 | Question: {{.Question}} 5 | Context: {{.Text}} 6 | Precise answer: 7 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/extractive_question_answering_2.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | Find the answer from the following context. 4 | {{.Text}} 5 | What's the answer to {{.Question}} 6 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/sentence_to_question.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | Convert statement to a question. 4 | Statement: {{.Text}} 5 | Question: -------------------------------------------------------------------------------- /examples/prompttester/prompts/simplification_1.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | Split and simplify the following sentence while retaining its full meaning: 4 | {{.Text}} 5 | Simplified version: 6 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/simplification_2.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | {{.Text}}. This sentence is hard to understand. A simpler version with equivalent meaning is the following: 4 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/simplification_3.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | {{.Text}} 4 | The above sentence is very complicated. Please provide me a simplified synonymous version consisting of multiple sentences: 5 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/summarization_1.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | {{.Text}} 4 | === 5 | Write a summary of the text above in English : 6 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/summarization_2.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | Article in English: {{.Text}} 4 | Summary in English: 5 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/summarization_3.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | {{.Text}} 4 | How would you rephrase that briefly in English? 5 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/summarization_4.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | {{.Text}} 4 | TL;DR in English: 5 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/summarization_5.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | First, read the English article below. 4 | {{.Text}} 5 | Now, please write a short abstract for it in English. 6 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/summarization_6.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | {{.Text}} 4 | Given the above abstract, write an English article for it. 5 | -------------------------------------------------------------------------------- /examples/prompttester/prompts/summarization_7.tmpl: -------------------------------------------------------------------------------- 1 | 2 | {{.Text}} 3 | I'm interested in that, but I only have a few mins. Can you give me the first 500 characters of an article about that? 4 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/nlpodyssey/verbaflow 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/nlpodyssey/gopickle v0.2.0 7 | github.com/nlpodyssey/gotokenizers v0.2.0 8 | github.com/nlpodyssey/rwkv v0.0.0-20230212203924-6a6eeeabd546 9 | github.com/nlpodyssey/spago v1.0.2-0.20230202124145-3cffe41f485c 10 | github.com/nlpodyssey/spago/embeddings/store/diskstore v0.0.0-20230202124145-3cffe41f485c 11 | github.com/rs/zerolog v1.29.0 12 | github.com/stretchr/testify v1.8.1 13 | github.com/urfave/cli/v2 v2.24.3 14 | google.golang.org/grpc v1.33.2 15 | google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.2.0 16 | google.golang.org/protobuf v1.28.1 17 | ) 18 | 19 | require ( 20 | github.com/cespare/xxhash v1.1.0 // indirect 21 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 22 | github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect 23 | github.com/davecgh/go-spew v1.1.1 // indirect 24 | github.com/dgraph-io/badger/v3 v3.2103.5 // indirect 25 | github.com/dgraph-io/ristretto v0.1.1 // indirect 26 | github.com/dlclark/regexp2 v1.8.0 // indirect 27 | github.com/dustin/go-humanize v1.0.1 // indirect 28 | github.com/gogo/protobuf v1.3.2 // indirect 29 | github.com/golang/glog v1.0.0 // indirect 30 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 31 | github.com/golang/protobuf v1.5.2 // indirect 32 | github.com/golang/snappy v0.0.4 // indirect 33 | github.com/google/flatbuffers v23.1.21+incompatible // indirect 34 | github.com/klauspost/compress v1.15.15 // indirect 35 | github.com/mattn/go-colorable v0.1.13 // indirect 36 | github.com/mattn/go-isatty v0.0.17 // indirect 37 | github.com/pkg/errors v0.9.1 // indirect 38 | github.com/pmezard/go-difflib v1.0.0 // indirect 39 | github.com/russross/blackfriday/v2 v2.1.0 // indirect 40 | github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect 41 | go.opencensus.io v0.24.0 // indirect 42 | golang.org/x/net v0.5.0 // indirect 43 | golang.org/x/sys v0.4.0 // indirect 44 | golang.org/x/text v0.6.0 // indirect 45 | google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 // indirect 46 | gopkg.in/yaml.v3 v3.0.1 // indirect 47 | ) 48 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= 2 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 3 | github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= 4 | github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= 5 | github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= 6 | github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= 7 | github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= 8 | github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= 9 | github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 10 | github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= 11 | github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 12 | github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= 13 | github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= 14 | github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= 15 | github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= 16 | github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= 17 | github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= 18 | github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= 19 | github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= 20 | github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= 21 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 22 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 23 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 24 | github.com/dgraph-io/badger/v3 v3.2103.5 h1:ylPa6qzbjYRQMU6jokoj4wzcaweHylt//CH0AKt0akg= 25 | github.com/dgraph-io/badger/v3 v3.2103.5/go.mod h1:4MPiseMeDQ3FNCYwRbbcBOGJLf5jsE0PPFzRiKjtcdw= 26 | github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= 27 | github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= 28 | github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= 29 | github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= 30 | github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= 31 | github.com/dlclark/regexp2 v1.8.0 h1:rJD5HeGIT/2b5CDk63FVCwZA3qgYElfg+oQK7uH5pfE= 32 | github.com/dlclark/regexp2 v1.8.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= 33 | github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= 34 | github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= 35 | github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= 36 | github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= 37 | github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= 38 | github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= 39 | github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= 40 | github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= 41 | github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= 42 | github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= 43 | github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= 44 | github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= 45 | github.com/golang/glog v1.0.0 h1:nfP3RFugxnNRyKgeWd4oI1nYvXpxrx8ck8ZrcizshdQ= 46 | github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= 47 | github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 48 | github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 49 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= 50 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 51 | github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= 52 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 53 | github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 54 | github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 55 | github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= 56 | github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= 57 | github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= 58 | github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= 59 | github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= 60 | github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= 61 | github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= 62 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 63 | github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= 64 | github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= 65 | github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= 66 | github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= 67 | github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= 68 | github.com/google/flatbuffers v1.12.1/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= 69 | github.com/google/flatbuffers v23.1.21+incompatible h1:bUqzx/MXCDxuS0hRJL2EfjyZL3uQrPbMocUa8zGqsTA= 70 | github.com/google/flatbuffers v23.1.21+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= 71 | github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= 72 | github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 73 | github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 74 | github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 75 | github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 76 | github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 77 | github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 78 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 79 | github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= 80 | github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 81 | github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= 82 | github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= 83 | github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= 84 | github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= 85 | github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= 86 | github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= 87 | github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= 88 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 89 | github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= 90 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 91 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 92 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 93 | github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= 94 | github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= 95 | github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= 96 | github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= 97 | github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= 98 | github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= 99 | github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= 100 | github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= 101 | github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= 102 | github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= 103 | github.com/nlpodyssey/gopickle v0.2.0 h1:4naD2DVylYJupQLbCQFdwo6yiXEmPyp+0xf5MVlrBDY= 104 | github.com/nlpodyssey/gopickle v0.2.0/go.mod h1:YIUwjJ2O7+vnBsxUN+MHAAI3N+adqEGiw+nDpwW95bY= 105 | github.com/nlpodyssey/gotokenizers v0.2.0 h1:CWx/sp9s35XMO5lT1kNXCshFGDCfPuuWdx/9JiQBsVc= 106 | github.com/nlpodyssey/gotokenizers v0.2.0/go.mod h1:SBLbuSQhpni9M7U+Ie6O46TXYN73T2Cuw/4eeYHYJ+s= 107 | github.com/nlpodyssey/rwkv v0.0.0-20230212203924-6a6eeeabd546 h1:i5zuhz9K0+EOzt9dDn65fK/uP7p/5yjdUxIsfT6IN4k= 108 | github.com/nlpodyssey/rwkv v0.0.0-20230212203924-6a6eeeabd546/go.mod h1:YXDwNBfpWqSma80h4D5qocIeR604gwEFZM2M8MDGnlI= 109 | github.com/nlpodyssey/spago v1.0.2-0.20230202124145-3cffe41f485c h1:T1Cn5J40B8zH6D+fnKRt2NmZOm1/2hsv0vkr7kqPW0Q= 110 | github.com/nlpodyssey/spago v1.0.2-0.20230202124145-3cffe41f485c/go.mod h1:myGGtjwdlurAzacv7iXHY7KVfWii6MaHZo/eCDd6iE4= 111 | github.com/nlpodyssey/spago/embeddings/store/diskstore v0.0.0-20230202124145-3cffe41f485c h1:IB7LjP9N7kE50gO3WjlfZDVGwwyDiTopC1YKpWWpwGI= 112 | github.com/nlpodyssey/spago/embeddings/store/diskstore v0.0.0-20230202124145-3cffe41f485c/go.mod h1:ElJ8F8dMxzuInbNMapI7m625LjE8OdIfGfmU/KwxPzM= 113 | github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= 114 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 115 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 116 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 117 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 118 | github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= 119 | github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= 120 | github.com/rs/zerolog v1.29.0 h1:Zes4hju04hjbvkVkOhdl2HpZa+0PmVwigmo8XoORE5w= 121 | github.com/rs/zerolog v1.29.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0= 122 | github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= 123 | github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= 124 | github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 125 | github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= 126 | github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= 127 | github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= 128 | github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= 129 | github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= 130 | github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= 131 | github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= 132 | github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= 133 | github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= 134 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 135 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 136 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 137 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 138 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 139 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 140 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 141 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 142 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 143 | github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= 144 | github.com/urfave/cli/v2 v2.24.3 h1:7Q1w8VN8yE0MJEHP06bv89PjYsN4IHWED2s1v/Zlfm0= 145 | github.com/urfave/cli/v2 v2.24.3/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= 146 | github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= 147 | github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= 148 | github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= 149 | github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= 150 | github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= 151 | go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= 152 | go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= 153 | go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= 154 | golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= 155 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 156 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 157 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 158 | golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= 159 | golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= 160 | golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= 161 | golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 162 | golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 163 | golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 164 | golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 165 | golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 166 | golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 167 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 168 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 169 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 170 | golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 171 | golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= 172 | golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= 173 | golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= 174 | golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= 175 | golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= 176 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 177 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 178 | golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 179 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 180 | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 181 | golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 182 | golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 183 | golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 184 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 185 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 186 | golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 187 | golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 188 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 189 | golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 190 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 191 | golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 192 | golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= 193 | golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 194 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 195 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 196 | golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= 197 | golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 198 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 199 | golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 200 | golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= 201 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 202 | golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= 203 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 204 | golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= 205 | golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= 206 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 207 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 208 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 209 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= 210 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 211 | google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= 212 | google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= 213 | google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= 214 | google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= 215 | google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= 216 | google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= 217 | google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= 218 | google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= 219 | google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= 220 | google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= 221 | google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= 222 | google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= 223 | google.golang.org/grpc v1.33.2 h1:EQyQC3sa8M+p6Ulc8yy9SWSS2GVwyRc83gAbG8lrl4o= 224 | google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= 225 | google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.2.0 h1:TLkBREm4nIsEcexnCjgQd5GQWaHcqMzwQV0TX9pq8S0= 226 | google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.2.0/go.mod h1:DNq5QpG7LJqD2AamLZ7zvKE0DEpVl2BSEVjFycAAjRY= 227 | google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= 228 | google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= 229 | google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= 230 | google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= 231 | google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= 232 | google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 233 | google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 234 | google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= 235 | google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= 236 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 237 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 238 | google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 239 | google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= 240 | google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= 241 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 242 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= 243 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 244 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 245 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 246 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 247 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 248 | honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= 249 | honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= 250 | -------------------------------------------------------------------------------- /go.work: -------------------------------------------------------------------------------- 1 | go 1.20 2 | 3 | use ( 4 | ./examples/prompttester 5 | . 6 | ) -------------------------------------------------------------------------------- /go.work.sum: -------------------------------------------------------------------------------- 1 | cloud.google.com/go v0.26.0 h1:e0WKqKTd5BnrG8aKH3J3h+QvEIQtSUcf2n5UZ5ZgLtQ= 2 | github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak= 3 | github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6 h1:G1bPvciwNyF7IUmKXNt9Ak3m6u9DE1rF+RmtIkBpVdA= 4 | github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= 5 | github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ= 6 | github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= 7 | github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f h1:WBZRG4aNOuI15bLRrCgN8fCq8E5Xuty6jGbmSNEvSsU= 8 | github.com/coreos/etcd v3.3.10+incompatible h1:jFneRYjIvLMLhDLCzuTuU4rSJUjRplcJQ7pD7MnhC04= 9 | github.com/coreos/go-etcd v2.0.0+incompatible h1:bXhRBIXoTm9BYHS3gE0TtQuyNZyeEMux2sDi4oo5YOo= 10 | github.com/coreos/go-semver v0.2.0 h1:3Jm3tLmsgAYcjC+4Up7hJrFBPr+n7rAqYeSw/SZazuY= 11 | github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534 h1:rtAn27wIbmOGUs7RIbVgPEjb31ehTVniDwPGXyMxm5U= 12 | github.com/cpuguy83/go-md2man v1.0.10 h1:BSKMNlYxDvnunlTymqtgONjNnaRV1sTpcovwwjF22jk= 13 | github.com/envoyproxy/go-control-plane v0.9.4 h1:rEvIZUSZ3fx39WIi3JkQqQBitGwpELBIYWeBVh6wn+E= 14 | github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A= 15 | github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= 16 | github.com/godbus/dbus/v5 v5.0.4 h1:9349emZab16e7zQvpmsbtjc18ykshndd8y2PG3sgJbA= 17 | github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8= 18 | github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= 19 | github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= 20 | github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= 21 | github.com/kisielk/errcheck v1.5.0 h1:e8esj/e4R+SAOwFwN+n3zr0nYeCyeweozKfO23MvHzY= 22 | github.com/kisielk/gotool v1.0.0 h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg= 23 | github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw= 24 | github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= 25 | github.com/magiconair/properties v1.8.0 h1:LLgXmsheXeRoUOBOjtwPQCWIYqM/LU1ayDtDePerRcY= 26 | github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= 27 | github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE= 28 | github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= 29 | github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM= 30 | github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY= 31 | github.com/russross/blackfriday v1.5.2 h1:HyvC0ARfnZBqnXwABFeSZHpKvJHJJfPz81GNueLj0oo= 32 | github.com/sahilm/fuzzy v0.1.0 h1:FzWGaw2Opqyu+794ZQ9SYifWv2EIXpwP4q8dY1kDAwI= 33 | github.com/spf13/afero v1.1.2 h1:m8/z1t7/fwjysjQRYbP0RD+bUIF/8tJwPdEZsI83ACI= 34 | github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8= 35 | github.com/spf13/cobra v0.0.5 h1:f0B+LkLX6DtmRH1isoNA9VTtNUK9K8xYd28JNNfOv/s= 36 | github.com/spf13/jwalterweatherman v1.0.0 h1:XHEdyB+EcvlqZamSM4ZOMGlc93t6AcsBEu9Gc1vn7yk= 37 | github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= 38 | github.com/spf13/viper v1.3.2 h1:VUFqw5KcqRf7i70GOzW7N+Q7+gxVBkSSqiXB12+JQ4M= 39 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= 40 | github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8 h1:3SVOIvH7Ae1KRYyQWRjXWJEA9sS/c/pjvH++55Gr648= 41 | github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77 h1:ESFSdwYZvkeru3RtdrYueztKhOBCSAAzS4Gf+k0tEow= 42 | github.com/yuin/goldmark v1.2.1 h1:ruQGxdhGHe7FWOJPT0mKs5+pD2Xs1Bm/kdGlHO04FmM= 43 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= 44 | golang.org/x/exp v0.0.0-20190121172915-509febef88a4 h1:c2HOrn5iMezYjSlGPncknSEr/8x5LELb/ilJbXi9DEA= 45 | golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3 h1:XQyxROzUlZH+WIQwySDgnISgOivlhjIEwaQaJEJrrN0= 46 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= 47 | golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be h1:vEDujvNQGv4jgYKudGeI/+DAX4Jffq6hpD55MmoEvKs= 48 | golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= 49 | google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= 50 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 51 | honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc h1:/hemPrYIhOhy8zYrNj+069zDB68us2sMGsfkFJO0iZs= 52 | -------------------------------------------------------------------------------- /prompt.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package verbaflow 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "text/template" 11 | ) 12 | 13 | // InputPrompt is the input for the prompt generation. 14 | type InputPrompt struct { 15 | Text string `json:"text"` 16 | Question string `json:"question,omitempty"` 17 | TargetLanguage string `json:"target_language,omitempty"` 18 | } 19 | 20 | // BuildPromptFromTemplateFile builds a prompt applying the given input to the template file. 21 | func BuildPromptFromTemplateFile(input InputPrompt, filename string) (string, error) { 22 | pt, err := template.ParseFiles(filename) 23 | if err != nil { 24 | return "", fmt.Errorf("unable to read the template file: %w", err) 25 | } 26 | return BuildPromptFromTemplate(input, pt) 27 | } 28 | 29 | // BuildPromptFromTemplate builds a prompt applying the given input to the template. 30 | func BuildPromptFromTemplate(input InputPrompt, pt *template.Template) (string, error) { 31 | result := new(bytes.Buffer) 32 | err := pt.Execute(result, input) 33 | if err != nil { 34 | return "", fmt.Errorf("unable to execute the template: %w", err) 35 | } 36 | return result.String(), nil 37 | } 38 | -------------------------------------------------------------------------------- /rwkvlm/converter.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package rwkvlm 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | "os" 11 | "path/filepath" 12 | "strconv" 13 | "strings" 14 | 15 | "github.com/nlpodyssey/gopickle/pytorch" 16 | "github.com/nlpodyssey/gopickle/types" 17 | "github.com/nlpodyssey/rwkv" 18 | "github.com/nlpodyssey/spago/embeddings" 19 | "github.com/nlpodyssey/spago/embeddings/store" 20 | "github.com/nlpodyssey/spago/embeddings/store/diskstore" 21 | "github.com/nlpodyssey/spago/mat" 22 | "github.com/nlpodyssey/spago/mat/float" 23 | "github.com/nlpodyssey/spago/nn" 24 | "github.com/nlpodyssey/spago/nn/normalization/layernorm" 25 | "github.com/rs/zerolog/log" 26 | ) 27 | 28 | const ( 29 | DefaultPyModelFilename = "pytorch_model.pt" 30 | DefaultOutputFilename = "spago_model.bin" 31 | DefaultEmbeddingRepoPath = "embeddings" 32 | 33 | DefaultLayerNormEps = 1e-5 34 | ) 35 | 36 | type ConverterConfig struct { 37 | // The path to the directory where the models will be read from and written to. 38 | ModelDir string 39 | // The path to the input model file (default "pytorch_model.pt") 40 | PyModelFilename string 41 | // The path to the output model file (default "spago_model.bin") 42 | GoModelFilename string 43 | // The path to the embedding repository (default "embeddings") 44 | EmbeddingRepoPath string 45 | // If true, overwrite the model file if it already exists (default "false") 46 | OverwriteIfExist bool 47 | } 48 | 49 | // ConvertPickledModelToRWKVLM converts a PyTorch model to a RWKVLM model. 50 | // It expects a configuration file "config.json" in the same directory as the model file containing the model configuration. 51 | func ConvertPickledModelToRWKVLM[T float.DType](config ConverterConfig) error { 52 | if config.PyModelFilename == "" { 53 | config.PyModelFilename = DefaultPyModelFilename 54 | } 55 | if config.GoModelFilename == "" { 56 | config.GoModelFilename = DefaultOutputFilename 57 | } 58 | if config.EmbeddingRepoPath == "" { 59 | config.EmbeddingRepoPath = DefaultEmbeddingRepoPath 60 | } 61 | 62 | outputFilename := filepath.Join(config.ModelDir, config.GoModelFilename) 63 | 64 | if !config.OverwriteIfExist && fileExists(outputFilename) { 65 | log.Debug().Str("model", outputFilename).Msg("Model file already exists, skipping conversion") 66 | return nil 67 | } 68 | 69 | configFilename := filepath.Join(config.ModelDir, "config.json") 70 | modelConfig, err := LoadConfig(configFilename) 71 | if err != nil { 72 | return fmt.Errorf("failed to load config file %q: %w", configFilename, err) 73 | } 74 | 75 | inFilename := filepath.Join(config.ModelDir, config.PyModelFilename) 76 | embRepoPath := filepath.Join(config.ModelDir, config.EmbeddingRepoPath) 77 | conv := newConverter[T](modelConfig, inFilename, outputFilename, embRepoPath) 78 | err = conv.run() 79 | if err != nil { 80 | return fmt.Errorf("model conversion failed: %w", err) 81 | } 82 | return nil 83 | } 84 | 85 | func fileExists(name string) bool { 86 | info, err := os.Stat(name) 87 | return err == nil && !info.IsDir() 88 | } 89 | 90 | type converter[T float.DType] struct { 91 | model *Model 92 | inFilename string 93 | outFilename string 94 | embRepoPath string 95 | params paramsMap 96 | } 97 | 98 | func newConverter[T float.DType](conf Config, inFilename, outFilename, embRepoPath string) *converter[T] { 99 | return &converter[T]{ 100 | model: &Model{Config: conf}, 101 | inFilename: inFilename, 102 | outFilename: outFilename, 103 | embRepoPath: embRepoPath, 104 | } 105 | } 106 | 107 | func (c *converter[T]) run() error { 108 | funcs := []func() error{ 109 | c.loadTorchModelParams, 110 | c.convEmbeddings, 111 | c.convLinear, 112 | c.convRootLayerNorm, 113 | c.convBlocks, 114 | c.dumpModel, 115 | } 116 | for _, fn := range funcs { 117 | if err := fn(); err != nil { 118 | return err 119 | } 120 | } 121 | return nil 122 | } 123 | 124 | func (c *converter[T]) dumpModel() (err error) { 125 | return Dump(c.model, c.outFilename) 126 | } 127 | 128 | func (c *converter[T]) convRootLayerNorm() (err error) { 129 | c.model.LN, err = c.convLayerNorm("ln_out", c.params) 130 | if err != nil { 131 | err = fmt.Errorf("failed to convert layer-norm: %w", err) 132 | } 133 | return 134 | } 135 | 136 | func (c *converter[T]) convEmbeddings() error { 137 | embWeight, err := c.params.fetch("emb.weight") 138 | if err != nil { 139 | return err 140 | } 141 | 142 | vecs, err := c.tensorToVectors(embWeight) 143 | if err != nil { 144 | return fmt.Errorf("failed to convert embeddings: %w", err) 145 | } 146 | 147 | if vs := c.model.Config.VocabSize; vs == 0 { 148 | c.model.Config.VocabSize = len(vecs) 149 | } else if len(vecs) != vs { 150 | return fmt.Errorf("expected embedding vectors to match vocabulary size %d, actual %d", vs, len(vecs)) 151 | } 152 | 153 | if dm := c.model.Config.DModel; dm == 0 { 154 | c.model.Config.DModel = vecs[0].Size() 155 | } else if dm != vecs[0].Size() { 156 | return fmt.Errorf("expected embedding vectors to match configured size %d, actual %d", dm, vecs[0].Size()) 157 | } 158 | 159 | return c.withEmbRepo(func(repo store.Repository) { 160 | embs := c.newEmbeddings(repo) 161 | for i, vec := range vecs { 162 | embs.Tokens.EmbeddingFast(i).ReplaceValue(vec) 163 | } 164 | c.model.Embeddings = embs 165 | }) 166 | } 167 | 168 | func (c *converter[T]) newEmbeddings(repo store.Repository) *Embeddings { 169 | return NewEmbeddings[T](embeddings.Config{ 170 | Size: c.model.Config.DModel, 171 | StoreName: c.model.Config.EmbeddingsStoreName, 172 | Trainable: false, 173 | }, repo) 174 | } 175 | 176 | func (c *converter[T]) withEmbRepo(fn func(store.Repository)) (err error) { 177 | repo, err := diskstore.NewRepository(c.embRepoPath, diskstore.ReadWriteMode) 178 | if err != nil { 179 | return fmt.Errorf("failed to open embedding repository: %w", err) 180 | } 181 | defer func() { 182 | if e := repo.Close(); e != nil && err == nil { 183 | err = fmt.Errorf("failed to close embedding repository: %w", e) 184 | } 185 | }() 186 | if err = repo.DropAll(); err != nil { 187 | err = fmt.Errorf("failed to drop embedding repository data: %w", err) 188 | } 189 | fn(repo) 190 | return nil 191 | } 192 | 193 | func (c *converter[T]) convLinear() error { 194 | headWeight, err := c.params.fetch("head.weight") 195 | if err != nil { 196 | return err 197 | } 198 | 199 | m, err := c.tensorToMatrix(headWeight) 200 | if err != nil { 201 | return fmt.Errorf("failed to convert head-weight/linear: %w", err) 202 | } 203 | 204 | if vs := c.model.Config.VocabSize; m.Rows() != vs { 205 | return fmt.Errorf("expected head-weight/linear rows to match vocabulary size %d, actual %d", vs, m.Rows()) 206 | } 207 | if dm := c.model.Config.DModel; m.Columns() != dm { 208 | return fmt.Errorf("expected head-weight/linear columns to match DModel %d, actual %d", dm, m.Columns()) 209 | } 210 | 211 | c.model.Linear = nn.NewParam(m) 212 | return nil 213 | } 214 | 215 | func (c *converter[T]) convBlocks() error { 216 | allBlocksParams := c.params.fetchPrefixed("blocks.") 217 | numBlocks, err := countBlocks(allBlocksParams) 218 | if err != nil { 219 | return err 220 | } 221 | if numBlocks == 0 { 222 | return fmt.Errorf("no blocks/layers found in parameters") 223 | } 224 | if hl := c.model.Config.NumHiddenLayers; hl == 0 { 225 | c.model.Config.NumHiddenLayers = numBlocks 226 | } else if hl != numBlocks { 227 | return fmt.Errorf("expected %d blocks/layers, actual %d", hl, numBlocks) 228 | } 229 | 230 | conf := rwkv.Config{ 231 | DModel: c.model.Config.DModel, 232 | NumLayers: c.model.Config.NumHiddenLayers, 233 | RescaleLayer: c.model.Config.RescaleLayer, 234 | } 235 | 236 | layers := make([]*rwkv.Layer, numBlocks) 237 | for i := range layers { 238 | blockParams := allBlocksParams.fetchPrefixed(fmt.Sprintf("%d.", i)) 239 | layers[i], err = c.convBlock(i, conf, blockParams) 240 | if err != nil { 241 | return fmt.Errorf("failed to convert block/layer %d: %w", i, err) 242 | } 243 | } 244 | 245 | c.model.Encoder = &rwkv.Model{ 246 | Config: conf, 247 | Layers: layers, 248 | } 249 | return nil 250 | } 251 | 252 | func (c *converter[T]) convBlock(id int, conf rwkv.Config, params paramsMap) (_ *rwkv.Layer, err error) { 253 | layer := &rwkv.Layer{ 254 | ID: id, 255 | } 256 | 257 | layer.ChanMix, err = c.convChanMix(id, params.fetchPrefixed("ffn.")) 258 | if err != nil { 259 | return nil, fmt.Errorf("failed to convert ffn/channel-mix: %w", err) 260 | } 261 | 262 | layer.TimeMix, err = c.convTimeMix(id, conf, params.fetchPrefixed("att.")) 263 | if err != nil { 264 | return nil, fmt.Errorf("failed to convert att/time-mix: %w", err) 265 | } 266 | 267 | if id == 0 { 268 | layer.LN0, err = c.convLayerNorm("ln0", params) 269 | if err != nil { 270 | return nil, fmt.Errorf("failed to convert layer-norm 0: %w", err) 271 | } 272 | } 273 | 274 | layer.LN1, err = c.convLayerNorm("ln1", params) 275 | if err != nil { 276 | return nil, fmt.Errorf("failed to convert layer-norm 1: %w", err) 277 | } 278 | 279 | layer.LN2, err = c.convLayerNorm("ln2", params) 280 | if err != nil { 281 | return nil, fmt.Errorf("failed to convert layer-norm 2: %w", err) 282 | } 283 | 284 | return layer, nil 285 | } 286 | 287 | func (c *converter[T]) convChanMix(id int, params paramsMap) (*rwkv.ChannelMix, error) { 288 | dm := c.model.Config.DModel 289 | outScale := math.Pow(2, float64(id/c.model.Config.RescaleLayer)) 290 | 291 | key, err := c.fetchParamToMatrix(params, "key.weight", [2]int{dm * 4, dm}) 292 | if err != nil { 293 | return nil, fmt.Errorf("failed to convert key weight: %w", err) 294 | } 295 | 296 | receptance, err := c.fetchParamToMatrix(params, "receptance.weight", [2]int{dm, dm}) 297 | if err != nil { 298 | return nil, fmt.Errorf("failed to convert receptance weight: %w", err) 299 | } 300 | 301 | value, err := c.fetchParamToMatrix(params, "value.weight", [2]int{dm, dm * 4}) 302 | if err != nil { 303 | return nil, fmt.Errorf("failed to convert value weight: %w", err) 304 | } 305 | if outScale != 1 { 306 | value.ProdScalarInPlace(1 / outScale) 307 | } 308 | 309 | tmk, err := c.fetchParamToSqueezedVector(params, "time_mix_k", dm) 310 | if err != nil { 311 | return nil, fmt.Errorf("failed to convert time-mix-k: %w", err) 312 | } 313 | 314 | tmr, err := c.fetchParamToSqueezedVector(params, "time_mix_r", dm) 315 | if err != nil { 316 | return nil, fmt.Errorf("failed to convert time-mix-r: %w", err) 317 | } 318 | 319 | return &rwkv.ChannelMix{ 320 | Key: nn.NewParam(key), 321 | Value: nn.NewParam(value), 322 | Receptance: nn.NewParam(receptance), 323 | TimeMixK: nn.NewParam(tmk), 324 | TimeMixR: nn.NewParam(tmr), 325 | }, nil 326 | } 327 | 328 | func (c *converter[T]) convTimeMix(id int, conf rwkv.Config, params paramsMap) (*rwkv.TimeMix, error) { 329 | dm := c.model.Config.DModel 330 | outScale := math.Pow(2, float64(id/c.model.Config.RescaleLayer)) 331 | 332 | key, err := c.fetchParamToMatrix(params, "key.weight", [2]int{dm, dm}) 333 | if err != nil { 334 | return nil, fmt.Errorf("failed to convert key weight: %w", err) 335 | } 336 | 337 | receptance, err := c.fetchParamToMatrix(params, "receptance.weight", [2]int{dm, dm}) 338 | if err != nil { 339 | return nil, fmt.Errorf("failed to convert receptance weight: %w", err) 340 | } 341 | 342 | output, err := c.fetchParamToMatrix(params, "output.weight", [2]int{dm, dm}) 343 | if err != nil { 344 | return nil, fmt.Errorf("failed to convert output weight: %w", err) 345 | } 346 | if outScale != 1 { 347 | output.ProdScalarInPlace(1 / outScale) 348 | } 349 | 350 | value, err := c.fetchParamToMatrix(params, "value.weight", [2]int{dm, dm}) 351 | if err != nil { 352 | return nil, fmt.Errorf("failed to convert value weight: %w", err) 353 | } 354 | 355 | tDecay, err := c.fetchParamToSqueezedVector(params, "time_decay", dm) 356 | if err != nil { 357 | return nil, fmt.Errorf("failed to convert time-decay: %w", err) 358 | } 359 | tDecay = tDecay.Exp().ProdScalarInPlace(-1) 360 | 361 | tFirst, err := c.fetchParamToSqueezedVector(params, "time_first", dm) 362 | if err != nil { 363 | return nil, fmt.Errorf("failed to convert time-first: %w", err) 364 | } 365 | 366 | tmk, err := c.fetchParamToSqueezedVector(params, "time_mix_k", dm) 367 | if err != nil { 368 | return nil, fmt.Errorf("failed to convert time-mix-k: %w", err) 369 | } 370 | 371 | tmr, err := c.fetchParamToSqueezedVector(params, "time_mix_r", dm) 372 | if err != nil { 373 | return nil, fmt.Errorf("failed to convert time-mix-r: %w", err) 374 | } 375 | 376 | tmv, err := c.fetchParamToSqueezedVector(params, "time_mix_v", dm) 377 | if err != nil { 378 | return nil, fmt.Errorf("failed to convert time-mix-v: %w", err) 379 | } 380 | 381 | return &rwkv.TimeMix{ 382 | Config: conf, 383 | Key: nn.NewParam(key), 384 | Value: nn.NewParam(value), 385 | Receptance: nn.NewParam(receptance), 386 | Output: nn.NewParam(output), 387 | TimeDecay: nn.NewParam(tDecay), 388 | TimeFirst: nn.NewParam(tFirst), 389 | TimeMixK: nn.NewParam(tmk), 390 | TimeMixV: nn.NewParam(tmv), 391 | TimeMixR: nn.NewParam(tmr), 392 | }, nil 393 | } 394 | 395 | func (c *converter[T]) convLayerNorm(name string, params paramsMap) (*layernorm.Model, error) { 396 | dm := c.model.Config.DModel 397 | 398 | w, err := c.fetchParamToVector(params, name+".weight", dm) 399 | if err != nil { 400 | return nil, fmt.Errorf("failed to convert layer-norm weight: %w", err) 401 | } 402 | 403 | b, err := c.fetchParamToVector(params, name+".bias", dm) 404 | if err != nil { 405 | return nil, fmt.Errorf("failed to convert layer-norm bias: %w", err) 406 | } 407 | 408 | return &layernorm.Model{ 409 | W: nn.NewParam(w), 410 | B: nn.NewParam(b), 411 | Eps: nn.Const[T](DefaultLayerNormEps), 412 | }, nil 413 | } 414 | 415 | func (c *converter[T]) loadTorchModelParams() error { 416 | torchModel, err := pytorch.Load(c.inFilename) 417 | if err != nil { 418 | return fmt.Errorf("failed to load torch model %q: %w", c.inFilename, err) 419 | } 420 | c.params, err = makeParamsMap(torchModel) 421 | if err != nil { 422 | return fmt.Errorf("failed to read model params: %w", err) 423 | } 424 | return nil 425 | } 426 | 427 | func (c *converter[T]) tensorToVectors(t *pytorch.Tensor) ([]mat.Matrix, error) { 428 | if len(t.Size) != 2 { 429 | return nil, fmt.Errorf("expected 2 dimensions, actual %d", len(t.Size)) 430 | } 431 | 432 | data, err := c.tensorData(t) 433 | if err != nil { 434 | return nil, err 435 | } 436 | 437 | rows := t.Size[0] 438 | cols := t.Size[1] 439 | 440 | vecs := make([]mat.Matrix, rows) 441 | for i := range vecs { 442 | d := data[i*cols : (i*cols)+cols] 443 | vecs[i] = mat.NewVecDense[T](c.castMatrixData(d)) 444 | } 445 | 446 | return vecs, nil 447 | } 448 | 449 | func (c *converter[T]) tensorToMatrix(t *pytorch.Tensor) (mat.Matrix, error) { 450 | if len(t.Size) != 2 { 451 | return nil, fmt.Errorf("expected 2 dimensions, actual %d", len(t.Size)) 452 | } 453 | 454 | data, err := c.tensorData(t) 455 | if err != nil { 456 | return nil, err 457 | } 458 | 459 | return mat.NewDense[T](t.Size[0], t.Size[1], c.castMatrixData(data)), nil 460 | } 461 | 462 | func (c *converter[T]) tensorToVector(t *pytorch.Tensor) (mat.Matrix, error) { 463 | if len(t.Size) != 1 { 464 | return nil, fmt.Errorf("expected 1 dimension, actual %d", len(t.Size)) 465 | } 466 | 467 | data, err := c.tensorData(t) 468 | if err != nil { 469 | return nil, err 470 | } 471 | 472 | return mat.NewVecDense[T](c.castMatrixData(data)), nil 473 | } 474 | 475 | func (c *converter[T]) tensorToSqueezedVector(t *pytorch.Tensor) (mat.Matrix, error) { 476 | data, err := c.tensorData(t) 477 | if err != nil { 478 | return nil, err 479 | } 480 | return mat.NewVecDense[T](c.castMatrixData(data)), nil 481 | } 482 | 483 | func (c *converter[T]) castMatrixData(d []float32) []T { 484 | return float.SliceValueOf[T](float.SliceInterface(d)) 485 | } 486 | 487 | func (c *converter[T]) tensorData(t *pytorch.Tensor) ([]float32, error) { 488 | st, ok := t.Source.(*pytorch.BFloat16Storage) 489 | if !ok { 490 | return nil, fmt.Errorf("only BFloat16Storage is supported, actual %T", t.Source) 491 | } 492 | size := tensorDataSize(t) 493 | return st.Data[t.StorageOffset : t.StorageOffset+size], nil 494 | } 495 | 496 | func (c *converter[T]) fetchParamToVector(params paramsMap, name string, expectedSize int) (mat.Matrix, error) { 497 | t, err := params.fetch(name) 498 | if err != nil { 499 | return nil, err 500 | } 501 | v, err := c.tensorToVector(t) 502 | if err != nil { 503 | return nil, err 504 | } 505 | if v.Size() != expectedSize { 506 | return nil, fmt.Errorf("expected vector size %d, actual %d", expectedSize, v.Size()) 507 | } 508 | return v, nil 509 | } 510 | 511 | func (c *converter[T]) fetchParamToSqueezedVector(params paramsMap, name string, expectedSize int) (mat.Matrix, error) { 512 | t, err := params.fetch(name) 513 | if err != nil { 514 | return nil, err 515 | } 516 | v, err := c.tensorToSqueezedVector(t) 517 | if err != nil { 518 | return nil, err 519 | } 520 | if v.Size() != expectedSize { 521 | return nil, fmt.Errorf("expected squeezed vector size %d, actual %d", expectedSize, v.Size()) 522 | } 523 | return v, nil 524 | } 525 | 526 | func (c *converter[T]) fetchParamToMatrix(params paramsMap, name string, expectedSize [2]int) (mat.Matrix, error) { 527 | t, err := params.fetch(name) 528 | if err != nil { 529 | return nil, err 530 | } 531 | m, err := c.tensorToMatrix(t) 532 | if err != nil { 533 | return nil, err 534 | } 535 | if m.Rows() != expectedSize[0] || m.Columns() != expectedSize[1] { 536 | return nil, fmt.Errorf("expected matrix size %dx%d, actual %dx%d", 537 | expectedSize[0], expectedSize[1], m.Rows(), m.Columns()) 538 | } 539 | return m, nil 540 | } 541 | 542 | func countBlocks(params paramsMap) (int, error) { 543 | max := 0 544 | for k := range params { 545 | before, _, ok := strings.Cut(k, ".") 546 | if !ok { 547 | return 0, fmt.Errorf("block/layer parameter names expected to start with number, actual name %q", k) 548 | } 549 | num, err := strconv.Atoi(before) 550 | if err != nil { 551 | return 0, fmt.Errorf("block/layer parameter names expected to start with number, actual name %q: %w", k, err) 552 | } 553 | if num > max { 554 | max = num 555 | } 556 | } 557 | return max + 1, nil 558 | } 559 | 560 | func tensorDataSize(t *pytorch.Tensor) int { 561 | size := t.Size[0] 562 | for _, v := range t.Size[1:] { 563 | size *= v 564 | } 565 | return size 566 | } 567 | 568 | func cast[T any](v any) (t T, _ error) { 569 | t, ok := v.(T) 570 | if !ok { 571 | return t, fmt.Errorf("type assertion failed: expected %T, actual %T", t, v) 572 | } 573 | return 574 | } 575 | 576 | type paramsMap map[string]*pytorch.Tensor 577 | 578 | func makeParamsMap(torchModel any) (paramsMap, error) { 579 | od, err := cast[*types.OrderedDict](torchModel) 580 | if err != nil { 581 | return nil, err 582 | } 583 | 584 | params := make(paramsMap, od.Len()) 585 | 586 | for k, item := range od.Map { 587 | name, err := cast[string](k) 588 | if err != nil { 589 | return nil, fmt.Errorf("wrong param name type: %w", err) 590 | } 591 | tensor, err := cast[*pytorch.Tensor](item.Value) 592 | if err != nil { 593 | return nil, fmt.Errorf("wrong value type for param %q: %w", name, err) 594 | } 595 | params[name] = tensor 596 | } 597 | 598 | return params, nil 599 | } 600 | 601 | // fetchParam gets a value from params by its name, removing the entry 602 | // from the map. 603 | func (p paramsMap) fetch(name string) (*pytorch.Tensor, error) { 604 | t, ok := p[name] 605 | if !ok { 606 | return nil, fmt.Errorf("parameter %q not found", name) 607 | } 608 | delete(p, name) 609 | return t, nil 610 | } 611 | 612 | func (p paramsMap) fetchPrefixed(prefix string) paramsMap { 613 | out := make(paramsMap, len(p)) 614 | for k, v := range p { 615 | if after, ok := strings.CutPrefix(k, prefix); ok { 616 | out[after] = v 617 | delete(p, k) 618 | } 619 | } 620 | return out 621 | } 622 | -------------------------------------------------------------------------------- /rwkvlm/embeddings.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package rwkvlm 6 | 7 | import ( 8 | "encoding/gob" 9 | 10 | "github.com/nlpodyssey/spago/ag" 11 | emb "github.com/nlpodyssey/spago/embeddings" 12 | "github.com/nlpodyssey/spago/embeddings/store" 13 | "github.com/nlpodyssey/spago/mat/float" 14 | "github.com/nlpodyssey/spago/nn" 15 | ) 16 | 17 | // Embeddings embeds the token embeddings. 18 | type Embeddings struct { 19 | nn.Module 20 | Tokens *emb.Model[int] 21 | Config Config 22 | } 23 | 24 | func init() { 25 | gob.Register(&Embeddings{}) 26 | } 27 | 28 | // NewEmbeddings returns a new embedding module. 29 | func NewEmbeddings[T float.DType](c emb.Config, repo store.Repository) *Embeddings { 30 | return &Embeddings{ 31 | Tokens: emb.New[T, int](c, repo), 32 | } 33 | } 34 | 35 | // Encode performs the input encoding. 36 | func (m *Embeddings) Encode(tokens []int) []ag.Node { 37 | return m.Tokens.Encode(tokens) 38 | } 39 | -------------------------------------------------------------------------------- /rwkvlm/gob.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package rwkvlm 6 | 7 | import ( 8 | "bufio" 9 | "encoding/gob" 10 | "io" 11 | "os" 12 | 13 | "github.com/nlpodyssey/rwkv" 14 | "github.com/nlpodyssey/spago/nn" 15 | "github.com/nlpodyssey/spago/nn/normalization/layernorm" 16 | ) 17 | 18 | func gobEncode(obj *Model, w io.Writer) error { 19 | bw := bufio.NewWriter(w) 20 | encoder := gob.NewEncoder(bw) 21 | 22 | for _, chunk := range getChunksForGobEncoding(obj) { 23 | if err := encoder.Encode(chunk); err != nil { 24 | return err 25 | } 26 | if err := bw.Flush(); err != nil { 27 | return err 28 | } 29 | } 30 | return nil 31 | } 32 | 33 | func getChunksForGobEncoding(obj *Model) []interface{} { 34 | chunks := []interface{}{ 35 | obj.Config, 36 | obj.Embeddings, 37 | obj.LN, 38 | obj.Linear.(*nn.BaseParam), 39 | obj.Encoder.Config, 40 | } 41 | for _, layer := range obj.Encoder.Layers { 42 | chunks = append(chunks, layer) 43 | } 44 | return chunks 45 | } 46 | 47 | // loadFromFile uses Gob to deserialize objects files to memory. 48 | // See gobDecoding for further details. 49 | func loadFromFile(filename string) (*Model, error) { 50 | f, err := os.Open(filename) 51 | if err != nil { 52 | return nil, err 53 | } 54 | defer func() { 55 | if e := f.Close(); e != nil && err == nil { 56 | err = e 57 | } 58 | }() 59 | return gobDecoding(f) 60 | } 61 | 62 | func gobDecoding(r io.Reader) (*Model, error) { 63 | obj := &Model{ 64 | LN: &layernorm.Model{}, 65 | Linear: &nn.BaseParam{}, 66 | Encoder: &rwkv.Model{}, 67 | } 68 | 69 | br := bufio.NewReader(r) 70 | decoder := gob.NewDecoder(br) 71 | 72 | w := nn.BaseParam{} 73 | 74 | if err := decoder.Decode(&obj.Config); err != nil { 75 | return nil, err 76 | } 77 | if err := decoder.Decode(&obj.Embeddings); err != nil { 78 | return nil, err 79 | } 80 | if err := decoder.Decode(&obj.LN); err != nil { 81 | return nil, err 82 | } 83 | if err := decoder.Decode(&w); err != nil { 84 | return nil, err 85 | } 86 | obj.Linear = &w 87 | if err := decoder.Decode(&obj.Encoder.Config); err != nil { 88 | return nil, err 89 | } 90 | 91 | obj.Encoder.Layers = make([]*rwkv.Layer, obj.Config.NumHiddenLayers) 92 | for i := range obj.Encoder.Layers { 93 | if err := decoder.Decode(&obj.Encoder.Layers[i]); err != nil { 94 | return nil, err 95 | } 96 | } 97 | 98 | return obj, nil 99 | } 100 | -------------------------------------------------------------------------------- /rwkvlm/rwkvlm.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package rwkvlm 6 | 7 | import ( 8 | "context" 9 | "encoding/gob" 10 | "encoding/json" 11 | "fmt" 12 | "os" 13 | "path/filepath" 14 | 15 | "github.com/nlpodyssey/rwkv" 16 | "github.com/nlpodyssey/spago/ag" 17 | "github.com/nlpodyssey/spago/embeddings" 18 | "github.com/nlpodyssey/spago/embeddings/store" 19 | "github.com/nlpodyssey/spago/embeddings/store/diskstore" 20 | "github.com/nlpodyssey/spago/mat" 21 | "github.com/nlpodyssey/spago/mat/float" 22 | "github.com/nlpodyssey/spago/nn" 23 | "github.com/nlpodyssey/spago/nn/normalization/layernorm" 24 | "github.com/rs/zerolog/log" 25 | ) 26 | 27 | type Model struct { 28 | nn.Module 29 | Embeddings *Embeddings 30 | Encoder *rwkv.Model 31 | LN *layernorm.Model 32 | Linear nn.Param `spago:"type:weights"` 33 | Config Config 34 | } 35 | 36 | type Config struct { 37 | // DModel primarily corresponds to the embedding size. 38 | // 39 | // When converting a torch model, it can be left zero, letting the 40 | // process deduce the value automatically. 41 | DModel int `json:"d_model"` 42 | // NumHiddenLayers is the number of hidden layers. 43 | // 44 | // When converting a torch model, it can be left zero, letting the 45 | // process deduce the value automatically. 46 | NumHiddenLayers int `json:"num_hidden_layers"` 47 | // VocabSize is the vocabulary size. 48 | // 49 | // When converting a torch model, it can be left zero, letting the 50 | // process deduce the value automatically. 51 | VocabSize int `json:"vocab_size"` 52 | RescaleLayer int `json:"rescale_layer"` 53 | EmbeddingsStoreName string `json:"embeddings_store_name"` 54 | } 55 | 56 | func LoadConfig(filePath string) (Config, error) { 57 | file, err := os.Open(filePath) 58 | if err != nil { 59 | return Config{}, err 60 | } 61 | defer file.Close() 62 | 63 | var config Config 64 | jsonDecoder := json.NewDecoder(file) 65 | if err := jsonDecoder.Decode(&config); err != nil { 66 | return Config{}, err 67 | } 68 | return config, nil 69 | } 70 | 71 | func init() { 72 | gob.Register(&Model{}) 73 | } 74 | 75 | func New[T float.DType](c Config, repo store.Repository) *Model { 76 | return &Model{ 77 | Config: c, 78 | Encoder: rwkv.New[T](rwkv.Config{ 79 | DModel: c.DModel, 80 | NumLayers: c.NumHiddenLayers, 81 | RescaleLayer: c.RescaleLayer, 82 | }), 83 | LN: layernorm.New[T](c.DModel, 1e-6), 84 | Linear: nn.NewParam(mat.NewEmptyDense[T](c.VocabSize, c.DModel)), 85 | Embeddings: NewEmbeddings[T](embeddings.Config{ 86 | Size: c.DModel, 87 | StoreName: c.EmbeddingsStoreName, 88 | Trainable: false, 89 | }, repo), 90 | } 91 | } 92 | 93 | // Load loads a pre-trained model from the given path. 94 | func Load(dir string) (*Model, error) { 95 | m, err := loadFromFile(filepath.Join(dir, DefaultOutputFilename)) 96 | if err != nil { 97 | return nil, err 98 | } 99 | return m, nil 100 | } 101 | 102 | // Dump saves the Model to a file. 103 | // See gobEncode for further details. 104 | func Dump(obj *Model, filename string) error { 105 | f, err := os.Create(filename) 106 | if err != nil { 107 | return fmt.Errorf("failed to open model dump file %q for writing: %w", filename, err) 108 | } 109 | defer func() { 110 | if e := f.Close(); e != nil && err == nil { 111 | err = fmt.Errorf("failed to close model dump file %q: %w", filename, e) 112 | } 113 | }() 114 | if err = gobEncode(obj, f); err != nil { 115 | return fmt.Errorf("failed to encode model dump: %w", err) 116 | } 117 | return nil 118 | } 119 | 120 | // ApplyEmbeddings sets the embeddings of the model. 121 | func (m *Model) ApplyEmbeddings(repo *diskstore.Repository) (err error) { 122 | nn.Apply(m, func(model nn.Model, name string) { 123 | switch em := model.(type) { 124 | case *embeddings.Model[[]byte], *embeddings.Model[int], *embeddings.Model[string]: 125 | if e := em.(interface { 126 | UseRepository(repo store.Repository) error 127 | }).UseRepository(repo); e != nil && err == nil { 128 | err = e 129 | } 130 | } 131 | }) 132 | return err 133 | } 134 | 135 | // Encode performs EncodeTokens and EncodeEmbeddings. 136 | func (m *Model) Encode(ctx context.Context, s rwkv.State, tokens ...int) (ag.Node, rwkv.State) { 137 | return m.EncodeEmbeddings(ctx, s, m.Embeddings.Encode(tokens)) 138 | } 139 | 140 | // EncodeTokens returns the embeddings of the given tokens. 141 | func (m *Model) EncodeTokens(_ context.Context, tokens ...int) []ag.Node { 142 | return m.Embeddings.Encode(tokens) 143 | } 144 | 145 | // EncodeEmbeddings returns the encoding of the given input considering the last state. 146 | // At least one token is required, otherwise can panic. 147 | // If the input is a sequence, the last state is returned. 148 | func (m *Model) EncodeEmbeddings(_ context.Context, s rwkv.State, xs []ag.Node) (ag.Node, rwkv.State) { 149 | if len(xs) == 1 { 150 | return m.Encoder.ForwardSingle(xs[0], s) 151 | } 152 | 153 | log.Trace().Msgf("Encoding sequence of %d tokens...", len(xs)) 154 | var h []ag.Node 155 | h, s = m.Encoder.ForwardSequence(xs, s) 156 | return h[len(h)-1], s 157 | } 158 | 159 | // Predict returns the prediction logits of the next token. 160 | func (m *Model) Predict(x ag.Node) ag.Node { 161 | return ag.Mul(m.Linear, m.LN.Forward(x)[0]) 162 | } 163 | -------------------------------------------------------------------------------- /service/service.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package service 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "net" 11 | "time" 12 | 13 | "github.com/nlpodyssey/spago/ag" 14 | "github.com/nlpodyssey/verbaflow" 15 | "github.com/nlpodyssey/verbaflow/api" 16 | "github.com/nlpodyssey/verbaflow/decoder" 17 | "github.com/rs/zerolog/log" 18 | "google.golang.org/grpc" 19 | "google.golang.org/grpc/health" 20 | "google.golang.org/grpc/health/grpc_health_v1" 21 | ) 22 | 23 | type Server struct { 24 | api.UnimplementedLanguageModelServer 25 | vf *verbaflow.VerbaFlow 26 | health *health.Server 27 | grpcServer *grpc.Server 28 | } 29 | 30 | func NewServer(vf *verbaflow.VerbaFlow) *Server { 31 | return &Server{ 32 | vf: vf, 33 | health: health.NewServer(), 34 | grpcServer: grpc.NewServer(), 35 | } 36 | } 37 | 38 | func (s *Server) Start(ctx context.Context, address string) error { 39 | lis, err := net.Listen("tcp", address) 40 | if err != nil { 41 | return fmt.Errorf("failed to listen: %w", err) 42 | } 43 | 44 | grpc_health_v1.RegisterHealthServer(s.grpcServer, s.health) 45 | api.RegisterLanguageModelServer(s.grpcServer, s) 46 | 47 | s.health.SetServingStatus(api.LanguageModel_ServiceDesc.ServiceName, grpc_health_v1.HealthCheckResponse_SERVING) 48 | 49 | go s.shutDownServerWhenContextIsDone(ctx) 50 | return s.grpcServer.Serve(lis) 51 | } 52 | 53 | // shutDownServerWhenContextIsDone shuts down the server when the context is done. 54 | func (s *Server) shutDownServerWhenContextIsDone(ctx context.Context) { 55 | <-ctx.Done() 56 | log.Info().Msg("context done, shutting down server") 57 | s.health.Shutdown() 58 | s.grpcServer.GracefulStop() 59 | log.Info().Msg("server shut down successfully") 60 | } 61 | 62 | // GenerateTokens implements the GenerateTokens method of the LanguageModel service. 63 | func (s *Server) GenerateTokens(req *api.TokenGenerationRequest, stream api.LanguageModel_GenerateTokensServer) error { 64 | ctx := stream.Context() 65 | log.Debug().Msgf("Received request from", ctx.Value("client")) 66 | 67 | opts := grpcToDecodingOptions(req.GetDecodingParameters()) 68 | 69 | // chGen is a channel that will receive the generated tokens 70 | chGen := make(chan decoder.GeneratedToken, opts.MaxLen) 71 | errCh := make(chan error) 72 | go func() { 73 | // free the computational graph after the generation is finished 74 | nt := &ag.NodesTracker{} 75 | defer nt.ReleaseNodes() 76 | 77 | log.Trace().Msgf("Decoding...") 78 | start := time.Now() 79 | errCh <- s.vf.Generate(ctx, nt, req.GetPrompt(), chGen, opts) 80 | log.Trace().Msgf("Inference time: %.2f seconds", time.Since(start).Seconds()) 81 | }() 82 | 83 | checkWriteConditions := func(tokenID int) bool { 84 | return !(tokenID == opts.EndTokenID && opts.SkipEndTokenID) 85 | } 86 | 87 | for gen := range chGen { 88 | if !checkWriteConditions(gen.TokenID) { 89 | continue 90 | } 91 | token, err := s.vf.TokenByID(gen.TokenID) 92 | if err != nil { 93 | return fmt.Errorf("failed to reconstruct text for token ID %d", gen.TokenID) 94 | } 95 | if err = stream.Send(&api.GeneratedToken{ 96 | Token: token, 97 | Score: float32(gen.SumNegLogProbs), 98 | }); err != nil { 99 | return err 100 | } 101 | } 102 | 103 | err := <-errCh 104 | if err != nil { 105 | return err 106 | } 107 | 108 | log.Debug().Msg("Done.") 109 | return nil 110 | } 111 | 112 | func grpcToDecodingOptions(dp *api.DecodingParameters) decoder.DecodingOptions { 113 | return decoder.DecodingOptions{ 114 | MaxLen: int(dp.MaxLen), 115 | MinLen: int(dp.MinLen), 116 | StopSequencesIDs: nil, 117 | EndTokenID: int(dp.EndTokenId), 118 | SkipEndTokenID: dp.SkipEndTokenId, 119 | Temp: float64(dp.Temperature), 120 | TopK: int(dp.TopK), 121 | TopP: float64(dp.TopP), 122 | UseSampling: dp.UseSampling, 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /sliceutils/indexedslice.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 The NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sliceutils 6 | 7 | // IndexedSlice allows sorting a slice of Ordered values without losing 8 | // track of the initial (pre-sorting) index of each element. 9 | type IndexedSlice[T Ordered] struct { 10 | // Slice of values exposed to the sorting operation. 11 | Slice []T 12 | // Indices is initialized with the index of each element 13 | // in the original slice, and is sorted in parallel with Slice. 14 | Indices []int 15 | } 16 | 17 | // NewIndexedSlice creates a new IndexedSlice. 18 | func NewIndexedSlice[T Ordered](slice []T) IndexedSlice[T] { 19 | indices := make([]int, len(slice)) 20 | for i := range indices { 21 | indices[i] = i 22 | } 23 | s := IndexedSlice[T]{ 24 | Slice: slice, 25 | Indices: indices, 26 | } 27 | return s 28 | } 29 | 30 | // Len returns the length of the slice. 31 | func (s IndexedSlice[_]) Len() int { 32 | return len(s.Indices) 33 | } 34 | 35 | // Less reports whether the value at index i is less than the value at index j. 36 | func (s IndexedSlice[T]) Less(i, j int) bool { 37 | return s.Slice[i] < s.Slice[j] 38 | } 39 | 40 | // Swap swaps the elements at indices i and j, on both Indices and Slice. 41 | func (s IndexedSlice[T]) Swap(i, j int) { 42 | in := s.Indices 43 | sl := s.Slice 44 | in[i], in[j] = in[j], in[i] 45 | sl[i], sl[j] = sl[j], sl[i] 46 | } 47 | -------------------------------------------------------------------------------- /sliceutils/indexedslice_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 The NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sliceutils 6 | 7 | import ( 8 | "sort" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | var _ sort.Interface = IndexedSlice[int]{} 16 | 17 | func TestNewIndexedSlice(t *testing.T) { 18 | values := []int{7, 8, 9} 19 | s := NewIndexedSlice(values) 20 | assert.Equal(t, values, s.Slice) 21 | assert.Equal(t, []int{0, 1, 2}, s.Indices) 22 | } 23 | 24 | func TestIndexedSlice_Len(t *testing.T) { 25 | tests := []struct { 26 | s IndexedSlice[int] 27 | l int 28 | }{ 29 | {NewIndexedSlice([]int{}), 0}, 30 | {NewIndexedSlice([]int{42}), 1}, 31 | {NewIndexedSlice([]int{8, 9}), 2}, 32 | {NewIndexedSlice([]int{1, 3, 5}), 3}, 33 | } 34 | for _, tt := range tests { 35 | assert.Equalf(t, tt.l, tt.s.Len(), "len of %v", tt.s) 36 | } 37 | } 38 | 39 | func TestIndexedSlice_Less(t *testing.T) { 40 | s := NewIndexedSlice([]int{5, 1, 9, 1}) 41 | tests := [4][4]bool{ 42 | {false, false, true, false}, 43 | {true, false, true, false}, 44 | {false, false, false, false}, 45 | {true, false, true, false}, 46 | } 47 | for i, iv := range tests { 48 | for j, want := range iv { 49 | assert.Equalf(t, want, s.Less(i, j), "Less(%d, %d)", i, j) 50 | } 51 | } 52 | } 53 | 54 | func TestIndexedSlice_Swap(t *testing.T) { 55 | s := NewIndexedSlice([]int{7, 8, 9}) 56 | swaps := []struct { 57 | i, j int 58 | wantSlice []int 59 | wantIndies []int 60 | }{ 61 | {0, 1, []int{8, 7, 9}, []int{1, 0, 2}}, 62 | {0, 2, []int{9, 7, 8}, []int{2, 0, 1}}, 63 | {1, 2, []int{9, 8, 7}, []int{2, 1, 0}}, 64 | {2, 0, []int{7, 8, 9}, []int{0, 1, 2}}, 65 | {1, 1, []int{7, 8, 9}, []int{0, 1, 2}}, 66 | } 67 | for _, sw := range swaps { 68 | s.Swap(sw.i, sw.j) 69 | require.EqualValuesf(t, sw.wantSlice, s.Slice, "slice after Swap(%d, %d)", sw.i, sw.j) 70 | require.EqualValuesf(t, sw.wantIndies, s.Indices, "indices after Swap(%d, %d)", sw.i, sw.j) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /sliceutils/orderedheap.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 The NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sliceutils 6 | 7 | // OrderedHeap is a min-heap of Ordered values. 8 | type OrderedHeap[T Ordered] []T 9 | 10 | // Len returns the length of the heap. 11 | func (t OrderedHeap[_]) Len() int { 12 | return len(t) 13 | } 14 | 15 | // Less reports whether the value at index i is less than the value at index j. 16 | func (t OrderedHeap[_]) Less(i, j int) bool { 17 | return t[i] < t[j] 18 | } 19 | 20 | // Swap swaps the elements at indices i and j. 21 | func (t OrderedHeap[_]) Swap(i, j int) { 22 | t[i], t[j] = t[j], t[i] 23 | } 24 | 25 | // Push appends the value x to the heap. 26 | func (t *OrderedHeap[T]) Push(x any) { 27 | *t = append(*t, x.(T)) 28 | } 29 | 30 | // Pop removes the last element from the heap and returns its value. 31 | func (t *OrderedHeap[T]) Pop() any { 32 | lastIndex := len(*t) - 1 33 | x := (*t)[lastIndex] 34 | *t = (*t)[:lastIndex] 35 | return x 36 | } 37 | -------------------------------------------------------------------------------- /sliceutils/orderedheap_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 The NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sliceutils 6 | 7 | import ( 8 | "container/heap" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | var _ heap.Interface = new(OrderedHeap[int]) 16 | 17 | func TestOrderedHeap_Len(t *testing.T) { 18 | tests := []struct { 19 | h *OrderedHeap[int] 20 | l int 21 | }{ 22 | {&OrderedHeap[int]{}, 0}, 23 | {&OrderedHeap[int]{42}, 1}, 24 | {&OrderedHeap[int]{8, 9}, 2}, 25 | {&OrderedHeap[int]{1, 3, 5}, 3}, 26 | } 27 | for _, tt := range tests { 28 | assert.Equalf(t, tt.l, tt.h.Len(), "len of %v", *tt.h) 29 | } 30 | } 31 | 32 | func TestOrderedHeap_Less(t *testing.T) { 33 | h := &OrderedHeap[int]{5, 1, 9, 1} 34 | tests := [4][4]bool{ 35 | {false, false, true, false}, 36 | {true, false, true, false}, 37 | {false, false, false, false}, 38 | {true, false, true, false}, 39 | } 40 | for i, iv := range tests { 41 | for j, want := range iv { 42 | assert.Equalf(t, want, h.Less(i, j), "Less(%d, %d)", i, j) 43 | } 44 | } 45 | } 46 | 47 | func TestOrderedHeap_Swap(t *testing.T) { 48 | h := &OrderedHeap[int]{0, 1, 2} 49 | swaps := []struct { 50 | i, j int 51 | want []int 52 | }{ 53 | {0, 1, []int{1, 0, 2}}, 54 | {0, 2, []int{2, 0, 1}}, 55 | {1, 2, []int{2, 1, 0}}, 56 | {2, 0, []int{0, 1, 2}}, 57 | {1, 1, []int{0, 1, 2}}, 58 | } 59 | for _, s := range swaps { 60 | h.Swap(s.i, s.j) 61 | require.EqualValuesf(t, s.want, *h, "after Swap(%d, %d)", s.i, s.j) 62 | } 63 | } 64 | 65 | func TestOrderedHeap_Push(t *testing.T) { 66 | h := &OrderedHeap[int]{} 67 | pushes := []struct { 68 | x int 69 | want []int 70 | }{ 71 | {1, []int{1}}, 72 | {3, []int{1, 3}}, 73 | {5, []int{1, 3, 5}}, 74 | {7, []int{1, 3, 5, 7}}, 75 | } 76 | for _, p := range pushes { 77 | h.Push(p.x) 78 | require.EqualValuesf(t, p.want, *h, "after Push(%d)", p.x) 79 | } 80 | } 81 | 82 | func TestOrderedHeap_Pop(t *testing.T) { 83 | h := &OrderedHeap[int]{1, 3, 5} 84 | 85 | pops := []struct { 86 | x int 87 | rest []int 88 | }{ 89 | {5, []int{1, 3}}, 90 | {3, []int{1}}, 91 | {1, []int{}}, 92 | } 93 | for i, p := range pops { 94 | x := h.Pop() 95 | assert.Equalf(t, p.x, x, "value after pop #%d", i+1) 96 | require.EqualValuesf(t, p.rest, *h, "remaining after pop #%d", i+1) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /sliceutils/reverseheap.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 The NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sliceutils 6 | 7 | import "container/heap" 8 | 9 | // ReverseHeap returns the reverse order for data. 10 | func ReverseHeap(data heap.Interface) heap.Interface { 11 | return &reverseHeap{data} 12 | } 13 | 14 | type reverseHeap struct { 15 | // This embedded heap.Interface permits ReverseHeap to use the methods of 16 | // another heap.Interface implementations. 17 | heap.Interface 18 | } 19 | 20 | // Less returns the opposite of the embedded implementations's Less method. 21 | func (r reverseHeap) Less(i, j int) bool { 22 | return r.Interface.Less(j, i) 23 | } 24 | -------------------------------------------------------------------------------- /sliceutils/reverseheap_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 The NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sliceutils 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestReverseHeap_Less(t *testing.T) { 14 | h := ReverseHeap(&dummyHeap{5, 1, 9, 1}) 15 | tests := [4][4]bool{ 16 | {false, true, false, true}, 17 | {false, false, false, false}, 18 | {true, true, false, true}, 19 | {false, false, false, false}, 20 | } 21 | for i, iv := range tests { 22 | for j, want := range iv { 23 | assert.Equalf(t, want, h.Less(i, j), "Less(%d, %d)", i, j) 24 | } 25 | } 26 | } 27 | 28 | type dummyHeap []int 29 | 30 | func (h dummyHeap) Less(i, j int) bool { return h[i] < h[j] } 31 | func (h dummyHeap) Len() int { panic("unexpected call to Len") } 32 | func (h dummyHeap) Swap(i, j int) { panic("unexpected call to Swap") } 33 | func (h dummyHeap) Push(x any) { panic("unexpected call to Push") } 34 | func (h dummyHeap) Pop() any { panic("unexpected call to Pop") } 35 | -------------------------------------------------------------------------------- /sliceutils/slices.go: -------------------------------------------------------------------------------- 1 | // Copyright 2022 The NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package sliceutils provides types and functions for various operations over 6 | // sliceutils of different types. 7 | package sliceutils 8 | 9 | // Ordered is a type constraint that permits any ordered type, that is, any 10 | // type supporting the operators < <= >= >. 11 | type Ordered interface { 12 | ~int | ~int8 | ~int16 | ~int32 | ~int64 | 13 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | 14 | ~float32 | ~float64 | 15 | ~string 16 | } 17 | -------------------------------------------------------------------------------- /tokenizer/internal/bpetokenizer/testdata/dummy-roberta-model/merges.txt: -------------------------------------------------------------------------------- 1 | #version: 0.2 2 | r e 3 | a t 4 | e d 5 | u n 6 | at ed 7 | re l 8 | rel ated 9 | un related 10 | -------------------------------------------------------------------------------- /tokenizer/internal/bpetokenizer/testdata/dummy-roberta-model/vocab.json: -------------------------------------------------------------------------------- 1 | { 2 | "u": 0, 3 | "n": 1, 4 | "r": 2, 5 | "e": 3, 6 | "l": 4, 7 | "a": 5, 8 | "t": 6, 9 | "d": 7, 10 | "re": 8, 11 | "at": 9, 12 | "ed": 10, 13 | "un": 11, 14 | "ated": 12, 15 | "rel": 13, 16 | "related": 14, 17 | "unrelated": 15 18 | } 19 | -------------------------------------------------------------------------------- /tokenizer/internal/bpetokenizer/tokenizer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 spaGO Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package bpetokenizer 6 | 7 | import ( 8 | "fmt" 9 | "path/filepath" 10 | "strings" 11 | 12 | "github.com/nlpodyssey/gotokenizers/encodings" 13 | "github.com/nlpodyssey/gotokenizers/models" 14 | "github.com/nlpodyssey/gotokenizers/models/bpemodel" 15 | "github.com/nlpodyssey/gotokenizers/normalizedstring" 16 | "github.com/nlpodyssey/gotokenizers/pretokenizedstring" 17 | "github.com/nlpodyssey/gotokenizers/pretokenizers/bytelevelpretokenizer" 18 | "github.com/nlpodyssey/gotokenizers/vocabulary" 19 | ) 20 | 21 | const ( 22 | defaultCacheCapacity = 0 23 | defaultDropout = 0.0 24 | defaultUnknownToken = "" 25 | defaultContinuingSubwordPrefix = "" 26 | defaultEndOfWordSuffix = "" 27 | defaultPrefixSpaceEnabled = false 28 | defaultOffsetsTrimmingEnabled = true 29 | defaultUnknownFusionEnabled = false 30 | ) 31 | 32 | // BPETokenizer is a higher-level tokenizer, which includes byte-level pre-tokenization. 33 | type BPETokenizer struct { 34 | preTokenizer *bytelevelpretokenizer.ByteLevelPreTokenizer 35 | model *bpemodel.BPEModel 36 | vocab *vocabulary.Vocabulary 37 | extraSpecialTokenIDs map[int]string 38 | ControlTokenIDs ControlTokensIDs 39 | 40 | StripPaddingTokensDuringTextReconstruction bool 41 | } 42 | 43 | type ControlTokensIDs struct { 44 | EosTokenID int 45 | BosTokenID int 46 | PadTokenID int 47 | DecoderStartTokenID int 48 | ExtraSpecialTokenIDs map[int]string 49 | } 50 | 51 | // Load returns a BPETokenizer from a file. 52 | func Load(path string, controlTokensIDs ControlTokensIDs) (*BPETokenizer, error) { 53 | vocabularyFilename := filepath.Join(path, "vocab.json") 54 | vocab, err := vocabulary.FromJSONFile(vocabularyFilename) 55 | if err != nil { 56 | return nil, fmt.Errorf("loading vocabulary from file %s: %w", vocabularyFilename, err) 57 | } 58 | 59 | mergesFilename := filepath.Join(path, "merges.txt") 60 | merges, err := bpemodel.MergeMapFromFile( 61 | mergesFilename, 62 | vocab, 63 | len(defaultContinuingSubwordPrefix), 64 | ) 65 | if err != nil { 66 | return nil, fmt.Errorf("loading merges from file %s: %w", mergesFilename, err) 67 | } 68 | 69 | preTokenizer := bytelevelpretokenizer.New( 70 | bytelevelpretokenizer.DefaultSplittingRegexp, 71 | defaultPrefixSpaceEnabled, 72 | defaultOffsetsTrimmingEnabled, 73 | ) 74 | 75 | model := bpemodel.New( 76 | vocab, 77 | merges, 78 | defaultCacheCapacity, 79 | defaultDropout, 80 | defaultUnknownToken, 81 | defaultContinuingSubwordPrefix, 82 | defaultEndOfWordSuffix, 83 | defaultUnknownFusionEnabled, 84 | ) 85 | 86 | t := &BPETokenizer{ 87 | preTokenizer: preTokenizer, 88 | model: model, 89 | vocab: vocab, 90 | ControlTokenIDs: controlTokensIDs, 91 | StripPaddingTokensDuringTextReconstruction: false, 92 | } 93 | if controlTokensIDs.ExtraSpecialTokenIDs != nil { 94 | t.SetExtraSpecialTokens(controlTokensIDs.ExtraSpecialTokenIDs) 95 | } 96 | 97 | return t, nil 98 | } 99 | 100 | func (t *BPETokenizer) SetExtraSpecialTokens(extra map[int]string) { 101 | t.extraSpecialTokenIDs = extra 102 | } 103 | 104 | // Encode converts a text into an encoded tokens representation useful for Transformer architectures. 105 | // It tokenizes using byte-level pre-tokenization and BPE tokenization. 106 | func (t *BPETokenizer) Encode(text string) (*encodings.Encoding, error) { 107 | pts := pretokenizedstring.FromString(text) 108 | 109 | err := t.preTokenizer.PreTokenize(pts) 110 | if err != nil { 111 | return nil, fmt.Errorf("BPETokenizer PreTokenize for %s: %w", text, err) 112 | } 113 | 114 | err = pts.Tokenize( 115 | func(ns *normalizedstring.NormalizedString) ([]models.Token, error) { 116 | return t.model.Tokenize(ns.Get()) 117 | }, 118 | ) 119 | if err != nil { 120 | return nil, fmt.Errorf("BPETokenizer Tokenize for %s: %w", text, err) 121 | } 122 | 123 | encoding, err := pts.IntoEncoding(0, 0) 124 | if err != nil { 125 | return nil, fmt.Errorf("BPETokenizer Encoding for %s: %w", text, err) 126 | } 127 | return encoding, nil 128 | } 129 | 130 | // Tokenize returns the token IDs of the input text applying the EOS pad token. 131 | func (t *BPETokenizer) Tokenize(text string) ([]int, error) { 132 | encoded, err := t.Encode(text) 133 | if err != nil { 134 | return nil, err 135 | } 136 | return encoded.IDs, nil 137 | } 138 | 139 | // ReconstructText returns the text of the input token IDs removing the padding token. 140 | func (t *BPETokenizer) ReconstructText(tokenIds []int) (string, error) { 141 | if !t.StripPaddingTokensDuringTextReconstruction { 142 | return t.internalDetokenize(tokenIds), nil 143 | } 144 | 145 | stripPaddingTokensFn := func(tokenIds []int) []int { 146 | result := make([]int, 0, len(tokenIds)) 147 | for _, id := range tokenIds { 148 | if id == t.ControlTokenIDs.EosTokenID || id == t.ControlTokenIDs.PadTokenID || id == t.ControlTokenIDs.BosTokenID || id == t.ControlTokenIDs.DecoderStartTokenID { 149 | continue 150 | } 151 | result = append(result, id) 152 | } 153 | return result 154 | } 155 | 156 | return t.internalDetokenize(stripPaddingTokensFn(tokenIds)), nil 157 | } 158 | 159 | // Detokenize flatten and merges a list of ids into a single string. 160 | // TODO: handle proper detokenization 161 | func (t *BPETokenizer) internalDetokenize(ids []int) string { 162 | var sb strings.Builder 163 | for _, id := range ids { 164 | if s, ok := t.extraSpecialTokenIDs[id]; ok { 165 | sb.WriteString(s) 166 | continue 167 | } 168 | 169 | if s, ok := t.vocab.GetString(id); ok { 170 | sb.WriteString(s) 171 | } 172 | } 173 | out := sb.String() 174 | out = strings.Replace(out, "Ġ", " ", -1) 175 | out = strings.Replace(out, "Ċ", "\n", -1) 176 | return out 177 | } 178 | -------------------------------------------------------------------------------- /tokenizer/internal/bpetokenizer/tokenizer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 spaGO Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package bpetokenizer 6 | 7 | import ( 8 | "testing" 9 | ) 10 | 11 | func TestNew(t *testing.T) { 12 | tokenizer, err := Load("testdata/dummy-roberta-model", ControlTokensIDs{}) 13 | if err != nil { 14 | t.Fatal(err) 15 | } 16 | if tokenizer == nil { 17 | t.Fatal("expected *BPETokenizer, actual nil") 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /tokenizer/tokenizer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package tokenizer 6 | 7 | import "github.com/nlpodyssey/verbaflow/tokenizer/internal/bpetokenizer" 8 | 9 | // Tokenizer is the interface that wraps the basic tokenizers methods. 10 | type Tokenizer interface { 11 | // Tokenize returns the sequence of token IDs for the given text. 12 | Tokenize(text string) ([]int, error) 13 | // ReconstructText returns the text corresponding to the given sequence of token IDs. 14 | ReconstructText(ids []int) (string, error) 15 | } 16 | 17 | // Load loads a tokenizer from the given path. 18 | func Load(path string) (Tokenizer, error) { 19 | tk, err := bpetokenizer.Load(path, bpetokenizer.ControlTokensIDs{}) 20 | if err != nil { 21 | return nil, err 22 | } 23 | return tk, nil 24 | } 25 | -------------------------------------------------------------------------------- /tools/tools.go: -------------------------------------------------------------------------------- 1 | //go:build tools 2 | // +build tools 3 | 4 | package tools 5 | 6 | import ( 7 | _ "google.golang.org/grpc" 8 | _ "google.golang.org/grpc/cmd/protoc-gen-go-grpc" 9 | _ "google.golang.org/protobuf/cmd/protoc-gen-go" 10 | ) 11 | -------------------------------------------------------------------------------- /verbaflow.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 NLP Odyssey Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package verbaflow 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "os" 11 | "path/filepath" 12 | "time" 13 | 14 | "github.com/nlpodyssey/spago/ag" 15 | "github.com/nlpodyssey/spago/embeddings/store/diskstore" 16 | "github.com/nlpodyssey/verbaflow/decoder" 17 | "github.com/nlpodyssey/verbaflow/encoder" 18 | "github.com/nlpodyssey/verbaflow/rwkvlm" 19 | "github.com/nlpodyssey/verbaflow/tokenizer" 20 | "github.com/rs/zerolog/log" 21 | ) 22 | 23 | // VerbaFlow is the core struct of the library. 24 | type VerbaFlow struct { 25 | Model *rwkvlm.Model 26 | Tokenizer tokenizer.Tokenizer 27 | embeddingsRepo *diskstore.Repository 28 | } 29 | 30 | // Load loads a VerbaFlow model from the given directory. 31 | func Load(modelDir string) (*VerbaFlow, error) { 32 | tk, err := tokenizer.Load(modelDir) 33 | if err != nil { 34 | return nil, err 35 | } 36 | model, err := rwkvlm.Load(modelDir) 37 | if err != nil { 38 | if os.IsNotExist(err) { 39 | return nil, fmt.Errorf("error: unable to find the model file or directory '%s'. Please ensure that the model has been successfully downloaded and converted before trying again", modelDir) 40 | } 41 | return nil, err 42 | } 43 | embeddingsRepo, err := diskstore.NewRepository(filepath.Join(modelDir, rwkvlm.DefaultEmbeddingRepoPath), diskstore.ReadOnlyMode) 44 | if err != nil { 45 | return nil, fmt.Errorf("failed to load embeddings repository: %w", err) 46 | } 47 | err = model.ApplyEmbeddings(embeddingsRepo) 48 | if err != nil { 49 | return nil, fmt.Errorf("failed to apply embeddings: %w", err) 50 | } 51 | return &VerbaFlow{ 52 | Model: model, 53 | Tokenizer: tk, 54 | embeddingsRepo: embeddingsRepo, 55 | }, nil 56 | } 57 | 58 | // Close closes the model resources. 59 | func (vf *VerbaFlow) Close() error { 60 | return vf.embeddingsRepo.Close() 61 | } 62 | 63 | // Generate generates a text from the given prompt. 64 | // The "out" channel is used to stream the generated text. 65 | // The generated text will be at most `maxTokens` long (in addition to the prompt). 66 | func (vf *VerbaFlow) Generate(ctx context.Context, nt *ag.NodesTracker, prompt string, chGen chan decoder.GeneratedToken, opts decoder.DecodingOptions) error { 67 | log.Trace().Msgf("Tokenizing prompt: %q", prompt) 68 | tokenized, err := vf.Tokenizer.Tokenize(prompt) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | log.Trace().Msgf("Preprocessing %d token IDs: %v", len(tokenized), tokenized) 74 | start := time.Now() 75 | encoderOutput, err := encoder.New(vf.Model).Encode(ctx, tokenized) 76 | if err != nil { 77 | return err 78 | } 79 | log.Trace().Msgf("Preprocessing took %s", time.Since(start)) 80 | 81 | log.Trace().Msg("Generating...") 82 | d, err := decoder.New(vf.Model, opts) 83 | if err != nil { 84 | return err 85 | } 86 | 87 | return d.Decode(ctx, nt, encoderOutput, chGen) 88 | } 89 | 90 | // TokenByID returns the token string for the given token ID. 91 | func (vf *VerbaFlow) TokenByID(id int) (string, error) { 92 | return vf.Tokenizer.ReconstructText([]int{id}) 93 | } 94 | --------------------------------------------------------------------------------