├── .github
└── workflows
│ └── build.yml
├── .gitignore
├── CHANGELOG.md
├── Gemfile
├── LICENSE.txt
├── README.md
├── Rakefile
├── lib
├── transformers-rb.rb
├── transformers.rb
└── transformers
│ ├── activations.rb
│ ├── configuration_utils.rb
│ ├── convert_slow_tokenizer.rb
│ ├── data
│ └── processors
│ │ └── squad.rb
│ ├── dynamic_module_utils.rb
│ ├── feature_extraction_utils.rb
│ ├── hf_hub
│ ├── constants.rb
│ ├── errors.rb
│ ├── file_download.rb
│ └── utils
│ │ ├── _errors.rb
│ │ └── _headers.rb
│ ├── image_processing_base.rb
│ ├── image_processing_utils.rb
│ ├── image_transforms.rb
│ ├── image_utils.rb
│ ├── modeling_outputs.rb
│ ├── modeling_utils.rb
│ ├── models
│ ├── auto
│ │ ├── auto_factory.rb
│ │ ├── configuration_auto.rb
│ │ ├── feature_extraction_auto.rb
│ │ ├── image_processing_auto.rb
│ │ ├── modeling_auto.rb
│ │ └── tokenization_auto.rb
│ ├── bert
│ │ ├── configuration_bert.rb
│ │ ├── modeling_bert.rb
│ │ ├── tokenization_bert.rb
│ │ └── tokenization_bert_fast.rb
│ ├── deberta_v2
│ │ ├── configuration_deberta_v2.rb
│ │ ├── modeling_deberta_v2.rb
│ │ └── tokenization_deberta_v2_fast.rb
│ ├── distilbert
│ │ ├── configuration_distilbert.rb
│ │ ├── modeling_distilbert.rb
│ │ ├── tokenization_distilbert.rb
│ │ └── tokenization_distilbert_fast.rb
│ ├── mpnet
│ │ ├── configuration_mpnet.rb
│ │ ├── modeling_mpnet.rb
│ │ └── tokenization_mpnet_fast.rb
│ ├── vit
│ │ ├── configuration_vit.rb
│ │ ├── image_processing_vit.rb
│ │ └── modeling_vit.rb
│ └── xlm_roberta
│ │ ├── configuration_xlm_roberta.rb
│ │ ├── modeling_xlm_roberta.rb
│ │ └── tokenization_xlm_roberta_fast.rb
│ ├── pipelines
│ ├── _init.rb
│ ├── base.rb
│ ├── embedding.rb
│ ├── feature_extraction.rb
│ ├── image_classification.rb
│ ├── image_feature_extraction.rb
│ ├── pt_utils.rb
│ ├── question_answering.rb
│ ├── reranking.rb
│ ├── text_classification.rb
│ └── token_classification.rb
│ ├── ruby_utils.rb
│ ├── sentence_transformer.rb
│ ├── tokenization_utils.rb
│ ├── tokenization_utils_base.rb
│ ├── tokenization_utils_fast.rb
│ ├── torch_utils.rb
│ ├── utils
│ ├── _init.rb
│ ├── generic.rb
│ ├── hub.rb
│ ├── import_utils.rb
│ └── logging.rb
│ └── version.rb
├── licenses
├── LICENSE-huggingface-hub.txt
├── LICENSE-sentence-transformers.txt
└── NOTICE-sentence-transformers.txt
├── test
├── model_test.rb
├── pipeline_test.rb
├── test_helper.rb
└── tokenizer_test.rb
└── transformers-rb.gemspec
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: build
2 | on: [push, pull_request]
3 | jobs:
4 | build:
5 | runs-on: ubuntu-latest
6 | env:
7 | BUNDLE_BUILD__TORCH___RB: "--with-torch-dir=/home/runner/libtorch"
8 | LIBTORCH_VERSION: 2.5.0
9 | steps:
10 | - uses: actions/checkout@v4
11 | - uses: actions/cache@v4
12 | with:
13 | path: ~/libtorch
14 | key: libtorch-${{ env.LIBTORCH_VERSION }}
15 | id: cache-libtorch
16 | - name: Download LibTorch
17 | if: steps.cache-libtorch.outputs.cache-hit != 'true'
18 | run: |
19 | cd ~
20 | wget -q -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-$LIBTORCH_VERSION%2Bcpu.zip
21 | unzip -q libtorch.zip
22 | - uses: ruby/setup-ruby@v1
23 | with:
24 | ruby-version: 3.4
25 | bundler-cache: true
26 | - uses: actions/cache@v4
27 | with:
28 | path: ~/.cache/huggingface
29 | key: huggingface
30 | - run: sudo apt-get update && sudo apt-get install libvips
31 | - run: bundle exec rake download:files
32 | - run: bundle exec rake test
33 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /.bundle/
2 | /.yardoc
3 | /_yardoc/
4 | /coverage/
5 | /doc/
6 | /pkg/
7 | /spec/reports/
8 | /test/support/
9 | /tmp/
10 | *.lock
11 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | ## 0.1.6 (2024-12-29)
2 |
3 | - Fixed error with failed HTTP requests
4 | - Fixed warning with Ruby 3.4
5 |
6 | ## 0.1.5 (2024-11-01)
7 |
8 | - Fixed error with pipelines when called more than 10 times
9 | - Fixed `device` option for pipelines
10 | - Fixed error with `reranking` pipeline
11 |
12 | ## 0.1.4 (2024-10-22)
13 |
14 | - Added `BertForSequenceClassification`
15 |
16 | ## 0.1.3 (2024-09-17)
17 |
18 | - Added `reranking` pipeline
19 | - Added DeBERTa-v2
20 | - Added MPNet
21 | - Added XLM-RoBERTa
22 |
23 | ## 0.1.2 (2024-09-10)
24 |
25 | - Fixed default revision for pipelines
26 |
27 | ## 0.1.1 (2024-08-29)
28 |
29 | - Added `embedding` pipeline
30 | - Added experimental `fast_init` option
31 | - Improved performance of loading models
32 | - Fixed error with `aggregation_strategy` option
33 |
34 | ## 0.1.0 (2024-08-19)
35 |
36 | - First release
37 |
--------------------------------------------------------------------------------
/Gemfile:
--------------------------------------------------------------------------------
1 | source "https://rubygems.org"
2 |
3 | gemspec
4 |
5 | gem "rake"
6 | gem "minitest"
7 | gem "ruby-vips"
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Transformers.rb
2 |
3 | :slightly_smiling_face: State-of-the-art [transformers](https://github.com/huggingface/transformers) for Ruby
4 |
5 | For fast inference, check out [Informers](https://github.com/ankane/informers) :fire:
6 |
7 | [](https://github.com/ankane/transformers-ruby/actions)
8 |
9 | ## Installation
10 |
11 | First, [install Torch.rb](https://github.com/ankane/torch.rb#installation).
12 |
13 | Then add this line to your application’s Gemfile:
14 |
15 | ```ruby
16 | gem "transformers-rb"
17 | ```
18 |
19 | ## Getting Started
20 |
21 | - [Models](#models)
22 | - [Pipelines](#pipelines)
23 |
24 | ## Models
25 |
26 | Embedding
27 |
28 | - [sentence-transformers/all-MiniLM-L6-v2](#sentence-transformersall-MiniLM-L6-v2)
29 | - [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](#sentence-transformersmulti-qa-MiniLM-L6-cos-v1)
30 | - [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
31 | - [sentence-transformers/paraphrase-MiniLM-L6-v2](#sentence-transformersparaphrase-minilm-l6-v2)
32 | - [mixedbread-ai/mxbai-embed-large-v1](#mixedbread-aimxbai-embed-large-v1)
33 | - [thenlper/gte-small](#thenlpergte-small)
34 | - [intfloat/e5-base-v2](#intfloate5-base-v2)
35 | - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
36 | - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
37 |
38 | Sparse embedding
39 |
40 | - [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
41 |
42 | Reranking
43 |
44 | - [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)
45 | - [BAAI/bge-reranker-base](#baaibge-reranker-base)
46 |
47 | ### sentence-transformers/all-MiniLM-L6-v2
48 |
49 | [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
50 |
51 | ```ruby
52 | sentences = ["This is an example sentence", "Each sentence is converted"]
53 |
54 | model = Transformers.pipeline("embedding", "sentence-transformers/all-MiniLM-L6-v2")
55 | embeddings = model.(sentences)
56 | ```
57 |
58 | ### sentence-transformers/multi-qa-MiniLM-L6-cos-v1
59 |
60 | [Docs](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1)
61 |
62 | ```ruby
63 | query = "How many people live in London?"
64 | docs = ["Around 9 Million people live in London", "London is known for its financial district"]
65 |
66 | model = Transformers.pipeline("embedding", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
67 | query_embedding = model.(query)
68 | doc_embeddings = model.(docs)
69 | scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
70 | doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
71 | ```
72 |
73 | ### sentence-transformers/all-mpnet-base-v2
74 |
75 | [Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
76 |
77 | ```ruby
78 | sentences = ["This is an example sentence", "Each sentence is converted"]
79 |
80 | model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
81 | embeddings = model.(sentences)
82 | ```
83 |
84 | ### sentence-transformers/paraphrase-MiniLM-L6-v2
85 |
86 | [Docs](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)
87 |
88 | ```ruby
89 | sentences = ["This is an example sentence", "Each sentence is converted"]
90 |
91 | model = Transformers.pipeline("embedding", "sentence-transformers/paraphrase-MiniLM-L6-v2")
92 | embeddings = model.(sentences)
93 | ```
94 |
95 | ### mixedbread-ai/mxbai-embed-large-v1
96 |
97 | [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
98 |
99 | ```ruby
100 | query_prefix = "Represent this sentence for searching relevant passages: "
101 |
102 | input = [
103 | "The dog is barking",
104 | "The cat is purring",
105 | query_prefix + "puppy"
106 | ]
107 |
108 | model = Transformers.pipeline("embedding", "mixedbread-ai/mxbai-embed-large-v1")
109 | embeddings = model.(input)
110 | ```
111 |
112 | ### thenlper/gte-small
113 |
114 | [Docs](https://huggingface.co/thenlper/gte-small)
115 |
116 | ```ruby
117 | sentences = ["That is a happy person", "That is a very happy person"]
118 |
119 | model = Transformers.pipeline("embedding", "thenlper/gte-small")
120 | embeddings = model.(sentences)
121 | ```
122 |
123 | ### intfloat/e5-base-v2
124 |
125 | [Docs](https://huggingface.co/intfloat/e5-base-v2)
126 |
127 | ```ruby
128 | doc_prefix = "passage: "
129 | query_prefix = "query: "
130 |
131 | input = [
132 | doc_prefix + "Ruby is a programming language created by Matz",
133 | query_prefix + "Ruby creator"
134 | ]
135 |
136 | model = Transformers.pipeline("embedding", "intfloat/e5-base-v2")
137 | embeddings = model.(input)
138 | ```
139 |
140 | ### BAAI/bge-base-en-v1.5
141 |
142 | [Docs](https://huggingface.co/BAAI/bge-base-en-v1.5)
143 |
144 | ```ruby
145 | query_prefix = "Represent this sentence for searching relevant passages: "
146 |
147 | input = [
148 | "The dog is barking",
149 | "The cat is purring",
150 | query_prefix + "puppy"
151 | ]
152 |
153 | model = Transformers.pipeline("embedding", "BAAI/bge-base-en-v1.5")
154 | embeddings = model.(input)
155 | ```
156 |
157 | ### Snowflake/snowflake-arctic-embed-m-v1.5
158 |
159 | [Docs](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5)
160 |
161 | ```ruby
162 | query_prefix = "Represent this sentence for searching relevant passages: "
163 |
164 | input = [
165 | "The dog is barking",
166 | "The cat is purring",
167 | query_prefix + "puppy"
168 | ]
169 |
170 | model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
171 | embeddings = model.(input, pooling: "cls")
172 | ```
173 |
174 | ### opensearch-project/opensearch-neural-sparse-encoding-v1
175 |
176 | [Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
177 |
178 | ```ruby
179 | docs = ["The dog is barking", "The cat is purring", "The bear is growling"]
180 |
181 | model_id = "opensearch-project/opensearch-neural-sparse-encoding-v1"
182 | model = Transformers::AutoModelForMaskedLM.from_pretrained(model_id)
183 | tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
184 | special_token_ids = tokenizer.special_tokens_map.map { |_, token| tokenizer.vocab[token] }
185 |
186 | feature = tokenizer.(docs, padding: true, truncation: true, return_tensors: "pt", return_token_type_ids: false)
187 | output = model.(**feature)[0]
188 |
189 | values, _ = Torch.max(output * feature[:attention_mask].unsqueeze(-1), dim: 1)
190 | values = Torch.log(1 + Torch.relu(values))
191 | values[0.., special_token_ids] = 0
192 | embeddings = values.to_a
193 | ```
194 |
195 | ### mixedbread-ai/mxbai-rerank-base-v1
196 |
197 | [Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)
198 |
199 | ```ruby
200 | query = "How many people live in London?"
201 | docs = ["Around 9 Million people live in London", "London is known for its financial district"]
202 |
203 | model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
204 | result = model.(query, docs)
205 | ```
206 |
207 | ### BAAI/bge-reranker-base
208 |
209 | [Docs](https://huggingface.co/BAAI/bge-reranker-base)
210 |
211 | ```ruby
212 | query = "How many people live in London?"
213 | docs = ["Around 9 Million people live in London", "London is known for its financial district"]
214 |
215 | model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base")
216 | result = model.(query, docs)
217 | ```
218 |
219 | ## Pipelines
220 |
221 | - [Text](#text)
222 | - [Vision](#vision)
223 |
224 | ### Text
225 |
226 | Embedding
227 |
228 | ```ruby
229 | embed = Transformers.pipeline("embedding")
230 | embed.("We are very happy to show you the 🤗 Transformers library.")
231 | ```
232 |
233 | Reranking
234 |
235 | ```ruby
236 | rerank = Informers.pipeline("reranking")
237 | rerank.("Who created Ruby?", ["Matz created Ruby", "Another doc"])
238 | ```
239 |
240 | Named-entity recognition
241 |
242 | ```ruby
243 | ner = Transformers.pipeline("ner")
244 | ner.("Ruby is a programming language created by Matz")
245 | ```
246 |
247 | Sentiment analysis
248 |
249 | ```ruby
250 | classifier = Transformers.pipeline("sentiment-analysis")
251 | classifier.("We are very happy to show you the 🤗 Transformers library.")
252 | ```
253 |
254 | Question answering
255 |
256 | ```ruby
257 | qa = Transformers.pipeline("question-answering")
258 | qa.(question: "Who invented Ruby?", context: "Ruby is a programming language created by Matz")
259 | ```
260 |
261 | Feature extraction
262 |
263 | ```ruby
264 | extractor = Transformers.pipeline("feature-extraction")
265 | extractor.("We are very happy to show you the 🤗 Transformers library.")
266 | ```
267 |
268 | ### Vision
269 |
270 | Image classification
271 |
272 | ```ruby
273 | classifier = Transformers.pipeline("image-classification")
274 | classifier.("image.jpg")
275 | ```
276 |
277 | Image feature extraction
278 |
279 | ```ruby
280 | extractor = Transformers.pipeline("image-feature-extraction")
281 | extractor.("image.jpg")
282 | ```
283 |
284 | ## API
285 |
286 | This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). The following model architectures are currently supported:
287 |
288 | - BERT
289 | - DeBERTa-v2
290 | - DistilBERT
291 | - MPNet
292 | - ViT
293 | - XLM-RoBERTa
294 |
295 | ## History
296 |
297 | View the [changelog](https://github.com/ankane/transformers-ruby/blob/master/CHANGELOG.md)
298 |
299 | ## Contributing
300 |
301 | Everyone is encouraged to help improve this project. Here are a few ways you can help:
302 |
303 | - [Report bugs](https://github.com/ankane/transformers-ruby/issues)
304 | - Fix bugs and [submit pull requests](https://github.com/ankane/transformers-ruby/pulls)
305 | - Write, clarify, or fix documentation
306 | - Suggest or add new features
307 |
308 | To get started with development:
309 |
310 | ```sh
311 | git clone https://github.com/ankane/transformers-ruby.git
312 | cd transformers-ruby
313 | bundle install
314 | bundle exec rake download:files
315 | bundle exec rake test
316 | ```
317 |
--------------------------------------------------------------------------------
/Rakefile:
--------------------------------------------------------------------------------
1 | require "bundler/gem_tasks"
2 | require "rake/testtask"
3 |
4 | task default: :test
5 | Rake::TestTask.new do |t|
6 | t.libs << "test"
7 | t.pattern = FileList["test/**/*_test.rb"].exclude("test/model_test.rb")
8 | end
9 |
10 | def download_file(url)
11 | require "open-uri"
12 |
13 | file = File.basename(url)
14 | puts "Downloading #{file}..."
15 | dest = "test/support/#{file}"
16 | File.binwrite(dest, URI.parse(url).read)
17 | puts "Saved #{dest}"
18 | end
19 |
20 | namespace :download do
21 | task :files do
22 | Dir.mkdir("test/support") unless Dir.exist?("test/support")
23 |
24 | download_file("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg")
25 | end
26 | end
27 |
--------------------------------------------------------------------------------
/lib/transformers-rb.rb:
--------------------------------------------------------------------------------
1 | require_relative "transformers"
2 |
--------------------------------------------------------------------------------
/lib/transformers.rb:
--------------------------------------------------------------------------------
1 | # dependencies
2 | require "numo/narray"
3 | require "safetensors"
4 | require "tokenizers"
5 | require "torch-rb"
6 |
7 | # stdlib
8 | require "cgi"
9 | require "fileutils"
10 | require "io/console"
11 | require "json"
12 | require "logger"
13 | require "net/http"
14 | require "pathname"
15 | require "securerandom"
16 | require "set"
17 | require "uri"
18 |
19 | # modules
20 | require_relative "transformers/ruby_utils"
21 | require_relative "transformers/utils/generic"
22 | require_relative "transformers/activations"
23 | require_relative "transformers/dynamic_module_utils"
24 | require_relative "transformers/configuration_utils"
25 | require_relative "transformers/convert_slow_tokenizer"
26 | require_relative "transformers/feature_extraction_utils"
27 | require_relative "transformers/image_utils"
28 | require_relative "transformers/image_processing_base"
29 | require_relative "transformers/image_processing_utils"
30 | require_relative "transformers/image_transforms"
31 | require_relative "transformers/modeling_outputs"
32 | require_relative "transformers/modeling_utils"
33 | require_relative "transformers/sentence_transformer"
34 | require_relative "transformers/tokenization_utils_base"
35 | require_relative "transformers/tokenization_utils"
36 | require_relative "transformers/tokenization_utils_fast"
37 | require_relative "transformers/torch_utils"
38 | require_relative "transformers/version"
39 |
40 | # data
41 | require_relative "transformers/data/processors/squad"
42 |
43 | # hub
44 | require_relative "transformers/hf_hub/constants"
45 | require_relative "transformers/hf_hub/errors"
46 | require_relative "transformers/hf_hub/file_download"
47 | require_relative "transformers/hf_hub/utils/_errors"
48 | require_relative "transformers/hf_hub/utils/_headers"
49 |
50 | # models auto
51 | require_relative "transformers/models/auto/auto_factory"
52 | require_relative "transformers/models/auto/configuration_auto"
53 | require_relative "transformers/models/auto/feature_extraction_auto"
54 | require_relative "transformers/models/auto/image_processing_auto"
55 | require_relative "transformers/models/auto/modeling_auto"
56 | require_relative "transformers/models/auto/tokenization_auto"
57 |
58 | # models bert
59 | require_relative "transformers/models/bert/configuration_bert"
60 | require_relative "transformers/models/bert/modeling_bert"
61 | require_relative "transformers/models/bert/tokenization_bert"
62 | require_relative "transformers/models/bert/tokenization_bert_fast"
63 |
64 | # models deberta-v2
65 | require_relative "transformers/models/deberta_v2/configuration_deberta_v2"
66 | require_relative "transformers/models/deberta_v2/modeling_deberta_v2"
67 | require_relative "transformers/models/deberta_v2/tokenization_deberta_v2_fast"
68 |
69 | # models distilbert
70 | require_relative "transformers/models/distilbert/configuration_distilbert"
71 | require_relative "transformers/models/distilbert/modeling_distilbert"
72 | require_relative "transformers/models/distilbert/tokenization_distilbert"
73 | require_relative "transformers/models/distilbert/tokenization_distilbert_fast"
74 |
75 | # models mpnet
76 | require_relative "transformers/models/mpnet/configuration_mpnet"
77 | require_relative "transformers/models/mpnet/modeling_mpnet"
78 | require_relative "transformers/models/mpnet/tokenization_mpnet_fast"
79 |
80 | # models vit
81 | require_relative "transformers/models/vit/configuration_vit"
82 | require_relative "transformers/models/vit/image_processing_vit"
83 | require_relative "transformers/models/vit/modeling_vit"
84 |
85 | # models xml-roberta
86 | require_relative "transformers/models/xlm_roberta/configuration_xlm_roberta"
87 | require_relative "transformers/models/xlm_roberta/modeling_xlm_roberta"
88 | require_relative "transformers/models/xlm_roberta/tokenization_xlm_roberta_fast"
89 |
90 | # pipelines
91 | require_relative "transformers/pipelines/base"
92 | require_relative "transformers/pipelines/feature_extraction"
93 | require_relative "transformers/pipelines/embedding"
94 | require_relative "transformers/pipelines/image_classification"
95 | require_relative "transformers/pipelines/image_feature_extraction"
96 | require_relative "transformers/pipelines/pt_utils"
97 | require_relative "transformers/pipelines/question_answering"
98 | require_relative "transformers/pipelines/reranking"
99 | require_relative "transformers/pipelines/text_classification"
100 | require_relative "transformers/pipelines/token_classification"
101 | require_relative "transformers/pipelines/_init"
102 |
103 | # utils
104 | require_relative "transformers/utils/_init"
105 | require_relative "transformers/utils/import_utils"
106 | require_relative "transformers/utils/hub"
107 | require_relative "transformers/utils/logging"
108 |
109 | module Transformers
110 | class Error < StandardError; end
111 |
112 | class Todo < Error
113 | def message
114 | "not implemented yet"
115 | end
116 | end
117 |
118 | class << self
119 | # experimental
120 | attr_accessor :fast_init
121 | end
122 | self.fast_init = false
123 | end
124 |
--------------------------------------------------------------------------------
/lib/transformers/activations.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class GELUActivation < Torch::NN::Module
17 | def initialize(use_gelu_python: false)
18 | super()
19 | if use_gelu_python
20 | @act = _gelu_python
21 | else
22 | @act = Torch::NN::Functional.method(:gelu)
23 | end
24 | end
25 |
26 | def _gelu_python(input)
27 | input * 0.5 * (1.0 + Torch.erf(input / Math.sqrt(2.0)))
28 | end
29 |
30 | def forward(input)
31 | @act.(input)
32 | end
33 | end
34 |
35 | class ClassInstantier
36 | def initialize(data)
37 | @data = data
38 | end
39 |
40 | def [](key)
41 | content = @data.fetch(key)
42 | cls, kwargs = content.is_a?(Array) ? content : [content, {}]
43 | cls.new(**kwargs)
44 | end
45 | end
46 |
47 | ACT2CLS = {
48 | "gelu" => GELUActivation
49 | }
50 | ACT2FN = ClassInstantier.new(ACT2CLS)
51 |
52 | module Activations
53 | def self.get_activation(activation_string)
54 | ACT2FN[activation_string]
55 | end
56 | end
57 | end
58 |
--------------------------------------------------------------------------------
/lib/transformers/convert_slow_tokenizer.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module ConvertSlowTokenizer
17 | class Converter
18 | def initialize(original_tokenizer)
19 | @original_tokenizer = original_tokenizer
20 | end
21 |
22 | def converted
23 | raise NotImplementedError
24 | end
25 | end
26 |
27 | class BertConverter < Converter
28 | def converted
29 | vocab = @original_tokenizer.vocab
30 | tokenizer = Tokenizers::Tokenizer.new(Tokenizers::Models::WordPiece.new(vocab: vocab, unk_token: @original_tokenizer.unk_token.to_s))
31 |
32 | tokenize_chinese_chars = false
33 | strip_accents = false
34 | do_lower_case = false
35 | if @original_tokenizer.basic_tokenizer
36 | tokenize_chinese_chars = @original_tokenizer.basic_tokenizer.tokenize_chinese_chars
37 | strip_accents = @original_tokenizer.basic_tokenizer.strip_accents
38 | do_lower_case = @original_tokenizer.basic_tokenizer.do_lower_case
39 | end
40 |
41 | tokenizer.normalizer =
42 | Tokenizers::Normalizers::BertNormalizer.new(
43 | clean_text: true,
44 | handle_chinese_chars: tokenize_chinese_chars,
45 | strip_accents: strip_accents,
46 | lowercase: do_lower_case,
47 | )
48 | tokenizer.pre_tokenizer = Tokenizers::PreTokenizers::BertPreTokenizer.new
49 |
50 | cls = @original_tokenizer.cls_token.to_s
51 | sep = @original_tokenizer.sep_token.to_s
52 | cls_token_id = @original_tokenizer.cls_token_id
53 | sep_token_id = @original_tokenizer.sep_token_id
54 |
55 | tokenizer.post_processor =
56 | Tokenizers::Processors::TemplateProcessing.new(
57 | single: "#{cls}:0 $A:0 #{sep}:0",
58 | pair: "#{cls}:0 $A:0 #{sep}:0 $B:1 #{sep}:1",
59 | special_tokens: [
60 | [cls, cls_token_id],
61 | [sep, sep_token_id]
62 | ]
63 | )
64 | tokenizer.decoder = Tokenizers::Decoders::WordPiece.new(prefix: "##")
65 |
66 | tokenizer
67 | end
68 | end
69 |
70 | SLOW_TO_FAST_CONVERTERS = {
71 | "BertTokenizer" => BertConverter,
72 | "DistilBertTokenizer" => BertConverter
73 | }
74 |
75 | def self.convert_slow_tokenizer(transformer_tokenizer)
76 | tokenizer_class_name = transformer_tokenizer.class.name.split("::").last
77 |
78 | if !SLOW_TO_FAST_CONVERTERS.include?(tokenizer_class_name)
79 | raise ArgumentError,
80 | "An instance of tokenizer class #{tokenizer_class_name} cannot be converted in a Fast tokenizer instance." +
81 | " No converter was found. Currently available slow->fast convertors:" +
82 | " #{SLOW_TO_FAST_CONVERTERS.keys}"
83 | end
84 |
85 | converter_class = SLOW_TO_FAST_CONVERTERS.fetch(tokenizer_class_name)
86 |
87 | converter_class.new(transformer_tokenizer).converted
88 | end
89 | end
90 | end
91 |
--------------------------------------------------------------------------------
/lib/transformers/data/processors/squad.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class SquadExample
17 | attr_reader :question_text, :context_text
18 |
19 | def initialize(
20 | qas_id,
21 | question_text,
22 | context_text,
23 | answer_text,
24 | start_position_character,
25 | title,
26 | answers: [],
27 | is_impossible: false
28 | )
29 | @qas_id = qas_id
30 | @question_text = question_text
31 | @context_text = context_text
32 | @answer_text = answer_text
33 | @title = title
34 | @is_impossible = is_impossible
35 | @answers = answers
36 |
37 | @start_position, @end_position = 0, 0
38 |
39 | doc_tokens = []
40 | char_to_word_offset = []
41 | prev_is_whitespace = true
42 |
43 | # Split on whitespace so that different tokens may be attributed to their original position.
44 | @context_text.each_char do |c|
45 | if _is_whitespace(c)
46 | prev_is_whitespace = true
47 | else
48 | if prev_is_whitespace
49 | doc_tokens << c
50 | else
51 | doc_tokens[-1] += c
52 | end
53 | prev_is_whitespace = false
54 | end
55 | char_to_word_offset << (doc_tokens.length - 1)
56 | end
57 |
58 | @doc_tokens = doc_tokens
59 | @char_to_word_offset = char_to_word_offset
60 |
61 | # Start and end positions only has a value during evaluation.
62 | if !start_position_character.nil? && !is_impossible
63 | @start_position = char_to_word_offset[start_position_character]
64 | @end_position = char_to_word_offset[
65 | [start_position_character + answer_text.length - 1, char_to_word_offset.length - 1].min
66 | ]
67 | end
68 | end
69 |
70 | def _is_whitespace(c)
71 | c == " " || c == "\t" || c == "\r" || c == "\n" || c.ord == 0x202F
72 | end
73 | end
74 |
75 | class SquadFeatures
76 | def initialize(
77 | input_ids:,
78 | attention_mask:,
79 | token_type_ids:,
80 | cls_index:,
81 | p_mask:,
82 | example_index:,
83 | unique_id:,
84 | paragraph_len:,
85 | token_is_max_context:,
86 | tokens:,
87 | token_to_orig_map:,
88 | start_position:,
89 | end_position:,
90 | is_impossible:,
91 | qas_id: nil,
92 | encoding: nil
93 | )
94 | @input_ids = input_ids
95 | @attention_mask = attention_mask
96 | @token_type_ids = token_type_ids
97 | @cls_index = cls_index
98 | @p_mask = p_mask
99 |
100 | @example_index = example_index
101 | @unique_id = unique_id
102 | @paragraph_len = paragraph_len
103 | @token_is_max_context = token_is_max_context
104 | @tokens = tokens
105 | @token_to_orig_map = token_to_orig_map
106 |
107 | @start_position = start_position
108 | @end_position = end_position
109 | @is_impossible = is_impossible
110 | @qas_id = qas_id
111 |
112 | @encoding = encoding
113 | end
114 | end
115 | end
116 |
--------------------------------------------------------------------------------
/lib/transformers/dynamic_module_utils.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module DynamicModuleUtils
17 | # TODO improve
18 | def self.resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code)
19 | if trust_remote_code
20 | raise Error, "trust_remote_code not supported"
21 | end
22 | trust_remote_code
23 | end
24 | end
25 | end
26 |
--------------------------------------------------------------------------------
/lib/transformers/feature_extraction_utils.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class BatchFeature
17 | def initialize(data:, tensor_type:)
18 | @data = data
19 | convert_to_tensors(tensor_type: tensor_type)
20 | end
21 |
22 | def to_h
23 | @data
24 | end
25 | alias_method :to_hash, :to_h
26 |
27 | def [](item)
28 | @data[item]
29 | end
30 |
31 | def keys
32 | @data.keys
33 | end
34 |
35 | def values
36 | @data.values
37 | end
38 |
39 | def items
40 | @data
41 | end
42 |
43 | def _get_is_as_tensor_fns(tensor_type: nil)
44 | if tensor_type.nil?
45 | return [nil, nil]
46 | end
47 |
48 | as_tensor = lambda do |value|
49 | if value.is_a?(Array) && value.length > 0 && value[0].is_a?(Numo::NArray)
50 | value = Numo::NArray.cast(value)
51 | end
52 | Torch.tensor(value)
53 | end
54 |
55 | is_tensor = Torch.method(:tensor?)
56 |
57 | [is_tensor, as_tensor]
58 | end
59 |
60 | def convert_to_tensors(tensor_type: nil)
61 | if tensor_type.nil?
62 | return self
63 | end
64 |
65 | is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type: tensor_type)
66 |
67 | # Do the tensor conversion in batch
68 | items.each do |key, value|
69 | begin
70 | if !is_tensor.(value)
71 | tensor = as_tensor.(value)
72 |
73 | @data[key] = tensor
74 | end
75 | rescue
76 | if key == :overflowing_values
77 | raise ArgumentError, "Unable to create tensor returning overflowing values of different lengths."
78 | end
79 | raise ArgumentError,
80 | "Unable to create tensor, you should probably activate padding " +
81 | "with 'padding: true' to have batched tensors with the same length."
82 | end
83 | end
84 |
85 | self
86 | end
87 |
88 | def to(*args, **kwargs)
89 | new_data = {}
90 | device = kwargs[:device]
91 | # Check if the args are a device or a dtype
92 | if device.nil? && args.length > 0
93 | raise Todo
94 | end
95 | # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
96 | items.each do |k, v|
97 | # check if v is a floating point
98 | if Torch.floating_point?(v)
99 | # cast and send to device
100 | new_data[k] = v.to(*args, **kwargs)
101 | elsif !device.nil?
102 | new_data[k] = v.to(device)
103 | else
104 | new_data[k] = v
105 | end
106 | end
107 | @data = new_data
108 | self
109 | end
110 | end
111 | end
112 |
--------------------------------------------------------------------------------
/lib/transformers/hf_hub/constants.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | module HfHub
3 | # Possible values for env variables
4 |
5 | ENV_VARS_TRUE_VALUES = ["1", "ON", "YES", "TRUE"]
6 | ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES + ["AUTO"]
7 |
8 | def self._is_true(value)
9 | if value.nil?
10 | return false
11 | end
12 | ENV_VARS_TRUE_VALUES.include?(value.upcase)
13 | end
14 |
15 | def self._as_int(value)
16 | if value.nil?
17 | return nil
18 | end
19 | value.to_i
20 | end
21 |
22 | # Constants for file downloads
23 |
24 | DEFAULT_ETAG_TIMEOUT = 10
25 |
26 | # Git-related constants
27 |
28 | DEFAULT_REVISION = "main"
29 |
30 | ENDPOINT = ENV["HF_ENDPOINT"] || "https://huggingface.co"
31 |
32 | HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/%{repo_id}/resolve/%{revision}/%{filename}"
33 | HUGGINGFACE_HEADER_X_REPO_COMMIT = "x-repo-commit"
34 | HUGGINGFACE_HEADER_X_LINKED_ETAG = "x-linked-etag"
35 | HUGGINGFACE_HEADER_X_LINKED_SIZE = "x-linked-size"
36 |
37 | REPO_ID_SEPARATOR = "--"
38 | # ^ this substring is not allowed in repo_ids on hf.co
39 | # and is the canonical one we use for serialization of repo ids elsewhere.
40 |
41 | REPO_TYPE_DATASET = "dataset"
42 | REPO_TYPE_SPACE = "space"
43 | REPO_TYPE_MODEL = "model"
44 | REPO_TYPES = [nil, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
45 |
46 | REPO_TYPES_URL_PREFIXES = {
47 | REPO_TYPE_DATASET => "datasets/",
48 | REPO_TYPE_SPACE => "spaces/"
49 | }
50 |
51 | # default cache
52 | DEFAULT_HOME = File.join(ENV.fetch("HOME"), ".cache")
53 | HF_HOME =
54 | File.expand_path(
55 | ENV.fetch(
56 | "HF_HOME",
57 | File.join(ENV.fetch("XDG_CACHE_HOME", DEFAULT_HOME), "huggingface")
58 | )
59 | )
60 |
61 | # New env variables
62 | HF_HUB_CACHE = ENV["HF_HUB_CACHE"] || File.join(HF_HOME, "hub")
63 |
64 | HF_HUB_OFFLINE = _is_true(ENV["HF_HUB_OFFLINE"] || ENV["TRANSFORMERS_OFFLINE"])
65 |
66 | # Disable sending the cached token by default is all HTTP requests to the Hub
67 | HF_HUB_DISABLE_IMPLICIT_TOKEN = _is_true(ENV["HF_HUB_DISABLE_IMPLICIT_TOKEN"])
68 |
69 | HF_HUB_ENABLE_HF_TRANSFER = _is_true(ENV["HF_HUB_ENABLE_HF_TRANSFER"])
70 | end
71 | end
72 |
--------------------------------------------------------------------------------
/lib/transformers/hf_hub/errors.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | module HfHub
3 | class Error < StandardError; end
4 |
5 | # Raised if local token is required but not found.
6 | class LocalTokenNotFoundError < Error; end
7 |
8 | # Raised when a request is made but `HF_HUB_OFFLINE=1` is set as environment variable.
9 | class OfflineModeIsEnabled < Error; end
10 | end
11 | end
12 |
--------------------------------------------------------------------------------
/lib/transformers/hf_hub/utils/_errors.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | module HfHub
3 | class HfHubHTTPError < Error
4 | def initialize(message, response = nil)
5 | super(message)
6 | end
7 | end
8 |
9 | class RepositoryNotFoundError < HfHubHTTPError; end
10 |
11 | class GatedRepoError < RepositoryNotFoundError; end
12 |
13 | class DisabledRepoError < HfHubHTTPError; end
14 |
15 | class RevisionNotFoundError < HfHubHTTPError; end
16 |
17 | class EntryNotFoundError < HfHubHTTPError; end
18 |
19 | class LocalEntryNotFoundError < EntryNotFoundError; end
20 |
21 | class BadRequestError < HfHubHTTPError; end
22 |
23 | class << self
24 | def hf_raise_for_status(response, endpoint_name: nil)
25 | begin
26 | response.value unless response.is_a?(Net::HTTPRedirection)
27 | rescue => e
28 | error_code = response["X-Error-Code"]
29 | error_message = response["X-Error-Message"]
30 |
31 | if error_code == "RevisionNotFound"
32 | message = "#{response.code} Client Error." + "\n\n" + "Revision Not Found for url: #{response.uri}."
33 | raise RevisionNotFoundError.new(message, response)
34 |
35 | elsif error_code == "EntryNotFound"
36 | message = "#{response.code} Client Error." + "\n\n" + "Entry Not Found for url: #{response.uri}."
37 | raise EntryNotFoundError.new(message, response)
38 |
39 | elsif error_code == "GatedRepo"
40 | message = (
41 | "#{response.code} Client Error." + "\n\n" + "Cannot access gated repo for url #{response.uri}."
42 | )
43 | raise GatedRepoError.new(message, response)
44 |
45 | elsif error_message == "Access to this resource is disabled."
46 | message = (
47 | "#{response.code} Client Error." +
48 | "\n\n" +
49 | "Cannot access repository for url #{response.uri}." +
50 | "\n" +
51 | "Access to this resource is disabled."
52 | )
53 | raise DisabledRepoError.new(message, response)
54 |
55 | elsif error_code == "RepoNotFound"
56 | # 401 is misleading as it is returned for:
57 | # - private and gated repos if user is not authenticated
58 | # - missing repos
59 | # => for now, we process them as `RepoNotFound` anyway.
60 | # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9
61 | message = (
62 | "#{response.code} Client Error." +
63 | "\n\n" +
64 | "Repository Not Found for url: #{response.uri}." +
65 | "\nPlease make sure you specified the correct `repo_id` and" +
66 | " `repo_type`.\nIf you are trying to access a private or gated repo," +
67 | " make sure you are authenticated."
68 | )
69 | raise RepositoryNotFoundError.new(message, response)
70 |
71 | elsif response.code.to_i == 400
72 | message = (
73 | !endpoint_name.nil? ? "\n\nBad request for #{endpoint_name} endpoint:" : "\n\nBad request:"
74 | )
75 | raise BadRequestError.new(message, response)
76 |
77 | elsif response.code.to_i == 403
78 | message = (
79 | "\n\n{response.code} Forbidden: #{error_message}." +
80 | "\nCannot access content at: #{response.uri}." +
81 | "\nIf you are trying to create or update content, " +
82 | "make sure you have a token with the `write` role."
83 | )
84 | raise HfHubHTTPError.new(message, response)
85 | end
86 |
87 | # Convert `HTTPError` into a `HfHubHTTPError` to display request information
88 | # as well (request id and/or server error message)
89 | raise HfHubHTTPError.new(e.to_s, response)
90 | end
91 | end
92 | end
93 | end
94 | end
95 |
--------------------------------------------------------------------------------
/lib/transformers/hf_hub/utils/_headers.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, the HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module HfHub
17 | class << self
18 | def build_hf_headers(
19 | token: nil,
20 | is_write_action: false,
21 | library_name: nil,
22 | library_version: nil,
23 | user_agent: nil,
24 | headers: nil
25 | )
26 | # Get auth token to send
27 | token_to_send = get_token_to_send(token)
28 | _validate_token_to_send(token_to_send, is_write_action)
29 |
30 | # Combine headers
31 | hf_headers = {
32 | "user-agent" => _http_user_agent(
33 | library_name: library_name,
34 | library_version: library_version,
35 | user_agent: user_agent
36 | )
37 | }
38 | if !token_to_send.nil?
39 | hf_headers["authorization"] = "Bearer #{token_to_send}"
40 | end
41 | if headers
42 | hf_headers.merge!(headers)
43 | end
44 | hf_headers
45 | end
46 |
47 | def get_token_to_send(token)
48 | # Case token is explicitly provided
49 | if token.is_a?(String)
50 | return token
51 | end
52 |
53 | # Case token is explicitly forbidden
54 | if token == false
55 | return nil
56 | end
57 |
58 | # Token is not provided: we get it from local cache
59 | cached_token = nil # get_token
60 |
61 | # Case token is explicitly required
62 | if token == true
63 | if cached_token.nil?
64 | raise LocalTokenNotFoundError,
65 | "Token is required (`token: true`), but no token found. You" +
66 | " need to provide a token or be logged in to Hugging Face with" +
67 | " `huggingface-cli login` or `huggingface_hub.login`. See" +
68 | " https://huggingface.co/settings/tokens."
69 | end
70 | return cached_token
71 | end
72 |
73 | # Case implicit use of the token is forbidden by env variable
74 | if HF_HUB_DISABLE_IMPLICIT_TOKEN
75 | return nil
76 | end
77 |
78 | # Otherwise: we use the cached token as the user has not explicitly forbidden it
79 | cached_token
80 | end
81 |
82 | def _validate_token_to_send(token, is_write_action)
83 | if is_write_action
84 | if token.nil?
85 | raise ArgumentError,
86 | "Token is required (write-access action) but no token found. You need" +
87 | " to provide a token or be logged in to Hugging Face with" +
88 | " `huggingface-cli login` or `huggingface_hub.login`. See" +
89 | " https://huggingface.co/settings/tokens."
90 | end
91 | end
92 | end
93 |
94 | def _http_user_agent(
95 | library_name: nil,
96 | library_version: nil,
97 | user_agent: nil
98 | )
99 | if !library_name.nil?
100 | ua = "#{library_name}/#{library_version}"
101 | else
102 | ua = "unknown/None"
103 | end
104 | ua += "; ruby/#{RUBY_VERSION.to_f}"
105 | ua
106 | end
107 | end
108 | end
109 | end
110 |
--------------------------------------------------------------------------------
/lib/transformers/image_processing_base.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class ImageProcessingMixin
17 | def self.from_pretrained(
18 | pretrained_model_name_or_path,
19 | cache_dir: nil,
20 | force_download: false,
21 | local_files_only: false,
22 | token: nil,
23 | revision: "main",
24 | **kwargs
25 | )
26 | kwargs[:cache_dir] = cache_dir
27 | kwargs[:force_download] = force_download
28 | kwargs[:local_files_only] = local_files_only
29 | kwargs[:revision] = revision
30 |
31 | if !token.nil?
32 | kwargs[:token] = token
33 | end
34 |
35 | image_processor_dict, kwargs = get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
36 |
37 | from_dict(image_processor_dict, **kwargs)
38 | end
39 |
40 | def self.get_image_processor_dict(
41 | pretrained_model_name_or_path, **kwargs
42 | )
43 | cache_dir = kwargs.delete(:cache_dir)
44 | force_download = kwargs.delete(:force_download) { false }
45 | resume_download = kwargs.delete(:resume_download)
46 | proxies = kwargs.delete(:proxies)
47 | token = kwargs.delete(:token)
48 | _use_auth_token = kwargs.delete(:use_auth_token)
49 | local_files_only = kwargs.delete(:local_files_only) { false }
50 | revision = kwargs.delete(:revision)
51 | subfolder = kwargs.delete(:subfolder) { "" }
52 |
53 | from_pipeline = kwargs.delete(:_from_pipeline)
54 | from_auto_class = kwargs.delete(:_from_auto) { false }
55 |
56 | user_agent = {file_type: "image processor", from_auto_class: from_auto_class}
57 | if !from_pipeline.nil?
58 | user_agent[:using_pipeline] = from_pipeline
59 | end
60 |
61 | if Utils::Hub.is_offline_mode && !local_files_only
62 | Transformers.logger.info("Offline mode: forcing local_files_only: true")
63 | local_files_only = true
64 | end
65 |
66 | pretrained_model_name_or_path = pretrained_model_name_or_path.to_s
67 | is_local = Dir.exist?(pretrained_model_name_or_path)
68 | if Dir.exist?(pretrained_model_name_or_path)
69 | image_processor_file = File.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
70 | end
71 | if File.exist?(pretrained_model_name_or_path)
72 | resolved_image_processor_file = pretrained_model_name_or_path
73 | is_local = true
74 | elsif Utils::Hub.is_remote_url(pretrained_model_name_or_path)
75 | raise Todo
76 | else
77 | image_processor_file = IMAGE_PROCESSOR_NAME
78 | begin
79 | # Load from local folder or from cache or download from model Hub and cache
80 | resolved_image_processor_file = Utils::Hub.cached_file(
81 | pretrained_model_name_or_path,
82 | image_processor_file,
83 | cache_dir: cache_dir,
84 | force_download: force_download,
85 | proxies: proxies,
86 | resume_download: resume_download,
87 | local_files_only: local_files_only,
88 | token: token,
89 | user_agent: user_agent,
90 | revision: revision,
91 | subfolder: subfolder
92 | )
93 | rescue EnvironmentError
94 | # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
95 | # the original exception.
96 | raise
97 | rescue
98 | # For any other exception, we throw a generic error.
99 | raise EnvironmentError,
100 | "Can't load image processor for '#{pretrained_model_name_or_path}'. If you were trying to load" +
101 | " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" +
102 | " same name. Otherwise, make sure '#{pretrained_model_name_or_path}' is the correct path to a" +
103 | " directory containing a #{IMAGE_PROCESSOR_NAME} file"
104 | end
105 | end
106 |
107 | begin
108 | image_processor_dict = JSON.load_file(resolved_image_processor_file).transform_keys(&:to_sym)
109 | rescue JSON::ParserError
110 | raise EnvironmentError,
111 | "It looks like the config file at '#{resolved_image_processor_file}' is not a valid JSON file."
112 | end
113 |
114 | if is_local
115 | Transformers.logger.info("loading configuration file #{resolved_image_processor_file}")
116 | else
117 | Transformers.logger.info(
118 | "loading configuration file #{image_processor_file} from cache at #{resolved_image_processor_file}"
119 | )
120 | end
121 |
122 | if !is_local
123 | if image_processor_dict.include?("auto_map")
124 | raise Todo
125 | end
126 | if image_processor_dict.include?("custom_pipelines")
127 | raise Todo
128 | end
129 | end
130 | [image_processor_dict, kwargs]
131 | end
132 |
133 | def self.from_dict(image_processor_dict, **kwargs)
134 | image_processor_dict = image_processor_dict.dup
135 | return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false }
136 |
137 | # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
138 | # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
139 | # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
140 | if kwargs.include?(:size) && image_processor_dict.include?(:size)
141 | image_processor_dict[:size] = kwargs.delete(:size)
142 | end
143 | if kwargs.include?(:crop_size) && image_processor_dict.include?(:crop_size)
144 | image_processor_dict[:crop_size] = kwargs.delete(:crop_size)
145 | end
146 |
147 | image_processor = new(**image_processor_dict)
148 |
149 | # Update image_processor with kwargs if needed
150 | to_remove = []
151 | kwargs.each do |key, value|
152 | if image_processor.instance_variable_defined?("@#{key}")
153 | image_processor.instance_variable_set("@#{key}", value)
154 | to_remove << key
155 | end
156 | end
157 | to_remove.each do |key|
158 | kwargs.delete(key)
159 | end
160 |
161 | Transformers.logger.info("Image processor #{image_processor}")
162 | if return_unused_kwargs
163 | [image_processor, kwargs]
164 | else
165 | image_processor
166 | end
167 | end
168 | end
169 | end
170 |
--------------------------------------------------------------------------------
/lib/transformers/image_processing_utils.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class BaseImageProcessor < ImageProcessingMixin
17 | def initialize(**kwargs)
18 | super(**kwargs)
19 | end
20 |
21 | def call(images, **kwargs)
22 | preprocess(images, **kwargs)
23 | end
24 |
25 | def preprocess(images, **kwargs)
26 | raise NotImplementedError, "Each image processor must implement its own preprocess method"
27 | end
28 |
29 | def rescale(
30 | image,
31 | scale,
32 | data_format: nil,
33 | input_data_format: nil,
34 | **kwargs
35 | )
36 | ImageTransforms.rescale(image, scale, data_format: data_format, input_data_format: input_data_format, **kwargs)
37 | end
38 |
39 | def normalize(
40 | image,
41 | mean,
42 | std,
43 | data_format: nil,
44 | input_data_format: nil,
45 | **kwargs
46 | )
47 | ImageTransforms.normalize(
48 | image, mean, std, data_format: data_format, input_data_format: input_data_format, **kwargs
49 | )
50 | end
51 | end
52 |
53 | module ImageProcessingUtils
54 | def self.get_size_dict(size)
55 | if !size.is_a?(Hash)
56 | size_dict = {height: size, width: size}
57 | else
58 | size_dict = size
59 | end
60 | size_dict
61 | end
62 | end
63 | end
64 |
--------------------------------------------------------------------------------
/lib/transformers/image_transforms.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module ImageTransforms
17 | def self.to_channel_dimension_format(
18 | image,
19 | channel_dim,
20 | input_channel_dim: nil
21 | )
22 | if !image.is_a?(Numo::NArray)
23 | raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
24 | end
25 |
26 | if input_channel_dim.nil?
27 | input_channel_dim = infer_channel_dimension_format(image)
28 | end
29 |
30 | target_channel_dim = ChannelDimension.new(channel_dim).to_s
31 | if input_channel_dim == target_channel_dim
32 | return image
33 | end
34 |
35 | if target_channel_dim == ChannelDimension::FIRST
36 | image = image.transpose(2, 0, 1)
37 | elsif target_channel_dim == ChannelDimension::LAST
38 | image = image.transpose(1, 2, 0)
39 | else
40 | raise ArgumentError, "Unsupported channel dimension format: #{channel_dim}"
41 | end
42 |
43 | image
44 | end
45 |
46 | def self.rescale(
47 | image,
48 | scale,
49 | data_format: nil,
50 | dtype: Numo::SFloat,
51 | input_data_format: nil
52 | )
53 | if !image.is_a?(Numo::NArray)
54 | raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
55 | end
56 |
57 | rescaled_image = image * scale
58 | if !data_format.nil?
59 | rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
60 | end
61 |
62 | rescaled_image = rescaled_image.cast_to(dtype)
63 |
64 | rescaled_image
65 | end
66 |
67 | def self.resize(
68 | image,
69 | size,
70 | resample: nil,
71 | reducing_gap: nil,
72 | data_format: nil,
73 | return_numpy: true,
74 | input_data_format: nil
75 | )
76 | resample = !resample.nil? ? resample : nil # PILImageResampling.BILINEAR
77 |
78 | if size.length != 2
79 | raise ArgumentError, "size must have 2 elements"
80 | end
81 |
82 | # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
83 | # The resized image from PIL will always have channels last, so find the input format first.
84 | if input_data_format.nil?
85 | input_data_format = ImageUtils.infer_channel_dimension_format(image)
86 | end
87 | data_format = data_format.nil? ? input_data_format : data_format
88 |
89 | # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
90 | # the pillow library to resize the image and then convert back to numpy
91 | do_rescale = false
92 | if !image.is_a?(Vips::Image)
93 | do_rescale = _rescale_for_pil_conversion(image)
94 | image = to_pil_image(image, do_rescale: do_rescale, input_data_format: input_data_format)
95 | end
96 | height, width = size
97 | # TODO support resample
98 | resized_image = image.thumbnail_image(width, height: height, size: :force)
99 |
100 | if return_numpy
101 | resized_image = ImageUtils.to_numo_array(resized_image)
102 | # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
103 | # so we need to add it back if necessary.
104 | resized_image = resized_image.ndim == 2 ? resized_image.expand_dims(-1) : resized_image
105 | # The image is always in channels last format after converting from a PIL image
106 | resized_image = to_channel_dimension_format(
107 | resized_image, data_format, input_channel_dim: ChannelDimension::LAST
108 | )
109 | # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
110 | # rescale it back to the original range.
111 | resized_image = do_rescale ? rescale(resized_image, 1 / 255.0) : resized_image
112 | end
113 | resized_image
114 | end
115 |
116 | def self.normalize(
117 | image,
118 | mean,
119 | std,
120 | data_format: nil,
121 | input_data_format: nil
122 | )
123 | if !image.is_a?(Numo::NArray)
124 | raise ArgumentError, "image must be a numpy array"
125 | end
126 |
127 | if input_data_format.nil?
128 | input_data_format = infer_channel_dimension_format(image)
129 | end
130 |
131 | channel_axis = ImageUtils.get_channel_dimension_axis(image, input_data_format: input_data_format)
132 | num_channels = image.shape[channel_axis]
133 |
134 | # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
135 | # We preserve the original dtype if it is a float type to prevent upcasting float16.
136 | if !image.is_a?(Numo::SFloat) && !image.is_a?(Numo::DFloat)
137 | image = image.cast_to(Numo::SFloat)
138 | end
139 |
140 | if mean.is_a?(Enumerable)
141 | if mean.length != num_channels
142 | raise ArgumentError, "mean must have #{num_channels} elements if it is an iterable, got #{mean.length}"
143 | end
144 | else
145 | mean = [mean] * num_channels
146 | end
147 | mean = Numo::DFloat.cast(mean)
148 |
149 | if std.is_a?(Enumerable)
150 | if std.length != num_channels
151 | raise ArgumentError, "std must have #{num_channels} elements if it is an iterable, got #{std.length}"
152 | end
153 | else
154 | std = [std] * num_channels
155 | end
156 | std = Numo::DFloat.cast(std)
157 |
158 | if input_data_format == ChannelDimension::LAST
159 | image = (image - mean) / std
160 | else
161 | image = ((image.transpose - mean) / std).transpose
162 | end
163 |
164 | image = !data_format.nil? ? to_channel_dimension_format(image, data_format, input_data_format) : image
165 | image
166 | end
167 |
168 | def self.to_pil_image(
169 | image,
170 | do_rescale: nil,
171 | input_data_format: nil
172 | )
173 | if image.is_a?(Vips::Image)
174 | return image
175 | end
176 |
177 | # Convert all tensors to numo arrays before converting to Vips image
178 | if !image.is_a?(Numo::NArray)
179 | raise ArgumentError, "Input image type not supported: #{image.class.name}"
180 | end
181 |
182 | # If the channel has been moved to first dim, we put it back at the end.
183 | image = to_channel_dimension_format(image, ChannelDimension::LAST, input_channel_dim: input_data_format)
184 |
185 | # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
186 | # image = image.shape[-1] == 1 ? image.squeeze(-1) : image
187 |
188 | # Rescale the image to be between 0 and 255 if needed.
189 | do_rescale = do_rescale.nil? ? _rescale_for_pil_conversion(image) : do_rescale
190 |
191 | if do_rescale
192 | image = rescale(image, 255)
193 | end
194 |
195 | image = image.cast_to(Numo::UInt8)
196 | Vips::Image.new_from_memory(image.to_binary, image.shape[1], image.shape[0], image.shape[2], :uchar)
197 | end
198 |
199 | def self._rescale_for_pil_conversion(image)
200 | if image.is_a?(Numo::UInt8)
201 | do_rescale = false
202 | else
203 | raise Todo
204 | end
205 | do_rescale
206 | end
207 | end
208 | end
209 |
--------------------------------------------------------------------------------
/lib/transformers/image_utils.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class ChannelDimension < ExplicitEnum
17 | FIRST = "channels_first"
18 | LAST = "channels_last"
19 | end
20 |
21 | module ImageUtils
22 | def self.load_image(image, timeout: nil)
23 | Utils.requires_backends(__method__, ["vision"])
24 | if image.is_a?(URI)
25 | require "open-uri"
26 |
27 | image = Vips::Image.new_from_buffer(image.read(open_timeout: timeout, read_timeout: timeout), "")
28 | elsif image.is_a?(String) && File.exist?(image)
29 | image = Vips::Image.new_from_file(image)
30 | elsif image.is_a?(Vips::Image)
31 | image = image
32 | else
33 | raise ArgumentError, "Incorrect format used for image"
34 | end
35 | image
36 | end
37 |
38 | def self.validate_preprocess_arguments(
39 | do_rescale: nil,
40 | rescale_factor: nil,
41 | do_normalize: nil,
42 | image_mean: nil,
43 | image_std: nil,
44 | do_pad: nil,
45 | size_divisibility: nil,
46 | do_center_crop: nil,
47 | crop_size: nil,
48 | do_resize: nil,
49 | size: nil,
50 | resample: nil
51 | )
52 | if do_rescale && rescale_factor.nil?
53 | raise ArgumentError, "`rescale_factor` must be specified if `do_rescale` is `true`."
54 | end
55 |
56 | if do_pad && size_divisibility.nil?
57 | # Here, size_divisor might be passed as the value of size
58 | raise ArgumentError, "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `true`."
59 | end
60 |
61 | if do_normalize && (image_mean.nil? || image_std.nil?)
62 | raise ArgumentError, "`image_mean` and `image_std` must both be specified if `do_normalize` is `true`."
63 | end
64 |
65 | if do_center_crop && crop_size.nil?
66 | raise ArgumentError, "`crop_size` must be specified if `do_center_crop` is `true`."
67 | end
68 |
69 | if do_resize && (size.nil? || resample.nil?)
70 | raise ArgumentError, "`size` and `resample` must be specified if `do_resize` is `true`."
71 | end
72 | end
73 |
74 | def self.make_list_of_images(images, expected_ndims: 3)
75 | # TODO improve
76 | images.is_a?(Array) ? images : [images]
77 | end
78 |
79 | def self.to_numo_array(img)
80 | Numo::UInt8.from_binary(img.write_to_memory, [img.height, img.width, img.bands])
81 | end
82 |
83 | def self.infer_channel_dimension_format(
84 | image, num_channels: nil
85 | )
86 | num_channels = !num_channels.nil? ? num_channels : [1, 3]
87 | num_channels = num_channels.is_a?(Integer) ? [num_channels] : num_channels
88 |
89 | if image.ndim == 3
90 | first_dim, last_dim = 0, 2
91 | elsif image.ndim == 4
92 | first_dim, last_dim = 1, 3
93 | else
94 | raise ArgumentError, "Unsupported number of image dimensions: #{image.ndim}"
95 | end
96 |
97 | if num_channels.include?(image.shape[first_dim]) && num_channels.include?(image.shape[last_dim])
98 | Transformers.logger.warn(
99 | "The channel dimension is ambiguous. Got image shape #{image.shape}. Assuming channels are the first dimension."
100 | )
101 | return ChannelDimension::FIRST
102 | elsif num_channels.include?(image.shape[first_dim])
103 | return ChannelDimension::FIRST
104 | elsif num_channels.include?(image.shape[last_dim])
105 | return ChannelDimension::LAST
106 | end
107 | raise ArgumentError, "Unable to infer channel dimension format"
108 | end
109 |
110 | def self.get_channel_dimension_axis(
111 | image, input_data_format: nil
112 | )
113 | if input_data_format.nil?
114 | input_data_format = infer_channel_dimension_format(image)
115 | end
116 | if input_data_format == ChannelDimension::FIRST
117 | return image.ndim - 3
118 | elsif input_data_format == ChannelDimension::LAST
119 | return image.ndim - 1
120 | end
121 | raise ArgumentError, "Unsupported data format: #{input_data_format}"
122 | end
123 |
124 | def self.is_vips_image(img)
125 | Utils.is_vision_available && img.is_a?(Vips::Image)
126 | end
127 |
128 | def self.is_valid_image(img)
129 | is_vips_image(img) || is_numo_array(img) || is_torch_tensor(img)
130 | end
131 |
132 | def self.valid_images(imgs)
133 | # If we have an list of images, make sure every image is valid
134 | if imgs.is_a?(Array)
135 | imgs.each do |img|
136 | if !valid_images(img)
137 | return false
138 | end
139 | end
140 | # If not a list of tuple, we have been given a single image or batched tensor of images
141 | elsif !is_valid_image(imgs)
142 | return false
143 | end
144 | true
145 | end
146 |
147 | def self.is_scaled_image(image)
148 | if image.is_a?(Numo::UInt8)
149 | return false
150 | end
151 |
152 | # It's possible the image has pixel values in [0, 255] but is of floating type
153 | image.min >= 0 && image.max <= 1
154 | end
155 |
156 | def self.validate_kwargs(valid_processor_keys:, captured_kwargs:)
157 | unused_keys = Set.new(captured_kwargs).difference(Set.new(valid_processor_keys))
158 | if unused_keys.any?
159 | unused_key_str = unused_keys.join(", ")
160 | # TODO raise a warning here instead of simply logging?
161 | Transformers.logger.warn("Unused or unrecognized kwargs: #{unused_key_str}.")
162 | end
163 | end
164 | end
165 | end
166 |
--------------------------------------------------------------------------------
/lib/transformers/modeling_outputs.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class BaseModelOutput < ModelOutput
17 | attribute :last_hidden_state
18 | attribute :hidden_states
19 | attribute :attentions
20 | end
21 |
22 | class BaseModelOutputWithPooling < ModelOutput
23 | attribute :last_hidden_state
24 | attribute :pooler_output
25 | attribute :hidden_states
26 | attribute :attentions
27 | end
28 |
29 | class BaseModelOutputWithPoolingAndCrossAttentions < ModelOutput
30 | attribute :last_hidden_state
31 | attribute :pooler_output
32 | attribute :hidden_states
33 | attribute :past_key_values
34 | attribute :attentions
35 | attribute :cross_attentions
36 | end
37 |
38 | class BaseModelOutputWithPastAndCrossAttentions < ModelOutput
39 | attribute :last_hidden_state
40 | attribute :past_key_values
41 | attribute :hidden_states
42 | attribute :attentions
43 | attribute :cross_attentions
44 | end
45 |
46 | class MaskedLMOutput < ModelOutput
47 | attribute :loss
48 | attribute :logits
49 | attribute :hidden_states
50 | attribute :attentions
51 | end
52 |
53 | class SequenceClassifierOutput < ModelOutput
54 | attribute :loss
55 | attribute :logits
56 | attribute :hidden_states
57 | attribute :attentions
58 | end
59 |
60 | class TokenClassifierOutput < ModelOutput
61 | attribute :loss
62 | attribute :logits
63 | attribute :hidden_states
64 | attribute :attentions
65 | end
66 |
67 | class QuestionAnsweringModelOutput < ModelOutput
68 | attribute :loss
69 | attribute :start_logits
70 | attribute :end_logits
71 | attribute :hidden_states
72 | attribute :attentions
73 | end
74 |
75 | class ImageClassifierOutput < ModelOutput
76 | attribute :loss
77 | attribute :logits
78 | attribute :hidden_states
79 | attribute :attentions
80 | end
81 | end
82 |
--------------------------------------------------------------------------------
/lib/transformers/models/auto/auto_factory.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class BaseAutoModelClass
17 | extend ClassAttribute
18 |
19 | class_attribute :_model_mapping
20 |
21 | class << self
22 | def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
23 | config = kwargs.delete(:config)
24 | trust_remote_code = kwargs.delete(:trust_remote_code)
25 | hub_kwargs_names = [
26 | :cache_dir,
27 | :force_download,
28 | :local_files_only,
29 | :proxies,
30 | :resume_download,
31 | :revision,
32 | :subfolder,
33 | :use_auth_token,
34 | :token
35 | ]
36 | hub_kwargs = hub_kwargs_names.select { |k| kwargs.key?(k) }.to_h { |name| [name, kwargs.delete(name)] }
37 | code_revision = kwargs.delete(:code_revision)
38 | commit_hash = kwargs.delete(:_commit_hash)
39 |
40 | if commit_hash.nil?
41 | if !config.is_a?(PretrainedConfig)
42 | # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
43 | resolved_config_file = Utils::Hub.cached_file(
44 | pretrained_model_name_or_path,
45 | CONFIG_NAME,
46 | _raise_exceptions_for_gated_repo: false,
47 | _raise_exceptions_for_missing_entries: false,
48 | _raise_exceptions_for_connection_errors: false,
49 | **hub_kwargs
50 | )
51 | commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
52 | else
53 | raise Todo
54 | end
55 | end
56 |
57 | if !config.is_a?(PretrainedConfig)
58 | config, kwargs =
59 | AutoConfig.from_pretrained(
60 | pretrained_model_name_or_path,
61 | return_unused_kwargs: true,
62 | trust_remote_code: trust_remote_code,
63 | code_revision: code_revision,
64 | _commit_hash: commit_hash,
65 | **hub_kwargs,
66 | **kwargs
67 | )
68 | end
69 |
70 | model_class = _get_model_class(config, _model_mapping)
71 | model_class.from_pretrained(
72 | pretrained_model_name_or_path, *model_args, config: config, **hub_kwargs, **kwargs
73 | )
74 | end
75 |
76 | private
77 |
78 | def _get_model_class(config, model_mapping)
79 | supported_models = model_mapping[config.class.name.split("::").last]
80 | if !supported_models.is_a?(Array)
81 | return supported_models
82 | end
83 |
84 | raise Todo
85 | end
86 | end
87 | end
88 |
89 | class LazyAutoMapping
90 | def initialize(config_mapping, model_mapping)
91 | @config_mapping = config_mapping
92 | @reverse_config_mapping = config_mapping.invert
93 | @model_mapping = model_mapping
94 | @modules = {}
95 | end
96 |
97 | def [](key)
98 | model_type = @reverse_config_mapping[key]
99 | if @model_mapping[model_type]
100 | model_name = @model_mapping[model_type]
101 | return _load_attr_from_module(model_type, model_name)
102 | end
103 |
104 | raise KeyError, key
105 | end
106 |
107 | def include?(key)
108 | self[key]
109 | true
110 | rescue KeyError
111 | false
112 | end
113 |
114 | private
115 |
116 | def _load_attr_from_module(model_type, attr)
117 | module_name = model_type_to_module_name(model_type)
118 | if !@modules.include?(module_name)
119 | @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
120 | end
121 | getattribute_from_module(@modules[module_name], attr)
122 | end
123 |
124 | def getattribute_from_module(mod, attr)
125 | if attr.nil?
126 | nil
127 | elsif attr.is_a?(Array)
128 | attr.map { |a| mod.const_get(a) }
129 | else
130 | mod.const_get(attr)
131 | end
132 | end
133 |
134 | def model_type_to_module_name(key)
135 | key
136 | end
137 | end
138 | end
139 |
--------------------------------------------------------------------------------
/lib/transformers/models/auto/configuration_auto.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | CONFIG_MAPPING_NAMES = {
17 | "bert" => "BertConfig",
18 | "deberta-v2" => "DebertaV2Config",
19 | "distilbert" => "DistilBertConfig",
20 | "mpnet" => "MPNetConfig",
21 | "vit" => "ViTConfig",
22 | "xlm-roberta" => "XLMRobertaConfig"
23 | }
24 |
25 | class LazyConfigMapping
26 | def initialize(mapping)
27 | @mapping = mapping
28 | @extra_content = {}
29 | @modules = {}
30 | end
31 |
32 | def [](key)
33 | value = @mapping.fetch(key)
34 | module_name = model_type_to_module_name(key)
35 | if !@modules.include?(module_name)
36 | @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
37 | end
38 | @modules[module_name].const_get(value)
39 | end
40 |
41 | def model_type_to_module_name(key)
42 | key
43 | end
44 | end
45 |
46 | CONFIG_MAPPING = LazyConfigMapping.new(CONFIG_MAPPING_NAMES)
47 |
48 | class AutoConfig
49 | def self.from_pretrained(pretrained_model_name_or_path, **kwargs)
50 | kwargs[:_from_auto] = true
51 | kwargs[:name_or_path] = pretrained_model_name_or_path
52 | _trust_remote_code = kwargs.delete(:trust_remote_code)
53 | _code_revision = kwargs.delete(:code_revision)
54 |
55 | config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
56 | if config_dict[:model_type]
57 | config_class = CONFIG_MAPPING[config_dict[:model_type]]
58 | config_class.from_dict(config_dict, **unused_kwargs)
59 | else
60 | raise Todo
61 | end
62 | end
63 | end
64 | end
65 |
--------------------------------------------------------------------------------
/lib/transformers/models/auto/feature_extraction_auto.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | FEATURE_EXTRACTOR_MAPPING_NAMES = {
17 | }
18 |
19 | FEATURE_EXTRACTOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
20 | end
21 |
--------------------------------------------------------------------------------
/lib/transformers/models/auto/image_processing_auto.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | IMAGE_PROCESSOR_MAPPING_NAMES = {
17 | "vit" => ["ViTImageProcessor"]
18 | }
19 |
20 | IMAGE_PROCESSOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
21 |
22 | class AutoImageProcessor
23 | def self.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
24 | config = kwargs.delete(:config)
25 | use_fast = kwargs.delete(:use_fast)
26 | trust_remote_code = kwargs.delete(:trust_remote_code)
27 | kwargs[:_from_auto] = true
28 |
29 | config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
30 | image_processor_class = config_dict[:image_processor_type]
31 | image_processor_auto_map = nil
32 | if (config_dict[:auto_map] || {}).include?("AutoImageProcessor")
33 | image_processor_auto_map = config_dict[:auto_map]["AutoImageProcessor"]
34 | end
35 |
36 | # If we still don't have the image processor class, check if we're loading from a previous feature extractor config
37 | # and if so, infer the image processor class from there.
38 | if image_processor_class.nil? && image_processor_auto_map.nil?
39 | feature_extractor_class = config_dict.delete(:feature_extractor_type)
40 | if !feature_extractor_class.nil?
41 | image_processor_class = feature_extractor_class.sub("FeatureExtractor", "ImageProcessor")
42 | end
43 | if (config_dict[:auto_map] || {}).include?("AutoFeatureExtractor")
44 | feature_extractor_auto_map = config_dict[:auto_map]["AutoFeatureExtractor"]
45 | image_processor_auto_map = feature_extractor_auto_map.sub("FeatureExtractor", "ImageProcessor")
46 | end
47 | end
48 |
49 | # If we don't find the image processor class in the image processor config, let's try the model config.
50 | if image_processor_class.nil? && image_processor_auto_map.nil?
51 | if !config.is_a?(PretrainedConfig)
52 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
53 | end
54 | # It could be in `config.image_processor_type``
55 | image_processor_class = config.instance_variable_get(:@image_processor_type)
56 | end
57 |
58 | if !image_processor_class.nil?
59 | raise Todo
60 | end
61 |
62 | has_remote_code = !image_processor_auto_map.nil?
63 | has_local_code = !image_processor_class.nil? || IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last)
64 | trust_remote_code = DynamicModuleUtils.resolve_trust_remote_code(
65 | trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
66 | )
67 |
68 | if !image_processor_auto_map.nil? && !image_processor_auto_map.is_a?(Array)
69 | raise Todo
70 | end
71 |
72 | if has_remote_code && trust_remote_code
73 | raise Todo
74 | elsif !image_processor_class.nil?
75 | return image_processor_class.from_dict(config_dict, **kwargs)
76 | # Last try: we use the IMAGE_PROCESSOR_MAPPING.
77 | elsif IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last)
78 | image_processor_tuple = IMAGE_PROCESSOR_MAPPING[config.class.name.split("::").last]
79 |
80 | image_processor_class_py, image_processor_class_fast = image_processor_tuple
81 |
82 | if !use_fast && !image_processor_class_fast.nil?
83 | _warning_fast_image_processor_available(image_processor_class_fast)
84 | end
85 |
86 | if image_processor_class_fast && (use_fast || image_processor_class_py.nil?)
87 | return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
88 | else
89 | if !image_processor_class_py.nil?
90 | return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
91 | else
92 | raise ArgumentError,
93 | "This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
94 | end
95 | end
96 | end
97 |
98 | raise ArgumentError,
99 | "Unrecognized image processor in #{pretrained_model_name_or_path}. Should have a " +
100 | "`image_processor_type` key in its #{IMAGE_PROCESSOR_NAME} of #{CONFIG_NAME}, or one of the following " +
101 | "`model_type` keys in its #{CONFIG_NAME}: #{IMAGE_PROCESSOR_MAPPING_NAMES.keys.join(", ")}"
102 | end
103 | end
104 | end
105 |
--------------------------------------------------------------------------------
/lib/transformers/models/auto/modeling_auto.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | MODEL_MAPPING_NAMES = {
17 | "bert" => "BertModel",
18 | "deberta-v2" => "DebertaV2Model",
19 | "distilbert" => "DistilBertModel",
20 | "mpnet" => "MPNetModel",
21 | "vit" => "ViTModel",
22 | "xlm-roberta" => "XLMRobertaModel"
23 | }
24 |
25 | MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
26 | "bert" => "BertForMaskedLM",
27 | "mpnet" => "MPNetForMaskedLM"
28 | }
29 |
30 | MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
31 | "bert" => "BertForSequenceClassification",
32 | "deberta-v2" => "DebertaV2ForSequenceClassification",
33 | "distilbert" => "DistilBertForSequenceClassification",
34 | "xlm-roberta" => "XLMRobertaForSequenceClassification"
35 | }
36 |
37 | MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
38 | "distilbert" => "DistilBertForQuestionAnswering"
39 | }
40 |
41 | MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
42 | "vit" => "ViTForImageClassification"
43 | }
44 |
45 | MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
46 | "bert" => "BertForTokenClassification"
47 | }
48 |
49 | MODEL_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
50 | MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
51 | CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
52 | )
53 | MODEL_FOR_MASKED_LM_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
54 | MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
55 | CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
56 | )
57 | MODEL_FOR_QUESTION_ANSWERING_MAPPING = LazyAutoMapping.new(
58 | CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
59 | )
60 | MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
61 | CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
62 | )
63 |
64 | class AutoModel < BaseAutoModelClass
65 | self._model_mapping = MODEL_MAPPING
66 | end
67 |
68 | class AutoModelForMaskedLM < BaseAutoModelClass
69 | self._model_mapping = MODEL_FOR_MASKED_LM_MAPPING
70 | end
71 |
72 | class AutoModelForSequenceClassification < BaseAutoModelClass
73 | self._model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
74 | end
75 |
76 | class AutoModelForQuestionAnswering < BaseAutoModelClass
77 | self._model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
78 | end
79 |
80 | class AutoModelForTokenClassification < BaseAutoModelClass
81 | self._model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
82 | end
83 |
84 | class AutoModelForImageClassification < BaseAutoModelClass
85 | self._model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
86 | end
87 | end
88 |
--------------------------------------------------------------------------------
/lib/transformers/models/auto/tokenization_auto.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | TOKENIZER_MAPPING_NAMES = {
17 | "bert" => ["BertTokenizer", "BertTokenizerFast"],
18 | "deberta-v2" => ["DebertaV2TokenizerFast"],
19 | "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"],
20 | "mpnet" => ["MPNetTokenizerFast"],
21 | "xlm-roberta" => ["XLMRobertaTokenizerFast"]
22 | }
23 |
24 | TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
25 |
26 | class AutoTokenizer
27 | class << self
28 | def from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
29 | config = kwargs.delete(:config)
30 | kwargs[:_from_auto] = true
31 |
32 | use_fast = kwargs.delete(:use_fast) { true }
33 | tokenizer_type = kwargs.delete(:tokenizer_type) { nil }
34 | trust_remote_code = kwargs.delete(:trust_remote_code)
35 |
36 | if !tokenizer_type.nil?
37 | raise Todo
38 | end
39 |
40 | tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
41 | if tokenizer_config.include?("_commit_hash")
42 | kwargs[:_commit_hash] = tokenizer_config["_commit_hash"]
43 | end
44 | config_tokenizer_class = tokenizer_config["tokenizer_class"]
45 | _tokenizer_auto_map = nil
46 | if tokenizer_config["auto_map"]
47 | raise Todo
48 | end
49 |
50 | # If that did not work, let's try to use the config.
51 | if config_tokenizer_class.nil?
52 | if !config.is_a?(PretrainedConfig)
53 | config = AutoConfig.from_pretrained(
54 | pretrained_model_name_or_path, trust_remote_code: trust_remote_code, **kwargs
55 | )
56 | config_tokenizer_class = config.tokenizer_class
57 | # if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
58 | # tokenizer_auto_map = config.auto_map["AutoTokenizer"]
59 | end
60 | end
61 |
62 | if !config_tokenizer_class.nil?
63 | tokenizer_class = nil
64 | if use_fast && !config_tokenizer_class.end_with?("Fast")
65 | tokenizer_class_candidate = "#{config_tokenizer_class}Fast"
66 | tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
67 | end
68 | if tokenizer_class.nil?
69 | tokenizer_class_candidate = config_tokenizer_class
70 | tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
71 | end
72 | if tokenizer_class.nil?
73 | raise ArgumentError, "Tokenizer class #{tokenizer_class_candidate} does not exist or is not currently imported."
74 | end
75 | return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
76 | end
77 |
78 | model_type = config_class_to_model_type(config.class.name.split("::").last)
79 | if !model_type.nil?
80 | tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[config.class.name.split("::").last]
81 | if tokenizer_class_fast && (use_fast || tokenizer_class_py.nil?)
82 | return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
83 | else
84 | if !tokenizer_class_py.nil?
85 | return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
86 | else
87 | raise ArgumentError, "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed in order to use this tokenizer."
88 | end
89 | end
90 | end
91 |
92 | raise ArgumentError, "Unrecognized configuration class #{config.class.name} to build an AutoTokenizer."
93 | end
94 |
95 | private
96 |
97 | def tokenizer_class_from_name(class_name)
98 | if class_name == "PreTrainedTokenizerFast"
99 | return PreTrainedTokenizerFast
100 | end
101 |
102 | TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
103 | if tokenizers.include?(class_name)
104 | cls = Transformers.const_get(module_name.split("-").map(&:capitalize).join).const_get(class_name)
105 | raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
106 | return cls
107 | end
108 | end
109 |
110 | raise Todo
111 | end
112 |
113 | def get_tokenizer_config(
114 | pretrained_model_name_or_path,
115 | cache_dir: nil,
116 | force_download: false,
117 | resume_download: false,
118 | proxies: nil,
119 | token: nil,
120 | revision: nil,
121 | local_files_only: false,
122 | subfolder: "",
123 | **kwargs
124 | )
125 | commit_hash = kwargs[:_commit_hash]
126 | resolved_config_file = Utils::Hub.cached_file(
127 | pretrained_model_name_or_path,
128 | TOKENIZER_CONFIG_FILE,
129 | cache_dir: cache_dir,
130 | force_download: force_download,
131 | resume_download: resume_download,
132 | proxies: proxies,
133 | token: token,
134 | revision: revision,
135 | local_files_only: local_files_only,
136 | subfolder: subfolder,
137 | _raise_exceptions_for_gated_repo: false,
138 | _raise_exceptions_for_missing_entries: false,
139 | _raise_exceptions_for_connection_errors: false,
140 | _commit_hash: commit_hash
141 | )
142 | if resolved_config_file.nil?
143 | Transformers.logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
144 | return {}
145 | end
146 | commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
147 |
148 | result = JSON.load_file(resolved_config_file)
149 | result["_commit_hash"] = commit_hash
150 | result
151 | end
152 |
153 | def config_class_to_model_type(config)
154 | CONFIG_MAPPING_NAMES.each do |key, cls|
155 | if cls == config
156 | return key
157 | end
158 | end
159 | nil
160 | end
161 | end
162 | end
163 | end
164 |
--------------------------------------------------------------------------------
/lib/transformers/models/bert/configuration_bert.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | module Transformers
17 | module Bert
18 | class BertConfig < PretrainedConfig
19 | self.model_type = "bert"
20 |
21 | attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
22 | :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
23 | :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
24 | :position_embedding_type, :use_cache, :classifier_dropout
25 |
26 | def initialize(
27 | vocab_size: 30522,
28 | hidden_size: 768,
29 | num_hidden_layers: 12,
30 | num_attention_heads: 12,
31 | intermediate_size: 3072,
32 | hidden_act: "gelu",
33 | hidden_dropout_prob: 0.1,
34 | attention_probs_dropout_prob: 0.1,
35 | max_position_embeddings: 512,
36 | type_vocab_size: 2,
37 | initializer_range: 0.02,
38 | layer_norm_eps: 1e-12,
39 | pad_token_id: 0,
40 | position_embedding_type: "absolute",
41 | use_cache: true,
42 | classifier_dropout: nil,
43 | **kwargs
44 | )
45 | super(pad_token_id: pad_token_id, **kwargs)
46 |
47 | @vocab_size = vocab_size
48 | @hidden_size = hidden_size
49 | @num_hidden_layers = num_hidden_layers
50 | @num_attention_heads = num_attention_heads
51 | @hidden_act = hidden_act
52 | @intermediate_size = intermediate_size
53 | @hidden_dropout_prob = hidden_dropout_prob
54 | @attention_probs_dropout_prob = attention_probs_dropout_prob
55 | @max_position_embeddings = max_position_embeddings
56 | @type_vocab_size = type_vocab_size
57 | @initializer_range = initializer_range
58 | @layer_norm_eps = layer_norm_eps
59 | @position_embedding_type = position_embedding_type
60 | @use_cache = use_cache
61 | @classifier_dropout = classifier_dropout
62 | end
63 | end
64 | end
65 | end
66 |
--------------------------------------------------------------------------------
/lib/transformers/models/bert/tokenization_bert.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module Bert
17 | class BertTokenizer < PreTrainedTokenizer
18 | class BasicTokenizer
19 | attr_reader :do_lower_case, :tokenize_chinese_chars, :never_split, :strip_accents, :do_split_on_punc
20 |
21 | def initialize(
22 | do_lower_case: true,
23 | never_split: nil,
24 | tokenize_chinese_chars: true,
25 | strip_accents: nil,
26 | do_split_on_punc: true
27 | )
28 | if never_split.nil?
29 | never_split = []
30 | end
31 | @do_lower_case = do_lower_case
32 | @never_split = Set.new(never_split)
33 | @tokenize_chinese_chars = tokenize_chinese_chars
34 | @strip_accents = strip_accents
35 | @do_split_on_punc = do_split_on_punc
36 | end
37 | end
38 |
39 | class WordpieceTokenizer
40 | def initialize(vocab:, unk_token:, max_input_chars_per_word: 100)
41 | @vocab = vocab
42 | @unk_token = unk_token
43 | @max_input_chars_per_word = max_input_chars_per_word
44 | end
45 | end
46 |
47 | attr_reader :vocab, :basic_tokenizer
48 |
49 | def initialize(
50 | vocab_file:,
51 | do_lower_case: true,
52 | do_basic_tokenize: true,
53 | never_split: nil,
54 | unk_token: "[UNK]",
55 | sep_token: "[SEP]",
56 | pad_token: "[PAD]",
57 | cls_token: "[CLS]",
58 | mask_token: "[MASK]",
59 | tokenize_chinese_chars: true,
60 | strip_accents: nil,
61 | **kwargs
62 | )
63 | if !File.exist?(vocab_file)
64 | raise ArgumentError,
65 | "Can't find a vocabulary file at path '#{vocab_file}'. To load the vocabulary from a Google pretrained" +
66 | " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
67 | end
68 | @vocab = load_vocab(vocab_file)
69 | @ids_to_tokens = @vocab.invert
70 | @do_basic_tokenize = do_basic_tokenize
71 | if do_basic_tokenize
72 | @basic_tokenizer =
73 | BasicTokenizer.new(
74 | do_lower_case: do_lower_case,
75 | never_split: never_split,
76 | tokenize_chinese_chars: tokenize_chinese_chars,
77 | strip_accents: strip_accents
78 | )
79 | end
80 |
81 | @wordpiece_tokenizer = WordpieceTokenizer.new(vocab: @vocab, unk_token: unk_token.to_s)
82 |
83 | super(
84 | do_lower_case: do_lower_case,
85 | do_basic_tokenize: do_basic_tokenize,
86 | never_split: never_split,
87 | unk_token: unk_token,
88 | sep_token: sep_token,
89 | pad_token: pad_token,
90 | cls_token: cls_token,
91 | mask_token: mask_token,
92 | tokenize_chinese_chars: tokenize_chinese_chars,
93 | strip_accents: strip_accents,
94 | **kwargs
95 | )
96 | end
97 |
98 | def _convert_token_to_id(token)
99 | @vocab.fetch(token, @vocab.fetch(@unk_token))
100 | end
101 |
102 | private
103 |
104 | def load_vocab(vocab_file)
105 | vocab = {}
106 | tokens = File.readlines(vocab_file)
107 | tokens.each_with_index do |token, index|
108 | token = token.chomp
109 | vocab[token] = index
110 | end
111 | vocab
112 | end
113 | end
114 | end
115 | end
116 |
--------------------------------------------------------------------------------
/lib/transformers/models/bert/tokenization_bert_fast.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module Bert
17 | class BertTokenizerFast < PreTrainedTokenizerFast
18 | VOCAB_FILES_NAMES = {vocab_file: "vocab.txt", tokenizer_file: "tokenizer.json"}
19 |
20 | self.vocab_files_names = VOCAB_FILES_NAMES
21 | self.slow_tokenizer_class = BertTokenizer
22 |
23 | def initialize(
24 | vocab_file: nil,
25 | tokenizer_file: nil,
26 | do_lower_case: true,
27 | unk_token: "[UNK]",
28 | sep_token: "[SEP]",
29 | pad_token: "[PAD]",
30 | cls_token: "[CLS]",
31 | mask_token: "[MASK]",
32 | tokenize_chinese_chars: true,
33 | strip_accents: nil,
34 | **kwargs
35 | )
36 | super(
37 | vocab_file,
38 | tokenizer_file: tokenizer_file,
39 | do_lower_case: do_lower_case,
40 | unk_token: unk_token,
41 | sep_token: sep_token,
42 | pad_token: pad_token,
43 | cls_token: cls_token,
44 | mask_token: mask_token,
45 | tokenize_chinese_chars: tokenize_chinese_chars,
46 | strip_accents: strip_accents,
47 | **kwargs
48 | )
49 | end
50 | end
51 | end
52 | end
53 |
--------------------------------------------------------------------------------
/lib/transformers/models/deberta_v2/configuration_deberta_v2.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2020, Microsoft and the HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module DebertaV2
17 | class DebertaV2Config < PretrainedConfig
18 | self.model_type = "deberta-v2"
19 |
20 | attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
21 | :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
22 | :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
23 | :relative_attention, :max_relative_positions, :pad_token_id, :position_biased_input,
24 | :pos_att_type, :pooler_dropout, :pooler_hidden_act, :pooler_hidden_size
25 |
26 | def initialize(
27 | vocab_size: 128100,
28 | hidden_size: 1536,
29 | num_hidden_layers: 24,
30 | num_attention_heads: 24,
31 | intermediate_size: 6144,
32 | hidden_act: "gelu",
33 | hidden_dropout_prob: 0.1,
34 | attention_probs_dropout_prob: 0.1,
35 | max_position_embeddings: 512,
36 | type_vocab_size: 0,
37 | initializer_range: 0.02,
38 | layer_norm_eps: 1e-07,
39 | relative_attention: false,
40 | max_relative_positions: -1,
41 | pad_token_id: 0,
42 | position_biased_input: true,
43 | pos_att_type: nil,
44 | pooler_dropout: 0,
45 | pooler_hidden_act: "gelu",
46 | **kwargs
47 | )
48 | super(**kwargs)
49 |
50 | @hidden_size = hidden_size
51 | @num_hidden_layers = num_hidden_layers
52 | @num_attention_heads = num_attention_heads
53 | @intermediate_size = intermediate_size
54 | @hidden_act = hidden_act
55 | @hidden_dropout_prob = hidden_dropout_prob
56 | @attention_probs_dropout_prob = attention_probs_dropout_prob
57 | @max_position_embeddings = max_position_embeddings
58 | @type_vocab_size = type_vocab_size
59 | @initializer_range = initializer_range
60 | @relative_attention = relative_attention
61 | @max_relative_positions = max_relative_positions
62 | @pad_token_id = pad_token_id
63 | @position_biased_input = position_biased_input
64 |
65 | # Backwards compatibility
66 | if pos_att_type.is_a?(String)
67 | pos_att_type = pos_att_type.downcase.split("|").map { |x| x.strip }
68 | end
69 |
70 | @pos_att_type = pos_att_type
71 | @vocab_size = vocab_size
72 | @layer_norm_eps = layer_norm_eps
73 |
74 | @pooler_hidden_size = kwargs[:pooler_hidden_size] || hidden_size
75 | @pooler_dropout = pooler_dropout
76 | @pooler_hidden_act = pooler_hidden_act
77 | end
78 | end
79 | end
80 | end
81 |
--------------------------------------------------------------------------------
/lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Microsoft and the HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module DebertaV2
17 | class DebertaV2TokenizerFast < PreTrainedTokenizerFast
18 | VOCAB_FILES_NAMES = {vocab_file: "spm.model", tokenizer_file: "tokenizer.json"}
19 |
20 | self.vocab_files_names = VOCAB_FILES_NAMES
21 | # self.slow_tokenizer_class = DebertaV2Tokenizer
22 |
23 | def initialize(
24 | vocab_file: nil,
25 | tokenizer_file: nil,
26 | do_lower_case: false,
27 | split_by_punct: false,
28 | bos_token: "[CLS]",
29 | eos_token: "[SEP]",
30 | unk_token: "[UNK]",
31 | sep_token: "[SEP]",
32 | pad_token: "[PAD]",
33 | cls_token: "[CLS]",
34 | mask_token: "[MASK]",
35 | **kwargs
36 | )
37 | super(vocab_file, tokenizer_file: tokenizer_file, do_lower_case: do_lower_case, bos_token: bos_token, eos_token: eos_token, unk_token: unk_token, sep_token: sep_token, pad_token: pad_token, cls_token: cls_token, mask_token: mask_token, split_by_punct: split_by_punct, **kwargs)
38 |
39 | @do_lower_case = do_lower_case
40 | @split_by_punct = split_by_punct
41 | @vocab_file = vocab_file
42 | end
43 |
44 | def can_save_slow_tokenizer
45 | @vocab_file ? File.exist?(@vocab_file) : false
46 | end
47 |
48 | def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil)
49 | if token_ids_1.nil?
50 | return [@cls_token_id] + token_ids_0 + [@sep_token_id]
51 | end
52 | cls = [@cls_token_id]
53 | sep = [@sep_token_id]
54 | cls + token_ids_0 + sep + token_ids_1 + sep
55 | end
56 |
57 | def get_special_tokens_mask(token_ids_0, token_ids_1: nil, already_has_special_tokens: false)
58 | if already_has_special_tokens
59 | return super(token_ids_0: token_ids_0, token_ids_1: token_ids_1, already_has_special_tokens: true)
60 | end
61 |
62 | if !token_ids_1.nil?
63 | return [1] + ([0] * token_ids_0.length) + [1] + ([0] * token_ids_1.length) + [1]
64 | end
65 | [1] + ([0] * token_ids_0.length) + [1]
66 | end
67 |
68 | def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil)
69 | sep = [@sep_token_id]
70 | cls = [@cls_token_id]
71 | if token_ids_1.nil?
72 | return (cls + token_ids_0 + sep).length * [0]
73 | end
74 | ((cls + token_ids_0 + sep).length * [0]) + ((token_ids_1 + sep).length * [1])
75 | end
76 | end
77 | end
78 | end
79 |
--------------------------------------------------------------------------------
/lib/transformers/models/distilbert/configuration_distilbert.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module Distilbert
17 | class DistilBertConfig < PretrainedConfig
18 | self.model_type = "distilbert"
19 | self.attribute_map = {
20 | hidden_size: "dim",
21 | num_attention_heads: "n_heads",
22 | num_hidden_layers: "n_layers"
23 | }
24 |
25 | attr_reader :vocab_size, :max_position_embeddings, :sinusoidal_pos_embds, :n_layers, :n_heads,
26 | :dim, :hidden_dim, :dropout, :attention_dropout, :activation, :initializer_range, :qa_dropout,
27 | :seq_classif_dropout, :pad_token_id
28 |
29 | def initialize(
30 | vocab_size: 30522,
31 | max_position_embeddings: 512,
32 | sinusoidal_pos_embds: false,
33 | n_layers: 6,
34 | n_heads: 12,
35 | dim: 768,
36 | hidden_dim: 4 * 768,
37 | dropout: 0.1,
38 | attention_dropout: 0.1,
39 | activation: "gelu",
40 | initializer_range: 0.02,
41 | qa_dropout: 0.1,
42 | seq_classif_dropout: 0.2,
43 | pad_token_id: 0,
44 | **kwargs
45 | )
46 | @vocab_size = vocab_size
47 | @max_position_embeddings = max_position_embeddings
48 | @sinusoidal_pos_embds = sinusoidal_pos_embds
49 | @n_layers = n_layers
50 | @n_heads = n_heads
51 | @dim = dim
52 | @hidden_dim = hidden_dim
53 | @dropout = dropout
54 | @attention_dropout = attention_dropout
55 | @activation = activation
56 | @initializer_range = initializer_range
57 | @qa_dropout = qa_dropout
58 | @seq_classif_dropout = seq_classif_dropout
59 | super(**kwargs, pad_token_id: pad_token_id)
60 | end
61 | end
62 | end
63 | end
64 |
--------------------------------------------------------------------------------
/lib/transformers/models/distilbert/tokenization_distilbert.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module Distilbert
17 | class DistilBertTokenizer < PreTrainedTokenizer
18 | VOCAB_FILES_NAMES = {vocab_file: "vocab.txt"}
19 |
20 | self.vocab_files_names = VOCAB_FILES_NAMES
21 | self.model_input_names = ["input_ids", "attention_mask"]
22 |
23 | class BasicTokenizer
24 | attr_reader :do_lower_case, :tokenize_chinese_chars, :strip_accents
25 |
26 | def initialize(
27 | do_lower_case: true,
28 | never_split: nil,
29 | tokenize_chinese_chars: true,
30 | strip_accents: nil,
31 | do_split_on_punc: true
32 | )
33 | if never_split.nil?
34 | never_split = []
35 | end
36 | @do_lower_case = do_lower_case
37 | @never_split = Set.new(never_split)
38 | @tokenize_chinese_chars = tokenize_chinese_chars
39 | @strip_accents = strip_accents
40 | @do_split_on_punc = do_split_on_punc
41 | end
42 | end
43 |
44 | class WordpieceTokenizer
45 | def initialize(vocab:, unk_token:, max_input_chars_per_word: 100)
46 | @vocab = vocab
47 | @unk_token = unk_token
48 | @max_input_chars_per_word = max_input_chars_per_word
49 | end
50 | end
51 |
52 | attr_reader :vocab, :basic_tokenizer
53 |
54 | def initialize(
55 | vocab_file:,
56 | do_lower_case: true,
57 | do_basic_tokenize: true,
58 | never_split: nil,
59 | unk_token: "[UNK]",
60 | sep_token: "[SEP]",
61 | pad_token: "[PAD]",
62 | cls_token: "[CLS]",
63 | mask_token: "[MASK]",
64 | tokenize_chinese_chars: true,
65 | strip_accents: nil,
66 | **kwargs
67 | )
68 | @vocab = load_vocab(vocab_file)
69 | @ids_to_tokens = @vocab.invert
70 | @do_basic_tokenize = do_basic_tokenize
71 | if do_basic_tokenize
72 | @basic_tokenizer =
73 | BasicTokenizer.new(
74 | do_lower_case: do_lower_case,
75 | never_split: never_split,
76 | tokenize_chinese_chars: tokenize_chinese_chars,
77 | strip_accents: strip_accents
78 | )
79 | end
80 | @wordpiece_tokenizer = WordpieceTokenizer.new(vocab: @vocab, unk_token: unk_token.to_s)
81 |
82 | super(
83 | do_lower_case: do_lower_case,
84 | do_basic_tokenize: do_basic_tokenize,
85 | never_split: never_split,
86 | unk_token: unk_token,
87 | sep_token: sep_token,
88 | pad_token: pad_token,
89 | cls_token: cls_token,
90 | mask_token: mask_token,
91 | tokenize_chinese_chars: tokenize_chinese_chars,
92 | strip_accents: strip_accents,
93 | **kwargs
94 | )
95 | end
96 |
97 | def _convert_token_to_id(token)
98 | @vocab.fetch(token, @vocab.fetch(@unk_token))
99 | end
100 |
101 | private
102 |
103 | def load_vocab(vocab_file)
104 | vocab = {}
105 | tokens = File.readlines(vocab_file)
106 | tokens.each_with_index do |token, index|
107 | token = token.chomp
108 | vocab[token] = index
109 | end
110 | vocab
111 | end
112 | end
113 | end
114 | end
115 |
--------------------------------------------------------------------------------
/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module Distilbert
17 | class DistilBertTokenizerFast < PreTrainedTokenizerFast
18 | VOCAB_FILES_NAMES = {vocab_file: "vocab.txt", tokenizer_file: "tokenizer.json"}
19 |
20 | self.vocab_files_names = VOCAB_FILES_NAMES
21 | self.model_input_names = ["input_ids", "attention_mask"]
22 | self.slow_tokenizer_class = DistilBertTokenizer
23 |
24 | def initialize(
25 | vocab_file: nil,
26 | tokenizer_file: nil,
27 | do_lower_case: true,
28 | unk_token: "[UNK]",
29 | sep_token: "[SEP]",
30 | pad_token: "[PAD]",
31 | cls_token: "[CLS]",
32 | mask_token: "[MASK]",
33 | tokenize_chinese_chars: true,
34 | strip_accents: nil,
35 | **kwargs
36 | )
37 | super(
38 | vocab_file,
39 | tokenizer_file: tokenizer_file,
40 | do_lower_case: do_lower_case,
41 | unk_token: unk_token,
42 | sep_token: sep_token,
43 | pad_token: pad_token,
44 | cls_token: cls_token,
45 | mask_token: mask_token,
46 | tokenize_chinese_chars: tokenize_chinese_chars,
47 | strip_accents: strip_accents,
48 | **kwargs
49 | )
50 |
51 | if @backend_tokenizer
52 | raise Todo
53 | end
54 |
55 | @do_lower_case = do_lower_case
56 | end
57 |
58 | def build_inputs_with_special_tokens(token_ids_0, token_ids_1 = nil)
59 | raise Todo
60 | end
61 |
62 | def create_token_type_ids_from_sequences(token_ids_0, token_ids_1 = nil)
63 | raise Todo
64 | end
65 |
66 | def save_vocabulary(save_directory, filename_prefix: nil)
67 | raise Todo
68 | end
69 | end
70 | end
71 | end
72 |
--------------------------------------------------------------------------------
/lib/transformers/models/mpnet/configuration_mpnet.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | module Transformers
17 | module Mpnet
18 | class MPNetConfig < PretrainedConfig
19 | self.model_type = "mpnet"
20 |
21 | attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
22 | :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
23 | :max_position_embeddings, :initializer_range, :layer_norm_eps, :relative_attention_num_buckets,
24 | :pad_token_id, :bos_token_id, :eos_token_id
25 |
26 | def initialize(
27 | vocab_size: 30527,
28 | hidden_size: 768,
29 | num_hidden_layers: 12,
30 | num_attention_heads: 12,
31 | intermediate_size: 3072,
32 | hidden_act: "gelu",
33 | hidden_dropout_prob: 0.1,
34 | attention_probs_dropout_prob: 0.1,
35 | max_position_embeddings: 512,
36 | initializer_range: 0.02,
37 | layer_norm_eps: 1e-12,
38 | relative_attention_num_buckets: 32,
39 | pad_token_id: 1,
40 | bos_token_id: 0,
41 | eos_token_id: 2,
42 | **kwargs
43 | )
44 | super(pad_token_id: pad_token_id, bos_token_id: bos_token_id, eos_token_id: eos_token_id, **kwargs)
45 |
46 | @vocab_size = vocab_size
47 | @hidden_size = hidden_size
48 | @num_hidden_layers = num_hidden_layers
49 | @num_attention_heads = num_attention_heads
50 | @hidden_act = hidden_act
51 | @intermediate_size = intermediate_size
52 | @hidden_dropout_prob = hidden_dropout_prob
53 | @attention_probs_dropout_prob = attention_probs_dropout_prob
54 | @max_position_embeddings = max_position_embeddings
55 | @initializer_range = initializer_range
56 | @layer_norm_eps = layer_norm_eps
57 | @relative_attention_num_buckets = relative_attention_num_buckets
58 | end
59 | end
60 | end
61 | end
62 |
--------------------------------------------------------------------------------
/lib/transformers/models/mpnet/tokenization_mpnet_fast.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | module Transformers
17 | module Mpnet
18 | class MPNetTokenizerFast < PreTrainedTokenizerFast
19 | VOCAB_FILES_NAMES = {vocab_file: "vocab.txt", tokenizer_file: "tokenizer.json"}
20 |
21 | self.vocab_files_names = VOCAB_FILES_NAMES
22 | # self.slow_tokenizer_class = MPNetTokenizer
23 | self.model_input_names = ["input_ids", "attention_mask"]
24 |
25 | def initialize(
26 | vocab_file: nil,
27 | tokenizer_file: nil,
28 | do_lower_case: true,
29 | bos_token: "",
30 | eos_token: "",
31 | sep_token: "",
32 | cls_token: "",
33 | unk_token: "[UNK]",
34 | pad_token: "",
35 | mask_token: "",
36 | tokenize_chinese_chars: true,
37 | strip_accents: nil,
38 | **kwargs
39 | )
40 | bos_token = bos_token.is_a?(String) ? Tokenizers::AddedToken.new(bos_token, lstrip: false, rstrip: false) : bos_token
41 | eos_token = eos_token.is_a?(String) ? Tokenizers::AddedToken.new(eos_token, lstrip: false, rstrip: false) : eos_token
42 | sep_token = sep_token.is_a?(String) ? Tokenizers::AddedToken.new(sep_token, lstrip: false, rstrip: false) : sep_token
43 | cls_token = cls_token.is_a?(String) ? Tokenizers::AddedToken.new(cls_token, lstrip: false, rstrip: false) : cls_token
44 | unk_token = unk_token.is_a?(String) ? Tokenizers::AddedToken.new(unk_token, lstrip: false, rstrip: false) : unk_token
45 | pad_token = pad_token.is_a?(String) ? Tokenizers::AddedToken.new(pad_token, lstrip: false, rstrip: false) : pad_token
46 |
47 | # Mask token behave like a normal word, i.e. include the space before it
48 | mask_token = mask_token.is_a?(String) ? Tokenizers::AddedToken.new(mask_token, lstrip: true, rstrip: false) : mask_token
49 |
50 | super(vocab_file, tokenizer_file: tokenizer_file, do_lower_case: do_lower_case, bos_token: bos_token, eos_token: eos_token, sep_token: sep_token, cls_token: cls_token, unk_token: unk_token, pad_token: pad_token, mask_token: mask_token, tokenize_chinese_chars: tokenize_chinese_chars, strip_accents: strip_accents, **kwargs)
51 |
52 | # TODO support
53 | # pre_tok_state = JSON.parse(backend_tokenizer.normalizer.__getstate__)
54 | # if (pre_tok_state["lowercase"] || do_lower_case) != do_lower_case || (pre_tok_state["strip_accents"] || strip_accents) != strip_accents
55 | # pre_tok_class = getattr(normalizers, pre_tok_state.delete("type"))
56 | # pre_tok_state["lowercase"] = do_lower_case
57 | # pre_tok_state["strip_accents"] = strip_accents
58 | # @normalizer = pre_tok_class(**pre_tok_state)
59 | # end
60 |
61 | @do_lower_case = do_lower_case
62 | end
63 |
64 | def mask_token
65 | if @mask_token.nil?
66 | if @verbose
67 | Transformers.logger.error("Using mask_token, but it is not set yet.")
68 | end
69 | return nil
70 | end
71 | @mask_token.to_s
72 | end
73 |
74 | def mask_token=(value)
75 | # Mask token behave like a normal word, i.e. include the space before it
76 | # So we set lstrip to True
77 | value = value.is_a?(String) ? Tokenizers::AddedToken.new(value, lstrip: true, rstrip: false) : value
78 | @mask_token = value
79 | end
80 |
81 | def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil)
82 | output = [@bos_token_id] + token_ids_0 + [@eos_token_id]
83 | if token_ids_1.nil?
84 | return output
85 | end
86 |
87 | output + [@eos_token_id] + token_ids_1 + [@eos_token_id]
88 | end
89 |
90 | def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil)
91 | sep = [@sep_token_id]
92 | cls = [@cls_token_id]
93 |
94 | if token_ids_1.nil?
95 | return (cls + token_ids_0 + sep).length * [0]
96 | end
97 | (cls + token_ids_0 + sep + sep + token_ids_1 + sep).length * [0]
98 | end
99 |
100 | def save_vocabulary(save_directory, filename_prefix: nil)
101 | files = @tokenizer.model.save(save_directory, name: filename_prefix)
102 | Array(files)
103 | end
104 | end
105 | end
106 | end
107 |
--------------------------------------------------------------------------------
/lib/transformers/models/vit/configuration_vit.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google AI and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module Vit
17 | class ViTConfig < PretrainedConfig
18 | self.model_type = "vit"
19 |
20 | attr_reader :hidden_size, :num_hidden_layers, :num_attention_heads, :intermediate_size,
21 | :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob, :initializer_range,
22 | :layer_norm_eps, :image_size, :patch_size, :num_channels, :qkv_bias, :encoder_stride
23 |
24 | def initialize(
25 | hidden_size: 768,
26 | num_hidden_layers: 12,
27 | num_attention_heads: 12,
28 | intermediate_size: 3072,
29 | hidden_act: "gelu",
30 | hidden_dropout_prob: 0.0,
31 | attention_probs_dropout_prob: 0.0,
32 | initializer_range: 0.02,
33 | layer_norm_eps: 1e-12,
34 | image_size: 224,
35 | patch_size: 16,
36 | num_channels: 3,
37 | qkv_bias: true,
38 | encoder_stride: 16,
39 | **kwargs
40 | )
41 | super(**kwargs)
42 |
43 | @hidden_size = hidden_size
44 | @num_hidden_layers = num_hidden_layers
45 | @num_attention_heads = num_attention_heads
46 | @intermediate_size = intermediate_size
47 | @hidden_act = hidden_act
48 | @hidden_dropout_prob = hidden_dropout_prob
49 | @attention_probs_dropout_prob = attention_probs_dropout_prob
50 | @initializer_range = initializer_range
51 | @layer_norm_eps = layer_norm_eps
52 | @image_size = image_size
53 | @patch_size = patch_size
54 | @num_channels = num_channels
55 | @qkv_bias = qkv_bias
56 | @encoder_stride = encoder_stride
57 | end
58 | end
59 | end
60 | end
61 |
--------------------------------------------------------------------------------
/lib/transformers/models/vit/image_processing_vit.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module Vit
17 | class ViTImageProcessor < BaseImageProcessor
18 | def initialize(
19 | do_resize: true,
20 | size: nil,
21 | resample: :bilinear,
22 | do_rescale: true,
23 | rescale_factor: 1 / 255.0,
24 | do_normalize: true,
25 | image_mean: nil,
26 | image_std: nil,
27 | **kwargs
28 | )
29 | super(**kwargs)
30 | size = !size.nil? ? size : {height: 224, width: 224}
31 | size = ImageProcessingUtils.get_size_dict(size)
32 | @do_resize = do_resize
33 | @do_rescale = do_rescale
34 | @do_normalize = do_normalize
35 | @size = size
36 | @resample = resample
37 | @rescale_factor = rescale_factor
38 | @image_mean = !image_mean.nil? ? image_mean : IMAGENET_STANDARD_MEAN
39 | @image_std = !image_std.nil? ? image_std : IMAGENET_STANDARD_STD
40 | @valid_processor_keys = [
41 | :images,
42 | :do_resize,
43 | :size,
44 | :resample,
45 | :do_rescale,
46 | :rescale_factor,
47 | :do_normalize,
48 | :image_mean,
49 | :image_std,
50 | :return_tensors,
51 | :data_format,
52 | :input_data_format
53 | ]
54 | end
55 |
56 | def resize(
57 | image,
58 | size,
59 | resample: :bilinear,
60 | data_format: nil,
61 | input_data_format: nil,
62 | **kwargs
63 | )
64 | size = ImageProcessingUtils.get_size_dict(size)
65 | if !size.include?(:height) || !size.include?(:width)
66 | raise ArgumentError, "The `size` dictionary must contain the keys `height` and `width`. Got #{size.keys}"
67 | end
68 | output_size = [size[:height], size[:width]]
69 | ImageTransforms.resize(
70 | image,
71 | output_size,
72 | resample: resample,
73 | data_format: data_format,
74 | input_data_format: input_data_format,
75 | **kwargs
76 | )
77 | end
78 |
79 | def preprocess(
80 | images,
81 | do_resize: nil,
82 | size: nil,
83 | resample: nil,
84 | do_rescale: nil,
85 | rescale_factor: nil,
86 | do_normalize: nil,
87 | image_mean: nil,
88 | image_std: nil,
89 | return_tensors: nil,
90 | data_format: ChannelDimension::FIRST,
91 | input_data_format: nil,
92 | **kwargs
93 | )
94 | do_resize = !do_resize.nil? ? do_resize : @do_resize
95 | do_rescale = !do_rescale.nil? ? do_rescale : @do_rescale
96 | do_normalize = !do_normalize.nil? ? do_normalize : @do_normalize
97 | resample = !resample.nil? ? resample : @resample
98 | rescale_factor = !rescale_factor.nil? ? rescale_factor : @rescale_factor
99 | image_mean = !image_mean.nil? ? image_mean : @image_mean
100 | image_std = !image_std.nil? ? image_std : @image_std
101 |
102 | size = !size.nil? ? size : @size
103 | size_dict = ImageProcessingUtils.get_size_dict(size)
104 |
105 | images = ImageUtils.make_list_of_images(images)
106 |
107 | ImageUtils.validate_kwargs(captured_kwargs: kwargs.keys, valid_processor_keys: @valid_processor_keys)
108 |
109 | if !ImageUtils.valid_images(images)
110 | raise ArgumentError,
111 | "Invalid image type. Must be of type Vips::Image, Numo::NArray, or Torch::Tensor."
112 | end
113 | ImageUtils.validate_preprocess_arguments(
114 | do_rescale: do_rescale,
115 | rescale_factor: rescale_factor,
116 | do_normalize: do_normalize,
117 | image_mean: image_mean,
118 | image_std: image_std,
119 | do_resize: do_resize,
120 | size: size,
121 | resample: resample
122 | )
123 |
124 | # All transformations expect numo arrays.
125 | images = images.map { |image| ImageUtils.to_numo_array(image) }
126 |
127 | if ImageUtils.is_scaled_image(images[0]) && do_rescale
128 | Transformers.logger.warn(
129 | "It looks like you are trying to rescale already rescaled images. If the input" +
130 | " images have pixel values between 0 and 1, set `do_rescale: false` to avoid rescaling them again."
131 | )
132 | end
133 |
134 | if input_data_format.nil?
135 | # We assume that all images have the same channel dimension format.
136 | input_data_format = ImageUtils.infer_channel_dimension_format(images[0])
137 | end
138 |
139 | if do_resize
140 | images =
141 | images.map do |image|
142 | resize(image, size_dict, resample: resample, input_data_format: input_data_format)
143 | end
144 | end
145 |
146 | if do_rescale
147 | images =
148 | images.map do |image|
149 | rescale(image, rescale_factor, input_data_format: input_data_format)
150 | end
151 | end
152 |
153 | if do_normalize
154 | images =
155 | images.map do |image|
156 | normalize(image, image_mean, image_std, input_data_format: input_data_format)
157 | end
158 | end
159 |
160 | images =
161 | images.map do |image|
162 | ImageTransforms.to_channel_dimension_format(image, data_format, input_channel_dim: input_data_format)
163 | end
164 |
165 | data = {pixel_values: images}
166 | BatchFeature.new(data: data, tensor_type: return_tensors)
167 | end
168 | end
169 | end
170 | end
171 |
--------------------------------------------------------------------------------
/lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | module Transformers
17 | module XlmRoberta
18 | class XLMRobertaConfig < PretrainedConfig
19 | self.model_type = "xlm-roberta"
20 |
21 | attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
22 | :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
23 | :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
24 | :pad_token_id, :bos_token_id, :eos_token_id, :position_embedding_type, :use_cache,
25 | :classifier_dropout
26 |
27 | def initialize(
28 | vocab_size: 30522,
29 | hidden_size: 768,
30 | num_hidden_layers: 12,
31 | num_attention_heads: 12,
32 | intermediate_size: 3072,
33 | hidden_act: "gelu",
34 | hidden_dropout_prob: 0.1,
35 | attention_probs_dropout_prob: 0.1,
36 | max_position_embeddings: 512,
37 | type_vocab_size: 2,
38 | initializer_range: 0.02,
39 | layer_norm_eps: 1e-12,
40 | pad_token_id: 1,
41 | bos_token_id: 0,
42 | eos_token_id: 2,
43 | position_embedding_type: "absolute",
44 | use_cache: true,
45 | classifier_dropout: nil,
46 | **kwargs
47 | )
48 | super(pad_token_id: pad_token_id, bos_token_id: bos_token_id, eos_token_id: eos_token_id, **kwargs)
49 |
50 | @vocab_size = vocab_size
51 | @hidden_size = hidden_size
52 | @num_hidden_layers = num_hidden_layers
53 | @num_attention_heads = num_attention_heads
54 | @hidden_act = hidden_act
55 | @intermediate_size = intermediate_size
56 | @hidden_dropout_prob = hidden_dropout_prob
57 | @attention_probs_dropout_prob = attention_probs_dropout_prob
58 | @max_position_embeddings = max_position_embeddings
59 | @type_vocab_size = type_vocab_size
60 | @initializer_range = initializer_range
61 | @layer_norm_eps = layer_norm_eps
62 | @position_embedding_type = position_embedding_type
63 | @use_cache = use_cache
64 | @classifier_dropout = classifier_dropout
65 | end
66 | end
67 | end
68 | end
69 |
--------------------------------------------------------------------------------
/lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License
14 |
15 | module Transformers
16 | module XlmRoberta
17 | class XLMRobertaTokenizerFast < PreTrainedTokenizerFast
18 | VOCAB_FILES_NAMES = {vocab_file: "sentencepiece.bpe.model", tokenizer_file: "tokenizer.json"}
19 |
20 | self.vocab_files_names = VOCAB_FILES_NAMES
21 | self.model_input_names = ["input_ids", "attention_mask"]
22 | # self.slow_tokenizer_class = XLMRobertaTokenizer
23 |
24 | def initialize(
25 | vocab_file: nil,
26 | tokenizer_file: nil,
27 | bos_token: "",
28 | eos_token: "",
29 | sep_token: "",
30 | cls_token: "",
31 | unk_token: "",
32 | pad_token: "",
33 | mask_token: "",
34 | **kwargs
35 | )
36 | # Mask token behave like a normal word, i.e. include the space before it
37 | mask_token = mask_token.is_a?(String) ? Tokenizers::AddedToken.new(mask_token, lstrip: true, rstrip: false) : mask_token
38 |
39 | super(vocab_file, tokenizer_file: tokenizer_file, bos_token: bos_token, eos_token: eos_token, sep_token: sep_token, cls_token: cls_token, unk_token: unk_token, pad_token: pad_token, mask_token: mask_token, **kwargs)
40 |
41 | @vocab_file = vocab_file
42 | end
43 |
44 | def can_save_slow_tokenizer
45 | @vocab_file ? File.exist?(@vocab_file) : false
46 | end
47 |
48 | def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil)
49 | if token_ids_1.nil?
50 | return [@cls_token_id] + token_ids_0 + [@sep_token_id]
51 | end
52 | cls = [@cls_token_id]
53 | sep = [@sep_token_id]
54 | cls + token_ids_0 + sep + sep + token_ids_1 + sep
55 | end
56 |
57 | def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil)
58 | sep = [@sep_token_id]
59 | cls = [@cls_token_id]
60 |
61 | if token_ids_1.nil?
62 | return (cls + token_ids_0 + sep).length * [0]
63 | end
64 | (cls + token_ids_0 + sep + sep + token_ids_1 + sep).length * [0]
65 | end
66 | end
67 | end
68 | end
69 |
--------------------------------------------------------------------------------
/lib/transformers/pipelines/embedding.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | class EmbeddingPipeline < Pipeline
3 | def _sanitize_parameters(**kwargs)
4 | [{}, {}, kwargs]
5 | end
6 |
7 | def preprocess(inputs)
8 | @tokenizer.(inputs, return_tensors: @framework)
9 | end
10 |
11 | def _forward(model_inputs)
12 | {
13 | last_hidden_state: @model.(**model_inputs)[0],
14 | attention_mask: model_inputs[:attention_mask]
15 | }
16 | end
17 |
18 | def postprocess(model_outputs, pooling: "mean", normalize: true)
19 | output = model_outputs[:last_hidden_state]
20 |
21 | case pooling
22 | when "none"
23 | # do nothing
24 | when "mean"
25 | output = mean_pooling(output, model_outputs[:attention_mask])
26 | when "cls"
27 | output = output[0.., 0]
28 | else
29 | raise Error, "Pooling method '#{pooling}' not supported."
30 | end
31 |
32 | if normalize
33 | output = Torch::NN::Functional.normalize(output, p: 2, dim: 1)
34 | end
35 |
36 | output[0].to_a
37 | end
38 |
39 | private
40 |
41 | def mean_pooling(output, attention_mask)
42 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
43 | Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
44 | end
45 | end
46 | end
47 |
--------------------------------------------------------------------------------
/lib/transformers/pipelines/feature_extraction.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | class FeatureExtractionPipeline < Pipeline
3 | def _sanitize_parameters(truncation: nil, tokenize_kwargs: nil, return_tensors: nil, **kwargs)
4 | if tokenize_kwargs.nil?
5 | tokenize_kwargs = {}
6 | end
7 |
8 | if !truncation.nil?
9 | if tokenize_kwargs.include?(:truncation)
10 | raise ArgumentError,
11 | "truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
12 | end
13 | tokenize_kwargs[:truncation] = truncation
14 | end
15 |
16 | preprocess_params = tokenize_kwargs
17 |
18 | postprocess_params = {}
19 | if !return_tensors.nil?
20 | postprocess_params[:return_tensors] = return_tensors
21 | end
22 |
23 | [preprocess_params, {}, postprocess_params]
24 | end
25 |
26 | def preprocess(inputs, **tokenize_kwargs)
27 | model_inputs = @tokenizer.(inputs, return_tensors: @framework, **tokenize_kwargs)
28 | model_inputs
29 | end
30 |
31 | def _forward(model_inputs)
32 | model_outputs = @model.(**model_inputs)
33 | model_outputs
34 | end
35 |
36 | def postprocess(model_outputs, return_tensors: false)
37 | # [0] is the first available tensor, logits or last_hidden_state.
38 | if return_tensors
39 | model_outputs[0]
40 | elsif @framework == "pt"
41 | model_outputs[0].to_a
42 | elsif @framework == "tf"
43 | raise Todo
44 | end
45 | end
46 | end
47 | end
48 |
--------------------------------------------------------------------------------
/lib/transformers/pipelines/image_classification.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | class ClassificationFunction < ExplicitEnum
3 | SIGMOID = "sigmoid"
4 | SOFTMAX = "softmax"
5 | NONE = "none"
6 | end
7 |
8 | class ImageClassificationPipeline < Pipeline
9 | extend ClassAttribute
10 |
11 | class_attribute :function_to_apply, ClassificationFunction::NONE
12 |
13 | def initialize(*args, **kwargs)
14 | super(*args, **kwargs)
15 | Utils.requires_backends(self, "vision")
16 | check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES)
17 | end
18 |
19 | def _sanitize_parameters(top_k: nil, function_to_apply: nil, timeout: nil)
20 | preprocess_params = {}
21 | if !timeout.nil?
22 | preprocess_params[:timeout] = timeout
23 | end
24 | postprocess_params = {}
25 | if !top_k.nil?
26 | postprocess_params[:top_k] = top_k
27 | end
28 | if function_to_apply.is_a?(String)
29 | function_to_apply = ClassificationFunction.new(function_to_apply.downcase).to_s
30 | end
31 | if !function_to_apply.nil?
32 | postprocess_params[:function_to_apply] = function_to_apply
33 | end
34 | [preprocess_params, {}, postprocess_params]
35 | end
36 |
37 | def preprocess(image, timeout: nil)
38 | image = ImageUtils.load_image(image, timeout: timeout)
39 | model_inputs = @image_processor.(image, return_tensors: @framework)
40 | if @framework == "pt"
41 | # TODO
42 | # model_inputs = model_inputs.to(torch_dtype)
43 | end
44 | model_inputs
45 | end
46 |
47 | def _forward(model_inputs)
48 | model_outputs = @model.(**model_inputs)
49 | model_outputs
50 | end
51 |
52 | def postprocess(model_outputs, function_to_apply: nil, top_k: 5)
53 | if function_to_apply.nil?
54 | if @model.config.problem_type == "multi_label_classification" || @model.config.num_labels == 1
55 | function_to_apply = ClassificationFunction::SIGMOID
56 | elsif @model.config.problem_type == "single_label_classification" || @model.config.num_labels > 1
57 | function_to_apply = ClassificationFunction::SOFTMAX
58 | elsif @model.config.instance_variable_defined?(:@function_to_apply) && function_to_apply.nil?
59 | function_to_apply = @model.config.function_to_apply
60 | else
61 | function_to_apply = ClassificationFunction::NONE
62 | end
63 | end
64 |
65 | if top_k > @model.config.num_labels
66 | top_k = @model.config.num_labels
67 | end
68 |
69 | outputs = model_outputs[:logits][0]
70 | if @framework == "pt" && [Torch.bfloat16, Torch.float16].include?(outputs.dtype)
71 | outputs = outputs.to(Torch.float32).numo
72 | else
73 | outputs = outputs.numo
74 | end
75 |
76 | if function_to_apply == ClassificationFunction::SIGMOID
77 | scores = sigmoid(outputs)
78 | elsif function_to_apply == ClassificationFunction::SOFTMAX
79 | scores = softmax(outputs)
80 | elsif function_to_apply == ClassificationFunction::NONE
81 | scores = outputs
82 | else
83 | raise ArgumentError, "Unrecognized `function_to_apply` argument: #{function_to_apply}"
84 | end
85 |
86 | dict_scores =
87 | scores.to_a.map.with_index do |score, i|
88 | {label: @model.config.id2label[i], score: score}
89 | end
90 | dict_scores.sort_by! { |x| -x[:score] }
91 | if !top_k.nil?
92 | dict_scores = dict_scores[...top_k]
93 | end
94 |
95 | dict_scores
96 | end
97 |
98 | private
99 |
100 | def sigmoid(_outputs)
101 | 1.0 / (1.0 + Numo::NMath.exp(-_outputs))
102 | end
103 |
104 | def softmax(_outputs)
105 | maxes = _outputs.max(axis: -1, keepdims: true)
106 | shifted_exp = Numo::NMath.exp(_outputs - maxes)
107 | shifted_exp / shifted_exp.sum(axis: -1, keepdims: true)
108 | end
109 | end
110 | end
111 |
--------------------------------------------------------------------------------
/lib/transformers/pipelines/image_feature_extraction.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | class ImageFeatureExtractionPipeline < Pipeline
3 | def _sanitize_parameters(image_processor_kwargs: nil, return_tensors: nil, pool: nil, **kwargs)
4 | preprocess_params = image_processor_kwargs.nil? ? {} : image_processor_kwargs
5 |
6 | postprocess_params = {}
7 | if !pool.nil?
8 | postprocess_params[:pool] = pool
9 | end
10 | if !return_tensors.nil?
11 | postprocess_params[:return_tensors] = return_tensors
12 | end
13 |
14 | if kwargs.include?(:timeout)
15 | preprocess_params[:timeout] = kwargs[:timeout]
16 | end
17 |
18 | [preprocess_params, {}, postprocess_params]
19 | end
20 |
21 | def preprocess(image, timeout: nil, **image_processor_kwargs)
22 | image = ImageUtils.load_image(image, timeout: timeout)
23 | model_inputs = @image_processor.(image, return_tensors: @framework, **image_processor_kwargs)
24 | if @framework == "pt"
25 | # TODO
26 | # model_inputs = model_inputs.to(torch_dtype)
27 | end
28 | model_inputs
29 | end
30 |
31 | def _forward(model_inputs)
32 | model_outputs = @model.(**model_inputs)
33 | model_outputs
34 | end
35 |
36 | def postprocess(model_outputs, pool: nil, return_tensors: false)
37 | pool = !pool.nil? ? pool : false
38 |
39 | if pool
40 | raise Todo
41 | else
42 | # [0] is the first available tensor, logits or last_hidden_state.
43 | outputs = model_outputs[0]
44 | end
45 |
46 | if return_tensors
47 | return outputs
48 | end
49 | if @framework == "pt"
50 | outputs.to_a
51 | else
52 | raise Todo
53 | end
54 | end
55 | end
56 | end
57 |
--------------------------------------------------------------------------------
/lib/transformers/pipelines/pt_utils.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | class PipelineDataset < Torch::Utils::Data::Dataset
3 | def initialize(dataset, process, params)
4 | @dataset = dataset
5 | @process = process
6 | @params = params
7 | end
8 |
9 | def size
10 | @dataset.size
11 | end
12 |
13 | def [](i)
14 | item = @dataset[i]
15 | processed = @process.(item, **@params)
16 | processed
17 | end
18 | end
19 |
20 | class PipelineIterator < Torch::Utils::Data::IterableDataset
21 | def initialize(loader, infer, params, loader_batch_size: nil)
22 | @loader = loader
23 | @infer = infer
24 | @params = params
25 | if loader_batch_size == 1
26 | # Let's spare some time by deactivating altogether
27 | loader_batch_size = nil
28 | end
29 | @loader_batch_size = loader_batch_size
30 |
31 | # Internal bookkeeping
32 | @loader_batch_index = nil
33 | @loader_batch_data = nil
34 | end
35 |
36 | def size
37 | @loader.size
38 | end
39 |
40 | def [](i)
41 | @infer.(@loader[i], **@params)
42 | end
43 |
44 | def each
45 | @iterator = @loader
46 |
47 | @iterator.each do |item|
48 | processed = @infer.(item, **@params)
49 | yield processed
50 | end
51 | end
52 | end
53 | end
54 |
--------------------------------------------------------------------------------
/lib/transformers/pipelines/reranking.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | class RerankingPipeline < Pipeline
3 | def _sanitize_parameters(**kwargs)
4 | [{}, {}, kwargs]
5 | end
6 |
7 | def preprocess(inputs)
8 | @tokenizer.(
9 | [inputs[:query]] * inputs[:documents].length,
10 | text_pair: inputs[:documents],
11 | return_tensors: @framework,
12 | padding: true
13 | )
14 | end
15 |
16 | def _forward(model_inputs)
17 | model_outputs = @model.(**model_inputs)
18 | model_outputs
19 | end
20 |
21 | def call(query, documents)
22 | super({query: query, documents: documents})
23 | end
24 |
25 | def postprocess(model_outputs)
26 | model_outputs[0]
27 | .sigmoid
28 | .squeeze
29 | .to_a
30 | .map.with_index { |s, i| {index: i, score: s} }
31 | .sort_by { |v| -v[:score] }
32 | end
33 | end
34 | end
35 |
--------------------------------------------------------------------------------
/lib/transformers/pipelines/text_classification.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | class TextClassificationPipeline < Pipeline
3 | def initialize(*args, **kwargs)
4 | super
5 |
6 | check_model_type(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES)
7 | end
8 |
9 | private
10 |
11 | def _sanitize_parameters(return_all_scores: nil, function_to_apply: nil, top_k: "", **tokenizer_kwargs)
12 | # Using "" as default argument because we're going to use `top_k=None` in user code to declare
13 | # "No top_k"
14 | preprocess_params = tokenizer_kwargs
15 |
16 | postprocess_params = {}
17 | if @model.config.respond_to?(:return_all_scores) && return_all_scores.nil?
18 | return_all_scores = @model.config.return_all_scores
19 | end
20 |
21 | if top_k.is_a?(Integer) || top_k.nil?
22 | postprocess_params[:top_k] = top_k
23 | postprocess_params[:_legacy] = false
24 | elsif !return_all_scores.nil?
25 | warn(
26 | "`return_all_scores` is now deprecated, if want a similar functionality use `top_k: nil` instead of" +
27 | " `return_all_scores: true` or `top_k: 1` instead of `return_all_scores: false`.",
28 | )
29 | if return_all_scores
30 | postprocess_params[:top_k] = nil
31 | else
32 | postprocess_params[:top_k] = 1
33 | end
34 | end
35 |
36 | if function_to_apply.is_a?(String)
37 | function_to_apply = ClassificationFunction.new(function_to_apply.downcase).to_s
38 | end
39 |
40 | if !function_to_apply.nil?
41 | postprocess_params[:function_to_apply] = function_to_apply
42 | end
43 | [preprocess_params, {}, postprocess_params]
44 | end
45 |
46 | def preprocess(inputs, **tokenizer_kwargs)
47 | return_tensors = @framework
48 | if inputs.is_a?(Hash)
49 | return @tokenizer.(**inputs, return_tensors: return_tensors, **tokenizer_kwargs)
50 | elsif inputs.is_a?(Array) && inputs.length == 1 && inputs[0].is_a?(Array) && inputs[0].length == 2
51 | # It used to be valid to use a list of list of list for text pairs, keeping this path for BC
52 | return @tokenizer.(
53 | inputs[0][0], text_pair: inputs[0][1], return_tensors: return_tensors, **tokenizer_kwargs
54 | )
55 | elsif inputs.is_a?(Array)
56 | # This is likely an invalid usage of the pipeline attempting to pass text pairs.
57 | raise ArgumentError,
58 | "The pipeline received invalid inputs, if you are trying to send text pairs, you can try to send a" +
59 | ' dictionary `{"text": "My text", "text_pair": "My pair"}` in order to send a text pair.'
60 | end
61 | @tokenizer.(inputs, return_tensors: return_tensors, **tokenizer_kwargs)
62 | end
63 |
64 | def _forward(model_inputs)
65 | @model.(**model_inputs)
66 | end
67 |
68 | def postprocess(model_outputs, function_to_apply: nil, top_k: 1, _legacy: true)
69 | if function_to_apply.nil?
70 | if @model.config.problem_type == "multi_label_classification" || @model.config.num_labels == 1
71 | function_to_apply = ClassificationFunction::SIGMOID
72 | elsif @model.config.problem_type == "single_label_classification" || @model.config.num_labels > 1
73 | function_to_apply = ClassificationFunction::SOFTMAX
74 | elsif @model.config.instance_variable_defined?(:@function_to_apply) && function_to_apply.nil?
75 | function_to_apply = @model.config.function_to_apply
76 | else
77 | function_to_apply = ClassificationFunction::NONE
78 | end
79 | end
80 |
81 | outputs = model_outputs["logits"][0]
82 | outputs = outputs.numo
83 |
84 | if function_to_apply == ClassificationFunction::SIGMOID
85 | scores = sigmoid(outputs)
86 | elsif function_to_apply == ClassificationFunction::SOFTMAX
87 | scores = softmax(outputs)
88 | elsif function_to_apply == ClassificationFunction::NONE
89 | scores = outputs
90 | else
91 | raise ArgumentError, "Unrecognized `function_to_apply` argument: #{function_to_apply}"
92 | end
93 |
94 | if top_k == 1 && _legacy
95 | return {label: @model.config.id2label[scores.argmax], score: scores.max}
96 | end
97 |
98 | dict_scores =
99 | scores.to_a.map.with_index do |score, i|
100 | {label: @model.config.id2label[i], score: score}
101 | end
102 | if !_legacy
103 | dict_scores.sort_by! { |x| -x[:score] }
104 | if !top_k.nil?
105 | dict_scores = dict_scores.first(top_k)
106 | end
107 | end
108 | dict_scores
109 | end
110 |
111 | private
112 |
113 | def sigmoid(_outputs)
114 | 1.0 / (1.0 + Numo::NMath.exp(-_outputs))
115 | end
116 |
117 | def softmax(_outputs)
118 | maxes = _outputs.max(axis: -1, keepdims: true)
119 | shifted_exp = Numo::NMath.exp(_outputs - maxes)
120 | shifted_exp / shifted_exp.sum(axis: -1, keepdims: true)
121 | end
122 | end
123 | end
124 |
--------------------------------------------------------------------------------
/lib/transformers/ruby_utils.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | module ClassAttribute
3 | def class_attribute(name, default = nil)
4 | singleton_class.attr_writer name
5 | var = "@#{name}"
6 | instance_variable_set(var, default)
7 | singleton_class.define_method(name) do
8 | # ancestors includes current module
9 | ancestors.find { |c| c.instance_variable_defined?(var) }.instance_variable_get(var)
10 | end
11 | define_method(name) do
12 | self.class.send(name)
13 | end
14 | end
15 | end
16 |
17 | module Copy
18 | def self.deepcopy(value, memo = {})
19 | key = value.object_id
20 | if !memo.key?(key)
21 | copy = value.dup
22 | memo[key] = copy
23 | if value.is_a?(Hash)
24 | copy.transform_keys! { |k| deepcopy(k, memo) }
25 | copy.transform_values! { |v| deepcopy(v, memo) }
26 | elsif value.is_a?(Array)
27 | copy.map! { |v| deepcopy(v, memo) }
28 | end
29 | end
30 | memo[key]
31 | end
32 | end
33 | end
34 |
--------------------------------------------------------------------------------
/lib/transformers/sentence_transformer.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | # TODO remove in 0.2.0
3 | class SentenceTransformer
4 | def initialize(model_id)
5 | @model_id = model_id
6 | @model = Transformers.pipeline("embedding", model_id)
7 | end
8 |
9 | def encode(sentences)
10 | # TODO check modules.json
11 | if [
12 | "sentence-transformers/all-MiniLM-L6-v2",
13 | "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
14 | ].include?(@model_id)
15 | @model.(sentences)
16 | else
17 | @model.(sentences, pooling: "cls", normalize: false)
18 | end
19 | end
20 | end
21 | end
22 |
--------------------------------------------------------------------------------
/lib/transformers/tokenization_utils.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Inc. team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class PreTrainedTokenizer < PreTrainedTokenizerBase
17 | def initialize(**kwargs)
18 | # 2. init `_added_tokens_decoder` if child class did not
19 | if !instance_variable_defined?(:@added_tokens_decoder)
20 | @added_tokens_decoder = {}
21 | end
22 |
23 | # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite
24 | @added_tokens_decoder.merge!(kwargs.delete(:added_tokens_decoder) { {} })
25 | @added_tokens_encoder = @added_tokens_decoder.to_h { |k, v| [k.content, v] }
26 |
27 | # 4 init the parent class
28 | super(**kwargs)
29 | end
30 |
31 | def is_fast
32 | false
33 | end
34 |
35 | def vocab_size
36 | raise NotImplementedError
37 | end
38 |
39 | def tokenize(text, **kwargs)
40 | raise Todo
41 | end
42 |
43 | def _encode_plus(
44 | text:,
45 | text_pair: nil,
46 | add_special_tokens: true,
47 | padding_strategy: PaddingStrategy::DO_NOT_PAD,
48 | truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
49 | max_length: nil,
50 | stride: 0,
51 | is_split_into_words: false,
52 | pad_to_multiple_of: nil,
53 | return_tensors: nil,
54 | return_token_type_ids: nil,
55 | return_attention_mask: nil,
56 | return_overflowing_tokens: false,
57 | return_special_tokens_mask: false,
58 | return_offsets_mapping: false,
59 | return_length: false,
60 | verbose: true,
61 | **kwargs
62 | )
63 | get_input_ids = lambda do |text|
64 | if text.is_a?(String)
65 | tokens = tokenize(text, **kwargs)
66 | convert_tokens_to_ids(tokens)
67 | elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(String)
68 | if is_split_into_words
69 | raise Todo
70 | else
71 | convert_tokens_to_ids(text)
72 | end
73 | elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(Integer)
74 | text
75 | else
76 | if is_split_into_words
77 | raise ArgumentError,
78 | "Input #{text} is not valid. Should be a string or a list/tuple of strings when" +
79 | " `is_split_into_words=True`."
80 | else
81 | raise ArgumentError,
82 | "Input #{text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of" +
83 | " integers."
84 | end
85 | end
86 | end
87 |
88 | if return_offsets_mapping
89 | raise RuntimeError,
90 | "return_offset_mapping is not available when using Ruby tokenizers. " +
91 | "To use this feature, change your tokenizer to one deriving from " +
92 | "Transformers::PreTrainedTokenizerFast. " +
93 | "More information on available tokenizers at " +
94 | "https://github.com/huggingface/transformers/pull/2674"
95 | end
96 |
97 | first_ids = get_input_ids.(text)
98 | second_ids = !text_pair.nil? ? get_input_ids.(text_pair) : nil
99 |
100 | prepare_for_model(
101 | first_ids,
102 | pair_ids: second_ids,
103 | add_special_tokens: add_special_tokens,
104 | padding: padding_strategy,
105 | truncation: truncation_strategy,
106 | max_length: max_length,
107 | stride: stride,
108 | pad_to_multiple_of: pad_to_multiple_of,
109 | return_tensors: return_tensors,
110 | prepend_batch_axis: true,
111 | return_attention_mask: return_attention_mask,
112 | return_token_type_ids: return_token_type_ids,
113 | return_overflowing_tokens: return_overflowing_tokens,
114 | return_special_tokens_mask: return_special_tokens_mask,
115 | return_length: return_length,
116 | verbose: verbose
117 | )
118 | end
119 |
120 | def convert_tokens_to_ids(tokens)
121 | if tokens.nil?
122 | return nil
123 | end
124 |
125 | if tokens.is_a?(String)
126 | return _convert_token_to_id_with_added_voc(tokens)
127 | end
128 |
129 | ids = []
130 | tokens.each do |token|
131 | ids << _convert_token_to_id_with_added_voc(token)
132 | end
133 | ids
134 | end
135 |
136 | def _convert_token_to_id_with_added_voc(token)
137 | if token.nil?
138 | return nil
139 | end
140 |
141 | if @added_tokens_encoder.include?(token)
142 | return @added_tokens_encoder[token]
143 | end
144 | _convert_token_to_id(token)
145 | end
146 |
147 | def _convert_token_to_id(token)
148 | raise NotImplementedError
149 | end
150 | end
151 | end
152 |
--------------------------------------------------------------------------------
/lib/transformers/torch_utils.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module TorchUtils
17 | def self.apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors)
18 | if chunk_size > 0
19 | raise Todo
20 | end
21 |
22 | forward_fn.(*input_tensors)
23 | end
24 | end
25 | end
26 |
--------------------------------------------------------------------------------
/lib/transformers/utils/_init.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | WEIGHTS_NAME = "pytorch_model.bin"
17 | WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
18 | TF2_WEIGHTS_NAME = "tf_model.h5"
19 | TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
20 | TF_WEIGHTS_NAME = "model.ckpt"
21 | FLAX_WEIGHTS_NAME = "flax_model.msgpack"
22 | FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
23 | SAFE_WEIGHTS_NAME = "model.safetensors"
24 | SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
25 | CONFIG_NAME = "config.json"
26 | FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
27 | IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
28 | PROCESSOR_NAME = "processor_config.json"
29 | GENERATION_CONFIG_NAME = "generation_config.json"
30 | MODEL_CARD_NAME = "modelcard.json"
31 | end
32 |
--------------------------------------------------------------------------------
/lib/transformers/utils/generic.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class ModelOutput
17 | def self.attributes
18 | @attributes ||= []
19 | end
20 |
21 | def self.attribute(attribute)
22 | attributes << attribute.to_sym
23 |
24 | define_method(attribute) do
25 | self[attribute]
26 | end
27 | end
28 |
29 | def initialize(**kwargs)
30 | @data = kwargs
31 | end
32 |
33 | def [](k)
34 | if k.is_a?(String) || k.is_a?(Symbol)
35 | @data[k.to_sym]
36 | else
37 | to_tuple[k]
38 | end
39 | end
40 |
41 | def to_tuple
42 | self.class.attributes.map { |k| @data[k] }.compact
43 | end
44 | end
45 |
46 | class ExplicitEnum
47 | def initialize(value)
48 | expected = self.class.constants.map { |k| self.class.const_get(k) }
49 | unless expected.include?(value)
50 | raise ArgumentError, "#{value} is not a valid #{self.class.name}, please select one of #{expected.inspect}"
51 | end
52 | @value = value
53 | end
54 |
55 | def to_s
56 | @value
57 | end
58 | end
59 |
60 | class PaddingStrategy < ExplicitEnum
61 | LONGEST = "longest"
62 | MAX_LENGTH = "max_length"
63 | DO_NOT_PAD = "do_not_pad"
64 | end
65 |
66 | class TensorType < ExplicitEnum
67 | PYTORCH = "pt"
68 | TENSORFLOW = "tf"
69 | NUMPY = "np"
70 | JAX = "jax"
71 | MLX = "mlx"
72 | end
73 |
74 | module Utils
75 | def self.infer_framework(model_class)
76 | if model_class < Torch::NN::Module
77 | "pt"
78 | else
79 | raise TypeError, "Could not infer framework from class #{model_class}."
80 | end
81 | end
82 |
83 | def self._is_numo(x)
84 | x.is_a?(Numo::NArray)
85 | end
86 |
87 | def self.is_numo_array(x)
88 | _is_numo(x)
89 | end
90 |
91 | def self._is_torch(x)
92 | x.is_a?(Torch::Tensor)
93 | end
94 |
95 | def self.is_torch_tensor(x)
96 | _is_torch(x)
97 | end
98 |
99 | def self._is_torch_device(x)
100 | x.is_a?(Torch::Device)
101 | end
102 |
103 | def self.is_torch_device(x)
104 | _is_torch_device(x)
105 | end
106 | end
107 | end
108 |
--------------------------------------------------------------------------------
/lib/transformers/utils/hub.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module Utils
17 | module Hub
18 | IS_OFFLINE_MODE = HfHub::HF_HUB_OFFLINE
19 |
20 | PYTORCH_PRETRAINED_BERT_CACHE = ENV.fetch("PYTORCH_PRETRAINED_BERT_CACHE", HfHub::HF_HUB_CACHE)
21 | PYTORCH_TRANSFORMERS_CACHE = ENV.fetch("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
22 | TRANSFORMERS_CACHE = ENV.fetch("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
23 |
24 | DEFAULT_ENDPOINT = "https://huggingface.co"
25 | HUGGINGFACE_CO_RESOLVE_ENDPOINT = ENV.fetch("HF_ENDPOINT", DEFAULT_ENDPOINT)
26 |
27 | class << self
28 | def is_offline_mode
29 | IS_OFFLINE_MODE
30 | end
31 |
32 | def is_remote_url(url_or_filename)
33 | url_or_filename.is_a?(URI)
34 | end
35 |
36 | def http_user_agent(user_agent = nil)
37 | ua = "transformers.rb/#{Transformers::VERSION}; ruby/#{RUBY_VERSION.to_f}"
38 | if user_agent.is_a?(Hash)
39 | ua += "; " + user_agent.map { |k, v| "#{k}/#{v}" }.join("; ")
40 | elsif user_agent.is_a?(String)
41 | ua += "; " + user_agent
42 | end
43 | ua
44 | end
45 |
46 | def extract_commit_hash(resolved_file, commit_hash)
47 | if resolved_file.nil? || !commit_hash.nil?
48 | return commit_hash
49 | end
50 | search = /snapshots\/([^\/]+)/.match(resolved_file)
51 | if search.nil?
52 | return nil
53 | end
54 | commit_hash = search[1]
55 | HfHub::REGEX_COMMIT_HASH.match(commit_hash) ? commit_hash : nil
56 | end
57 |
58 | def cached_file(
59 | path_or_repo_id,
60 | filename,
61 | cache_dir: nil,
62 | force_download: false,
63 | resume_download: false,
64 | proxies: nil,
65 | token: nil,
66 | revision: nil,
67 | local_files_only: false,
68 | subfolder: "",
69 | repo_type: nil,
70 | user_agent: nil,
71 | _raise_exceptions_for_gated_repo: true,
72 | _raise_exceptions_for_missing_entries: true,
73 | _raise_exceptions_for_connection_errors: true,
74 | _commit_hash: nil,
75 | **deprecated_kwargs
76 | )
77 | if is_offline_mode && !local_files_only
78 | Transformers.logger.info "Offline mode: forcing local_files_only: true"
79 | local_files_only = true
80 | end
81 | if subfolder.nil?
82 | subfolder = ""
83 | end
84 |
85 | path_or_repo_id = path_or_repo_id.to_s
86 | full_filename = File.join(subfolder, filename)
87 | if Dir.exist?(path_or_repo_id)
88 | raise Todo
89 | end
90 |
91 | if cache_dir.nil?
92 | cache_dir = TRANSFORMERS_CACHE
93 | end
94 | if cache_dir.is_a?(Pathname)
95 | cache_dir = cache_dir.to_s
96 | end
97 |
98 | if !_commit_hash.nil? && !force_download
99 | # If the file is cached under that commit hash, we return it directly.
100 | resolved_file =
101 | HfHub.try_to_load_from_cache(
102 | path_or_repo_id, full_filename, cache_dir: cache_dir, revision: _commit_hash, repo_type: repo_type
103 | )
104 | if !resolved_file.nil?
105 | if resolved_file != HfHub::CACHED_NO_EXIST
106 | return resolved_file
107 | elsif !_raise_exceptions_for_missing_entries
108 | return nil
109 | else
110 | raise EnvironmentError, "Could not locate #{full_filename} inside #{path_or_repo_id}."
111 | end
112 | end
113 | end
114 |
115 | user_agent = http_user_agent(user_agent)
116 |
117 | resolved_file = nil
118 | begin
119 | resolved_file =
120 | HfHub.hf_hub_download(
121 | path_or_repo_id,
122 | filename,
123 | subfolder: subfolder.length == 0 ? nil : subfolder,
124 | repo_type: repo_type,
125 | revision: revision,
126 | cache_dir: cache_dir,
127 | user_agent: user_agent,
128 | force_download: force_download,
129 | proxies: proxies,
130 | resume_download: resume_download,
131 | token: token,
132 | local_files_only: local_files_only
133 | )
134 | rescue => e
135 | raise e if _raise_exceptions_for_missing_entries
136 | end
137 | resolved_file
138 | end
139 |
140 | def has_file(
141 | path_or_repo,
142 | filename,
143 | revision: nil,
144 | proxies: nil,
145 | token: nil,
146 | local_files_only: false,
147 | cache_dir: nil,
148 | repo_type: nil,
149 | **deprecated_kwargs
150 | )
151 | # If path to local directory, check if the file exists
152 | if Dir.exist?(path_or_repo)
153 | return File.exist?(File.join(path_or_repo, filename))
154 | end
155 |
156 | # Else it's a repo => let's check if the file exists in local cache or on the Hub
157 |
158 | # Check if file exists in cache
159 | # This information might be outdated so it's best to also make a HEAD call (if allowed).
160 | cached_path = HfHub.try_to_load_from_cache(
161 | path_or_repo,
162 | filename,
163 | revision: revision,
164 | repo_type: repo_type,
165 | cache_dir: cache_dir
166 | )
167 | has_file_in_cache = cached_path.is_a?(String)
168 |
169 | # If local_files_only, don't try the HEAD call
170 | if local_files_only
171 | return has_file_in_cache
172 | end
173 |
174 | # Check if the file exists
175 | begin
176 | HfHub._request_wrapper(
177 | "HEAD",
178 | HfHub.hf_hub_url(path_or_repo, filename, revision: revision, repo_type: repo_type),
179 | headers: HfHub.build_hf_headers(token: token, user_agent: http_user_agent),
180 | allow_redirects: false,
181 | proxies: proxies,
182 | timeout: 10
183 | )
184 | true
185 | rescue HfHub::OfflineModeIsEnabled
186 | has_file_in_cache
187 | rescue HfHub::GatedRepoError => e
188 | Transformers.logger.error(e)
189 | raise EnvironmentError,
190 | "#{path_or_repo} is a gated repository. Make sure to request access at " +
191 | "https://huggingface.co/#{path_or_repo} and pass a token having permission to this repo either by " +
192 | "logging in with `huggingface-cli login` or by passing `token=`."
193 | rescue HfHub::RepositoryNotFoundError => e
194 | Transformers.logger.error(e)
195 | raise EnvironmentError,
196 | "#{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'."
197 | rescue HfHub::RevisionNotFoundError => e
198 | Transformers.logger.error(e)
199 | raise EnvironmentError,
200 | "#{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " +
201 | "model name. Check the model page at 'https://huggingface.co/#{path_or_repo}' for available revisions."
202 | rescue HfHub::EntryNotFoundError
203 | false # File does not exist
204 | end
205 | end
206 | end
207 | end
208 | end
209 | end
210 |
--------------------------------------------------------------------------------
/lib/transformers/utils/import_utils.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | module Utils
17 | ENV_VARS_TRUE_VALUES = ["1", "ON", "YES", "TRUE"]
18 |
19 | def self.requires_backends(obj, backends)
20 | if !backends.is_a?(Array)
21 | backends = [backends]
22 | end
23 |
24 | name = obj.is_a?(Symbol) ? obj : obj.class.name
25 |
26 | checks = backends.map { |backend| BACKENDS_MAPPING.fetch(backend) }
27 | failed = checks.filter_map { |available, msg| format(msg, name) if !available.call }
28 | if failed.any?
29 | raise Error, failed.join("")
30 | end
31 | end
32 |
33 | def self.is_vision_available
34 | defined?(Vips)
35 | end
36 |
37 | VISION_IMPORT_ERROR = <<~MSG
38 | %s requires the `ruby-vips` gem
39 | MSG
40 |
41 | BACKENDS_MAPPING = {
42 | "vision" => [singleton_method(:is_vision_available), VISION_IMPORT_ERROR]
43 | }
44 | end
45 | end
46 |
--------------------------------------------------------------------------------
/lib/transformers/utils/logging.rb:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Optuna, Hugging Face
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | module Transformers
16 | class << self
17 | attr_accessor :logger
18 | end
19 |
20 | # TODO add detail
21 | LOG_LEVELS = {
22 | "debug" => Logger::DEBUG,
23 | "info" => Logger::INFO,
24 | "warning" => Logger::WARN,
25 | "error" => Logger::ERROR,
26 | "critical" => Logger::FATAL
27 | }
28 |
29 | DEFAULT_LOG_LEVEL = Logger::WARN
30 |
31 | def self._get_default_logging_level
32 | env_level_str = ENV["TRANSFORMERS_VERBOSITY"]
33 | if env_level_str
34 | if LOG_LEVELS.include?(env_level_str)
35 | return LOG_LEVELS[env_level_str]
36 | else
37 | warn(
38 | "Unknown option TRANSFORMERS_VERBOSITY=#{env_level_str}, " +
39 | "has to be one of: #{LOG_LEVELS.keys.join(", ")}"
40 | )
41 | end
42 | end
43 | DEFAULT_LOG_LEVEL
44 | end
45 |
46 | self.logger = begin
47 | logger = Logger.new(STDERR)
48 | logger.level = _get_default_logging_level
49 | logger.formatter = proc { |severity, datetime, progname, msg| "#{msg}\n" }
50 | logger
51 | end
52 | end
53 |
--------------------------------------------------------------------------------
/lib/transformers/version.rb:
--------------------------------------------------------------------------------
1 | module Transformers
2 | VERSION = "0.1.6"
3 | end
4 |
--------------------------------------------------------------------------------
/licenses/LICENSE-huggingface-hub.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/licenses/LICENSE-sentence-transformers.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2019 Nils Reimers
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/licenses/NOTICE-sentence-transformers.txt:
--------------------------------------------------------------------------------
1 | -------------------------------------------------------------------------------
2 | Copyright 2019
3 | Ubiquitous Knowledge Processing (UKP) Lab
4 | Technische Universität Darmstadt
5 | -------------------------------------------------------------------------------
--------------------------------------------------------------------------------
/test/model_test.rb:
--------------------------------------------------------------------------------
1 | require_relative "test_helper"
2 |
3 | class ModelTest < Minitest::Test
4 | def setup
5 | skip if ci?
6 | end
7 |
8 | # https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
9 | def test_all_mini_lm
10 | sentences = ["This is an example sentence", "Each sentence is converted"]
11 |
12 | model = Transformers.pipeline("embedding", "sentence-transformers/all-MiniLM-L6-v2")
13 | embeddings = model.(sentences)
14 |
15 | assert_elements_in_delta [0.067657, 0.063496, 0.048713], embeddings[0][..2]
16 | assert_elements_in_delta [0.086439, 0.10276, 0.0053946], embeddings[1][..2]
17 | end
18 |
19 | # https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1
20 | def test_multi_qa_minilm
21 | query = "How many people live in London?"
22 | docs = ["Around 9 Million people live in London", "London is known for its financial district"]
23 |
24 | model = Transformers.pipeline("embedding", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
25 | query_embedding = model.(query)
26 | doc_embeddings = model.(docs)
27 | scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
28 | doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
29 |
30 | assert_equal "Around 9 Million people live in London", doc_score_pairs[0][0]
31 | assert_in_delta 0.9156, doc_score_pairs[0][1]
32 | assert_equal "London is known for its financial district", doc_score_pairs[1][0]
33 | assert_in_delta 0.4948, doc_score_pairs[1][1]
34 | end
35 |
36 | # https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2
37 | def test_paraphrase_minilm
38 | sentences = ["This is an example sentence", "Each sentence is converted"]
39 |
40 | model = Transformers.pipeline("embedding", "sentence-transformers/paraphrase-MiniLM-L6-v2")
41 | embeddings = model.(sentences, normalize: false)
42 |
43 | assert_elements_in_delta [0.067359, 0.783935, 0.270018], embeddings[0][..2]
44 | assert_elements_in_delta [0.122117, 0.670228, 0.317166], embeddings[1][..2]
45 | end
46 |
47 | # https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
48 | def test_mxbai_embed
49 | query_prefix = "Represent this sentence for searching relevant passages: "
50 |
51 | input = [
52 | query_prefix + "puppy",
53 | "The dog is barking",
54 | "The cat is purring"
55 | ]
56 |
57 | model = Transformers.pipeline("embedding", "mixedbread-ai/mxbai-embed-large-v1")
58 | embeddings = model.(input, pooling: "cls", normalize: false)
59 |
60 | assert_elements_in_delta [-0.00624076, 0.12864432, 0.5248165], embeddings[0][..2]
61 | assert_elements_in_delta [-0.61227727, 1.4060247, -0.04079155], embeddings[-1][..2]
62 | end
63 |
64 | # https://huggingface.co/thenlper/gte-small
65 | def test_gte_small
66 | sentences = ["That is a happy person", "That is a very happy person"]
67 |
68 | model = Transformers.pipeline("embedding", "thenlper/gte-small")
69 | embeddings = model.(sentences)
70 |
71 | assert_elements_in_delta [-0.05316979, 0.01044252, 0.06194701], embeddings[0][..2]
72 | assert_elements_in_delta [-0.05246907, 0.03752426, 0.07344585], embeddings[-1][..2]
73 | end
74 |
75 | # https://huggingface.co/intfloat/e5-base-v2
76 | def test_e5_base
77 | doc_prefix = "passage: "
78 | query_prefix = "query: "
79 |
80 | input = [
81 | doc_prefix + "Ruby is a programming language created by Matz",
82 | query_prefix + "Ruby creator"
83 | ]
84 |
85 | model = Transformers.pipeline("embedding", "intfloat/e5-base-v2")
86 | embeddings = model.(input)
87 |
88 | assert_elements_in_delta [-0.00596662, -0.03730119, -0.0703470], embeddings[0][..2]
89 | assert_elements_in_delta [0.00298353, -0.04421991, -0.0591884], embeddings[-1][..2]
90 | end
91 |
92 | # https://huggingface.co/BAAI/bge-base-en-v1.5
93 | def test_bge_base
94 | query_prefix = "Represent this sentence for searching relevant passages: "
95 |
96 | input = [
97 | "The dog is barking",
98 | "The cat is purring",
99 | query_prefix + "puppy"
100 | ]
101 |
102 | model = Transformers.pipeline("embedding", "BAAI/bge-base-en-v1.5")
103 | embeddings = model.(input)
104 |
105 | assert_elements_in_delta [-0.07482512, -0.0770234, 0.03398684], embeddings[1][..2]
106 | assert_elements_in_delta [0.00029264, -0.0619305, -0.06199387], embeddings[-1][..2]
107 | end
108 |
109 | # https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5
110 | def test_snowflake_arctic_embed
111 | query_prefix = "Represent this sentence for searching relevant passages: "
112 |
113 | input = [
114 | "The dog is barking",
115 | "The cat is purring",
116 | query_prefix + "puppy"
117 | ]
118 |
119 | model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
120 | embeddings = model.(input, pooling: "cls")
121 |
122 | assert_elements_in_delta [0.03239886, 0.0009998, 0.08401278], embeddings[0][..2]
123 | assert_elements_in_delta [-0.02530634, -0.02715422, 0.01218867], embeddings[-1][..2]
124 | end
125 |
126 | # https://huggingface.co/sentence-transformers/all-mpnet-base-v2
127 | def test_all_mpnet
128 | sentences = ["This is an example sentence", "Each sentence is converted"]
129 |
130 | model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
131 | embeddings = model.(sentences)
132 |
133 | assert_elements_in_delta [0.02250263, -0.07829167, -0.02303071], embeddings[0][..2]
134 | assert_elements_in_delta [0.04170236, 0.00109747, -0.01553415], embeddings[1][..2]
135 | end
136 |
137 | # https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1
138 | def test_opensearch
139 | docs = ["The dog is barking", "The cat is purring", "The bear is growling"]
140 |
141 | model_id = "opensearch-project/opensearch-neural-sparse-encoding-v1"
142 | model = Transformers::AutoModelForMaskedLM.from_pretrained(model_id)
143 | tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
144 | special_token_ids = tokenizer.special_tokens_map.map { |_, token| tokenizer.vocab[token] }
145 |
146 | feature = tokenizer.(docs, padding: true, truncation: true, return_tensors: "pt", return_token_type_ids: false)
147 | output = model.(**feature)[0]
148 |
149 | values, _ = Torch.max(output * feature[:attention_mask].unsqueeze(-1), dim: 1)
150 | values = Torch.log(1 + Torch.relu(values))
151 | values[0.., special_token_ids] = 0
152 | embeddings = values.to_a
153 |
154 | assert_equal 74, embeddings[0].count { |v| v != 0 }
155 | assert_equal 77, embeddings[1].count { |v| v != 0 }
156 | assert_equal 102, embeddings[2].count { |v| v != 0 }
157 | end
158 |
159 | # https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1
160 | def test_mxbai_rerank
161 | query = "How many people live in London?"
162 | docs = ["Around 9 Million people live in London", "London is known for its financial district"]
163 |
164 | model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
165 | result = model.(query, docs)
166 |
167 | assert_equal 0, result[0][:index]
168 | assert_in_delta 0.984, result[0][:score]
169 |
170 | assert_equal 1, result[1][:index]
171 | assert_in_delta 0.139, result[1][:score]
172 | end
173 |
174 | # https://huggingface.co/BAAI/bge-reranker-base
175 | def test_bge_reranker
176 | query = "How many people live in London?"
177 | docs = ["Around 9 Million people live in London", "London is known for its financial district"]
178 |
179 | model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base")
180 | result = model.(query, docs)
181 |
182 | assert_equal 0, result[0][:index]
183 | assert_in_delta 0.996, result[0][:score]
184 |
185 | assert_equal 1, result[1][:index]
186 | assert_in_delta 0.000158, result[1][:score], 0.000001
187 | end
188 | end
189 |
--------------------------------------------------------------------------------
/test/pipeline_test.rb:
--------------------------------------------------------------------------------
1 | require_relative "test_helper"
2 |
3 | class PipelineTest < Minitest::Test
4 | def test_ner
5 | ner = Transformers.pipeline("ner")
6 | result = ner.("Ruby is a programming language created by Matz")
7 | assert_equal 3, result.size
8 | assert_equal "I-MISC", result[0][:entity]
9 | assert_in_delta 0.96, result[0][:score]
10 | assert_equal 1, result[0][:index]
11 | assert_equal "Ruby", result[0][:word]
12 | assert_equal 0, result[0][:start]
13 | assert_equal 4, result[0][:end]
14 | end
15 |
16 | def test_ner_aggregation_strategy
17 | ner = Transformers.pipeline("ner", aggregation_strategy: "simple")
18 | result = ner.("Ruby is a programming language created by Matz")
19 | assert_equal 2, result.size
20 |
21 | assert_equal "MISC", result[0][:entity_group]
22 | assert_in_delta 0.9608, result[0][:score]
23 | assert_equal "Ruby", result[0][:word]
24 | assert_equal 0, result[0][:start]
25 | assert_equal 4, result[0][:end]
26 |
27 | assert_equal "PER", result[1][:entity_group]
28 | assert_in_delta 0.9496, result[1][:score]
29 | assert_equal "Matz", result[1][:word]
30 | assert_equal 42, result[1][:start]
31 | assert_equal 46, result[1][:end]
32 | end
33 |
34 | def test_sentiment_analysis
35 | classifier = Transformers.pipeline("sentiment-analysis")
36 | result = classifier.("We are very happy to show you the 🤗 Transformers library.")
37 | assert_equal "POSITIVE", result[:label]
38 | assert_in_delta 0.9998, result[:score]
39 |
40 | result = classifier.(["We are very happy to show you the 🤗 Transformers library.", "We hope you don't hate it."])
41 | assert_equal "POSITIVE", result[0][:label]
42 | assert_in_delta 0.9998, result[0][:score]
43 | assert_equal "NEGATIVE", result[1][:label]
44 | assert_in_delta 0.5309, result[1][:score]
45 | end
46 |
47 | def test_question_answering
48 | qa = Transformers.pipeline("question-answering")
49 | result = qa.(question: "Who invented Ruby?", context: "Ruby is a programming language created by Matz")
50 | assert_in_delta 0.998, result[:score]
51 | assert_equal 42, result[:start]
52 | assert_equal 46, result[:end]
53 | assert_equal "Matz", result[:answer]
54 |
55 | result = qa.("Who invented Ruby?", "Ruby is a programming language created by Matz")
56 | assert_equal "Matz", result[:answer]
57 | end
58 |
59 | def test_feature_extraction
60 | fe = Transformers.pipeline("feature-extraction")
61 | result = fe.("We are very happy to show you the 🤗 Transformers library.")
62 | assert_in_delta 0.454, result[0][0][0]
63 | end
64 |
65 | def test_embedding
66 | sentences = ["This is an example sentence", "Each sentence is converted"]
67 | embed = Transformers.pipeline("embedding")
68 | embeddings = embed.(sentences)
69 | assert_elements_in_delta [0.067657, 0.063496, 0.048713], embeddings[0][..2]
70 | assert_elements_in_delta [0.086439, 0.10276, 0.0053946], embeddings[1][..2]
71 | end
72 |
73 | def test_reranking
74 | query = "How many people live in London?"
75 | docs = ["Around 9 Million people live in London", "London is known for its financial district"]
76 | rerank = Transformers.pipeline("reranking")
77 | result = rerank.(query, docs)
78 | assert_equal 2, result.size
79 | assert_equal 0, result[0][:index]
80 | assert_in_delta 0.984, result[0][:score]
81 | assert_equal 1, result[1][:index]
82 | assert_in_delta 0.139, result[1][:score]
83 | end
84 |
85 | def test_image_classification
86 | classifier = Transformers.pipeline("image-classification")
87 | result = classifier.("test/support/pipeline-cat-chonk.jpeg")
88 | assert_equal "lynx, catamount", result[0][:label]
89 | assert_in_delta 0.433, result[0][:score], 0.01
90 | assert_equal "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", result[1][:label]
91 | assert_in_delta 0.035, result[1][:score], 0.01
92 | end
93 |
94 | def test_image_feature_extraction
95 | fe = Transformers.pipeline("image-feature-extraction")
96 | result = fe.("test/support/pipeline-cat-chonk.jpeg")
97 | assert_in_delta 0.868, result[0][0][0], 0.01
98 | end
99 |
100 | def test_device
101 | skip unless mac?
102 |
103 | sentences = ["This is an example sentence", "Each sentence is converted"]
104 | embed = Transformers.pipeline("embedding", device: "mps")
105 | embeddings = embed.(sentences)
106 | assert_elements_in_delta [0.067657, 0.063496, 0.048713], embeddings[0][..2]
107 | assert_elements_in_delta [0.086439, 0.10276, 0.0053946], embeddings[1][..2]
108 | end
109 |
110 | def test_pipeline_input_works_with_more_than_ten
111 | embedding = Transformers.pipeline("embedding")
112 | 11.times do
113 | result = embedding.("Ruby is a programming language created by Matz")
114 | assert_instance_of(Array, result)
115 | end
116 | end
117 | end
118 |
--------------------------------------------------------------------------------
/test/test_helper.rb:
--------------------------------------------------------------------------------
1 | require "bundler/setup"
2 | Bundler.require(:default)
3 | require "minitest/autorun"
4 |
5 | unless ENV["TRANSFORMERS_VERBOSITY"]
6 | Transformers.logger.level = Logger::ERROR
7 | end
8 |
9 | Transformers.fast_init = true
10 |
11 | class Minitest::Test
12 | def assert_elements_in_delta(expected, actual)
13 | assert_equal expected.size, actual.size
14 | expected.zip(actual) do |exp, act|
15 | assert_in_delta exp, act
16 | end
17 | end
18 |
19 | def ci?
20 | ENV["CI"]
21 | end
22 |
23 | def mac?
24 | RbConfig::CONFIG["host_os"] =~ /darwin/i
25 | end
26 | end
27 |
--------------------------------------------------------------------------------
/test/tokenizer_test.rb:
--------------------------------------------------------------------------------
1 | require_relative "test_helper"
2 |
3 | class TokenizerTest < Minitest::Test
4 | def test_auto_tokenizer
5 | model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
6 | tokenizer = Transformers::AutoTokenizer.from_pretrained(model_name)
7 |
8 | encoding = tokenizer.("We are very happy to show you the 🤗 Transformers library.")
9 | assert_equal [101, 11312, 10320, 12495, 19308, 10114, 11391, 10855, 10103, 100, 58263, 13299, 119, 102], encoding[:input_ids]
10 | assert_equal [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], encoding[:attention_mask]
11 | end
12 | end
13 |
--------------------------------------------------------------------------------
/transformers-rb.gemspec:
--------------------------------------------------------------------------------
1 | require_relative "lib/transformers/version"
2 |
3 | Gem::Specification.new do |spec|
4 | spec.name = "transformers-rb"
5 | spec.version = Transformers::VERSION
6 | spec.summary = "State-of-the-art transformers for Ruby"
7 | spec.homepage = "https://github.com/ankane/transformers-ruby"
8 | spec.license = "Apache-2.0"
9 |
10 | spec.author = "Andrew Kane"
11 | spec.email = "andrew@ankane.org"
12 |
13 | spec.files = Dir["*.{md,txt}", "{lib,licenses}/**/*"]
14 | spec.require_path = "lib"
15 |
16 | spec.required_ruby_version = ">= 3.1"
17 |
18 | spec.add_dependency "logger"
19 | spec.add_dependency "numo-narray", ">= 0.9.2"
20 | spec.add_dependency "safetensors", ">= 0.1.1"
21 | spec.add_dependency "tokenizers", ">= 0.5.3"
22 | spec.add_dependency "torch-rb", ">= 0.17.1"
23 | end
24 |
--------------------------------------------------------------------------------