├── .github └── workflows │ └── build.yml ├── .gitignore ├── CHANGELOG.md ├── Gemfile ├── LICENSE.txt ├── README.md ├── Rakefile ├── lib ├── transformers-rb.rb ├── transformers.rb └── transformers │ ├── activations.rb │ ├── configuration_utils.rb │ ├── convert_slow_tokenizer.rb │ ├── data │ └── processors │ │ └── squad.rb │ ├── dynamic_module_utils.rb │ ├── feature_extraction_utils.rb │ ├── hf_hub │ ├── constants.rb │ ├── errors.rb │ ├── file_download.rb │ └── utils │ │ ├── _errors.rb │ │ └── _headers.rb │ ├── image_processing_base.rb │ ├── image_processing_utils.rb │ ├── image_transforms.rb │ ├── image_utils.rb │ ├── modeling_outputs.rb │ ├── modeling_utils.rb │ ├── models │ ├── auto │ │ ├── auto_factory.rb │ │ ├── configuration_auto.rb │ │ ├── feature_extraction_auto.rb │ │ ├── image_processing_auto.rb │ │ ├── modeling_auto.rb │ │ └── tokenization_auto.rb │ ├── bert │ │ ├── configuration_bert.rb │ │ ├── modeling_bert.rb │ │ ├── tokenization_bert.rb │ │ └── tokenization_bert_fast.rb │ ├── deberta_v2 │ │ ├── configuration_deberta_v2.rb │ │ ├── modeling_deberta_v2.rb │ │ └── tokenization_deberta_v2_fast.rb │ ├── distilbert │ │ ├── configuration_distilbert.rb │ │ ├── modeling_distilbert.rb │ │ ├── tokenization_distilbert.rb │ │ └── tokenization_distilbert_fast.rb │ ├── mpnet │ │ ├── configuration_mpnet.rb │ │ ├── modeling_mpnet.rb │ │ └── tokenization_mpnet_fast.rb │ ├── vit │ │ ├── configuration_vit.rb │ │ ├── image_processing_vit.rb │ │ └── modeling_vit.rb │ └── xlm_roberta │ │ ├── configuration_xlm_roberta.rb │ │ ├── modeling_xlm_roberta.rb │ │ └── tokenization_xlm_roberta_fast.rb │ ├── pipelines │ ├── _init.rb │ ├── base.rb │ ├── embedding.rb │ ├── feature_extraction.rb │ ├── image_classification.rb │ ├── image_feature_extraction.rb │ ├── pt_utils.rb │ ├── question_answering.rb │ ├── reranking.rb │ ├── text_classification.rb │ └── token_classification.rb │ ├── ruby_utils.rb │ ├── sentence_transformer.rb │ ├── tokenization_utils.rb │ ├── tokenization_utils_base.rb │ ├── tokenization_utils_fast.rb │ ├── torch_utils.rb │ ├── utils │ ├── _init.rb │ ├── generic.rb │ ├── hub.rb │ ├── import_utils.rb │ └── logging.rb │ └── version.rb ├── licenses ├── LICENSE-huggingface-hub.txt ├── LICENSE-sentence-transformers.txt └── NOTICE-sentence-transformers.txt ├── test ├── model_test.rb ├── pipeline_test.rb ├── test_helper.rb └── tokenizer_test.rb └── transformers-rb.gemspec /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | on: [push, pull_request] 3 | jobs: 4 | build: 5 | runs-on: ubuntu-latest 6 | env: 7 | BUNDLE_BUILD__TORCH___RB: "--with-torch-dir=/home/runner/libtorch" 8 | LIBTORCH_VERSION: 2.5.0 9 | steps: 10 | - uses: actions/checkout@v4 11 | - uses: actions/cache@v4 12 | with: 13 | path: ~/libtorch 14 | key: libtorch-${{ env.LIBTORCH_VERSION }} 15 | id: cache-libtorch 16 | - name: Download LibTorch 17 | if: steps.cache-libtorch.outputs.cache-hit != 'true' 18 | run: | 19 | cd ~ 20 | wget -q -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-$LIBTORCH_VERSION%2Bcpu.zip 21 | unzip -q libtorch.zip 22 | - uses: ruby/setup-ruby@v1 23 | with: 24 | ruby-version: 3.4 25 | bundler-cache: true 26 | - uses: actions/cache@v4 27 | with: 28 | path: ~/.cache/huggingface 29 | key: huggingface 30 | - run: sudo apt-get update && sudo apt-get install libvips 31 | - run: bundle exec rake download:files 32 | - run: bundle exec rake test 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.bundle/ 2 | /.yardoc 3 | /_yardoc/ 4 | /coverage/ 5 | /doc/ 6 | /pkg/ 7 | /spec/reports/ 8 | /test/support/ 9 | /tmp/ 10 | *.lock 11 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.1.6 (2024-12-29) 2 | 3 | - Fixed error with failed HTTP requests 4 | - Fixed warning with Ruby 3.4 5 | 6 | ## 0.1.5 (2024-11-01) 7 | 8 | - Fixed error with pipelines when called more than 10 times 9 | - Fixed `device` option for pipelines 10 | - Fixed error with `reranking` pipeline 11 | 12 | ## 0.1.4 (2024-10-22) 13 | 14 | - Added `BertForSequenceClassification` 15 | 16 | ## 0.1.3 (2024-09-17) 17 | 18 | - Added `reranking` pipeline 19 | - Added DeBERTa-v2 20 | - Added MPNet 21 | - Added XLM-RoBERTa 22 | 23 | ## 0.1.2 (2024-09-10) 24 | 25 | - Fixed default revision for pipelines 26 | 27 | ## 0.1.1 (2024-08-29) 28 | 29 | - Added `embedding` pipeline 30 | - Added experimental `fast_init` option 31 | - Improved performance of loading models 32 | - Fixed error with `aggregation_strategy` option 33 | 34 | ## 0.1.0 (2024-08-19) 35 | 36 | - First release 37 | -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | gemspec 4 | 5 | gem "rake" 6 | gem "minitest" 7 | gem "ruby-vips" 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformers.rb 2 | 3 | :slightly_smiling_face: State-of-the-art [transformers](https://github.com/huggingface/transformers) for Ruby 4 | 5 | For fast inference, check out [Informers](https://github.com/ankane/informers) :fire: 6 | 7 | [![Build Status](https://github.com/ankane/transformers-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/transformers-ruby/actions) 8 | 9 | ## Installation 10 | 11 | First, [install Torch.rb](https://github.com/ankane/torch.rb#installation). 12 | 13 | Then add this line to your application’s Gemfile: 14 | 15 | ```ruby 16 | gem "transformers-rb" 17 | ``` 18 | 19 | ## Getting Started 20 | 21 | - [Models](#models) 22 | - [Pipelines](#pipelines) 23 | 24 | ## Models 25 | 26 | Embedding 27 | 28 | - [sentence-transformers/all-MiniLM-L6-v2](#sentence-transformersall-MiniLM-L6-v2) 29 | - [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](#sentence-transformersmulti-qa-MiniLM-L6-cos-v1) 30 | - [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2) 31 | - [sentence-transformers/paraphrase-MiniLM-L6-v2](#sentence-transformersparaphrase-minilm-l6-v2) 32 | - [mixedbread-ai/mxbai-embed-large-v1](#mixedbread-aimxbai-embed-large-v1) 33 | - [thenlper/gte-small](#thenlpergte-small) 34 | - [intfloat/e5-base-v2](#intfloate5-base-v2) 35 | - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15) 36 | - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15) 37 | 38 | Sparse embedding 39 | 40 | - [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1) 41 | 42 | Reranking 43 | 44 | - [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1) 45 | - [BAAI/bge-reranker-base](#baaibge-reranker-base) 46 | 47 | ### sentence-transformers/all-MiniLM-L6-v2 48 | 49 | [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) 50 | 51 | ```ruby 52 | sentences = ["This is an example sentence", "Each sentence is converted"] 53 | 54 | model = Transformers.pipeline("embedding", "sentence-transformers/all-MiniLM-L6-v2") 55 | embeddings = model.(sentences) 56 | ``` 57 | 58 | ### sentence-transformers/multi-qa-MiniLM-L6-cos-v1 59 | 60 | [Docs](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1) 61 | 62 | ```ruby 63 | query = "How many people live in London?" 64 | docs = ["Around 9 Million people live in London", "London is known for its financial district"] 65 | 66 | model = Transformers.pipeline("embedding", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1") 67 | query_embedding = model.(query) 68 | doc_embeddings = model.(docs) 69 | scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } } 70 | doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s } 71 | ``` 72 | 73 | ### sentence-transformers/all-mpnet-base-v2 74 | 75 | [Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) 76 | 77 | ```ruby 78 | sentences = ["This is an example sentence", "Each sentence is converted"] 79 | 80 | model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2") 81 | embeddings = model.(sentences) 82 | ``` 83 | 84 | ### sentence-transformers/paraphrase-MiniLM-L6-v2 85 | 86 | [Docs](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) 87 | 88 | ```ruby 89 | sentences = ["This is an example sentence", "Each sentence is converted"] 90 | 91 | model = Transformers.pipeline("embedding", "sentence-transformers/paraphrase-MiniLM-L6-v2") 92 | embeddings = model.(sentences) 93 | ``` 94 | 95 | ### mixedbread-ai/mxbai-embed-large-v1 96 | 97 | [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1) 98 | 99 | ```ruby 100 | query_prefix = "Represent this sentence for searching relevant passages: " 101 | 102 | input = [ 103 | "The dog is barking", 104 | "The cat is purring", 105 | query_prefix + "puppy" 106 | ] 107 | 108 | model = Transformers.pipeline("embedding", "mixedbread-ai/mxbai-embed-large-v1") 109 | embeddings = model.(input) 110 | ``` 111 | 112 | ### thenlper/gte-small 113 | 114 | [Docs](https://huggingface.co/thenlper/gte-small) 115 | 116 | ```ruby 117 | sentences = ["That is a happy person", "That is a very happy person"] 118 | 119 | model = Transformers.pipeline("embedding", "thenlper/gte-small") 120 | embeddings = model.(sentences) 121 | ``` 122 | 123 | ### intfloat/e5-base-v2 124 | 125 | [Docs](https://huggingface.co/intfloat/e5-base-v2) 126 | 127 | ```ruby 128 | doc_prefix = "passage: " 129 | query_prefix = "query: " 130 | 131 | input = [ 132 | doc_prefix + "Ruby is a programming language created by Matz", 133 | query_prefix + "Ruby creator" 134 | ] 135 | 136 | model = Transformers.pipeline("embedding", "intfloat/e5-base-v2") 137 | embeddings = model.(input) 138 | ``` 139 | 140 | ### BAAI/bge-base-en-v1.5 141 | 142 | [Docs](https://huggingface.co/BAAI/bge-base-en-v1.5) 143 | 144 | ```ruby 145 | query_prefix = "Represent this sentence for searching relevant passages: " 146 | 147 | input = [ 148 | "The dog is barking", 149 | "The cat is purring", 150 | query_prefix + "puppy" 151 | ] 152 | 153 | model = Transformers.pipeline("embedding", "BAAI/bge-base-en-v1.5") 154 | embeddings = model.(input) 155 | ``` 156 | 157 | ### Snowflake/snowflake-arctic-embed-m-v1.5 158 | 159 | [Docs](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5) 160 | 161 | ```ruby 162 | query_prefix = "Represent this sentence for searching relevant passages: " 163 | 164 | input = [ 165 | "The dog is barking", 166 | "The cat is purring", 167 | query_prefix + "puppy" 168 | ] 169 | 170 | model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5") 171 | embeddings = model.(input, pooling: "cls") 172 | ``` 173 | 174 | ### opensearch-project/opensearch-neural-sparse-encoding-v1 175 | 176 | [Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1) 177 | 178 | ```ruby 179 | docs = ["The dog is barking", "The cat is purring", "The bear is growling"] 180 | 181 | model_id = "opensearch-project/opensearch-neural-sparse-encoding-v1" 182 | model = Transformers::AutoModelForMaskedLM.from_pretrained(model_id) 183 | tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id) 184 | special_token_ids = tokenizer.special_tokens_map.map { |_, token| tokenizer.vocab[token] } 185 | 186 | feature = tokenizer.(docs, padding: true, truncation: true, return_tensors: "pt", return_token_type_ids: false) 187 | output = model.(**feature)[0] 188 | 189 | values, _ = Torch.max(output * feature[:attention_mask].unsqueeze(-1), dim: 1) 190 | values = Torch.log(1 + Torch.relu(values)) 191 | values[0.., special_token_ids] = 0 192 | embeddings = values.to_a 193 | ``` 194 | 195 | ### mixedbread-ai/mxbai-rerank-base-v1 196 | 197 | [Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1) 198 | 199 | ```ruby 200 | query = "How many people live in London?" 201 | docs = ["Around 9 Million people live in London", "London is known for its financial district"] 202 | 203 | model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1") 204 | result = model.(query, docs) 205 | ``` 206 | 207 | ### BAAI/bge-reranker-base 208 | 209 | [Docs](https://huggingface.co/BAAI/bge-reranker-base) 210 | 211 | ```ruby 212 | query = "How many people live in London?" 213 | docs = ["Around 9 Million people live in London", "London is known for its financial district"] 214 | 215 | model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base") 216 | result = model.(query, docs) 217 | ``` 218 | 219 | ## Pipelines 220 | 221 | - [Text](#text) 222 | - [Vision](#vision) 223 | 224 | ### Text 225 | 226 | Embedding 227 | 228 | ```ruby 229 | embed = Transformers.pipeline("embedding") 230 | embed.("We are very happy to show you the 🤗 Transformers library.") 231 | ``` 232 | 233 | Reranking 234 | 235 | ```ruby 236 | rerank = Informers.pipeline("reranking") 237 | rerank.("Who created Ruby?", ["Matz created Ruby", "Another doc"]) 238 | ``` 239 | 240 | Named-entity recognition 241 | 242 | ```ruby 243 | ner = Transformers.pipeline("ner") 244 | ner.("Ruby is a programming language created by Matz") 245 | ``` 246 | 247 | Sentiment analysis 248 | 249 | ```ruby 250 | classifier = Transformers.pipeline("sentiment-analysis") 251 | classifier.("We are very happy to show you the 🤗 Transformers library.") 252 | ``` 253 | 254 | Question answering 255 | 256 | ```ruby 257 | qa = Transformers.pipeline("question-answering") 258 | qa.(question: "Who invented Ruby?", context: "Ruby is a programming language created by Matz") 259 | ``` 260 | 261 | Feature extraction 262 | 263 | ```ruby 264 | extractor = Transformers.pipeline("feature-extraction") 265 | extractor.("We are very happy to show you the 🤗 Transformers library.") 266 | ``` 267 | 268 | ### Vision 269 | 270 | Image classification 271 | 272 | ```ruby 273 | classifier = Transformers.pipeline("image-classification") 274 | classifier.("image.jpg") 275 | ``` 276 | 277 | Image feature extraction 278 | 279 | ```ruby 280 | extractor = Transformers.pipeline("image-feature-extraction") 281 | extractor.("image.jpg") 282 | ``` 283 | 284 | ## API 285 | 286 | This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). The following model architectures are currently supported: 287 | 288 | - BERT 289 | - DeBERTa-v2 290 | - DistilBERT 291 | - MPNet 292 | - ViT 293 | - XLM-RoBERTa 294 | 295 | ## History 296 | 297 | View the [changelog](https://github.com/ankane/transformers-ruby/blob/master/CHANGELOG.md) 298 | 299 | ## Contributing 300 | 301 | Everyone is encouraged to help improve this project. Here are a few ways you can help: 302 | 303 | - [Report bugs](https://github.com/ankane/transformers-ruby/issues) 304 | - Fix bugs and [submit pull requests](https://github.com/ankane/transformers-ruby/pulls) 305 | - Write, clarify, or fix documentation 306 | - Suggest or add new features 307 | 308 | To get started with development: 309 | 310 | ```sh 311 | git clone https://github.com/ankane/transformers-ruby.git 312 | cd transformers-ruby 313 | bundle install 314 | bundle exec rake download:files 315 | bundle exec rake test 316 | ``` 317 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require "bundler/gem_tasks" 2 | require "rake/testtask" 3 | 4 | task default: :test 5 | Rake::TestTask.new do |t| 6 | t.libs << "test" 7 | t.pattern = FileList["test/**/*_test.rb"].exclude("test/model_test.rb") 8 | end 9 | 10 | def download_file(url) 11 | require "open-uri" 12 | 13 | file = File.basename(url) 14 | puts "Downloading #{file}..." 15 | dest = "test/support/#{file}" 16 | File.binwrite(dest, URI.parse(url).read) 17 | puts "Saved #{dest}" 18 | end 19 | 20 | namespace :download do 21 | task :files do 22 | Dir.mkdir("test/support") unless Dir.exist?("test/support") 23 | 24 | download_file("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg") 25 | end 26 | end 27 | -------------------------------------------------------------------------------- /lib/transformers-rb.rb: -------------------------------------------------------------------------------- 1 | require_relative "transformers" 2 | -------------------------------------------------------------------------------- /lib/transformers.rb: -------------------------------------------------------------------------------- 1 | # dependencies 2 | require "numo/narray" 3 | require "safetensors" 4 | require "tokenizers" 5 | require "torch-rb" 6 | 7 | # stdlib 8 | require "cgi" 9 | require "fileutils" 10 | require "io/console" 11 | require "json" 12 | require "logger" 13 | require "net/http" 14 | require "pathname" 15 | require "securerandom" 16 | require "set" 17 | require "uri" 18 | 19 | # modules 20 | require_relative "transformers/ruby_utils" 21 | require_relative "transformers/utils/generic" 22 | require_relative "transformers/activations" 23 | require_relative "transformers/dynamic_module_utils" 24 | require_relative "transformers/configuration_utils" 25 | require_relative "transformers/convert_slow_tokenizer" 26 | require_relative "transformers/feature_extraction_utils" 27 | require_relative "transformers/image_utils" 28 | require_relative "transformers/image_processing_base" 29 | require_relative "transformers/image_processing_utils" 30 | require_relative "transformers/image_transforms" 31 | require_relative "transformers/modeling_outputs" 32 | require_relative "transformers/modeling_utils" 33 | require_relative "transformers/sentence_transformer" 34 | require_relative "transformers/tokenization_utils_base" 35 | require_relative "transformers/tokenization_utils" 36 | require_relative "transformers/tokenization_utils_fast" 37 | require_relative "transformers/torch_utils" 38 | require_relative "transformers/version" 39 | 40 | # data 41 | require_relative "transformers/data/processors/squad" 42 | 43 | # hub 44 | require_relative "transformers/hf_hub/constants" 45 | require_relative "transformers/hf_hub/errors" 46 | require_relative "transformers/hf_hub/file_download" 47 | require_relative "transformers/hf_hub/utils/_errors" 48 | require_relative "transformers/hf_hub/utils/_headers" 49 | 50 | # models auto 51 | require_relative "transformers/models/auto/auto_factory" 52 | require_relative "transformers/models/auto/configuration_auto" 53 | require_relative "transformers/models/auto/feature_extraction_auto" 54 | require_relative "transformers/models/auto/image_processing_auto" 55 | require_relative "transformers/models/auto/modeling_auto" 56 | require_relative "transformers/models/auto/tokenization_auto" 57 | 58 | # models bert 59 | require_relative "transformers/models/bert/configuration_bert" 60 | require_relative "transformers/models/bert/modeling_bert" 61 | require_relative "transformers/models/bert/tokenization_bert" 62 | require_relative "transformers/models/bert/tokenization_bert_fast" 63 | 64 | # models deberta-v2 65 | require_relative "transformers/models/deberta_v2/configuration_deberta_v2" 66 | require_relative "transformers/models/deberta_v2/modeling_deberta_v2" 67 | require_relative "transformers/models/deberta_v2/tokenization_deberta_v2_fast" 68 | 69 | # models distilbert 70 | require_relative "transformers/models/distilbert/configuration_distilbert" 71 | require_relative "transformers/models/distilbert/modeling_distilbert" 72 | require_relative "transformers/models/distilbert/tokenization_distilbert" 73 | require_relative "transformers/models/distilbert/tokenization_distilbert_fast" 74 | 75 | # models mpnet 76 | require_relative "transformers/models/mpnet/configuration_mpnet" 77 | require_relative "transformers/models/mpnet/modeling_mpnet" 78 | require_relative "transformers/models/mpnet/tokenization_mpnet_fast" 79 | 80 | # models vit 81 | require_relative "transformers/models/vit/configuration_vit" 82 | require_relative "transformers/models/vit/image_processing_vit" 83 | require_relative "transformers/models/vit/modeling_vit" 84 | 85 | # models xml-roberta 86 | require_relative "transformers/models/xlm_roberta/configuration_xlm_roberta" 87 | require_relative "transformers/models/xlm_roberta/modeling_xlm_roberta" 88 | require_relative "transformers/models/xlm_roberta/tokenization_xlm_roberta_fast" 89 | 90 | # pipelines 91 | require_relative "transformers/pipelines/base" 92 | require_relative "transformers/pipelines/feature_extraction" 93 | require_relative "transformers/pipelines/embedding" 94 | require_relative "transformers/pipelines/image_classification" 95 | require_relative "transformers/pipelines/image_feature_extraction" 96 | require_relative "transformers/pipelines/pt_utils" 97 | require_relative "transformers/pipelines/question_answering" 98 | require_relative "transformers/pipelines/reranking" 99 | require_relative "transformers/pipelines/text_classification" 100 | require_relative "transformers/pipelines/token_classification" 101 | require_relative "transformers/pipelines/_init" 102 | 103 | # utils 104 | require_relative "transformers/utils/_init" 105 | require_relative "transformers/utils/import_utils" 106 | require_relative "transformers/utils/hub" 107 | require_relative "transformers/utils/logging" 108 | 109 | module Transformers 110 | class Error < StandardError; end 111 | 112 | class Todo < Error 113 | def message 114 | "not implemented yet" 115 | end 116 | end 117 | 118 | class << self 119 | # experimental 120 | attr_accessor :fast_init 121 | end 122 | self.fast_init = false 123 | end 124 | -------------------------------------------------------------------------------- /lib/transformers/activations.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class GELUActivation < Torch::NN::Module 17 | def initialize(use_gelu_python: false) 18 | super() 19 | if use_gelu_python 20 | @act = _gelu_python 21 | else 22 | @act = Torch::NN::Functional.method(:gelu) 23 | end 24 | end 25 | 26 | def _gelu_python(input) 27 | input * 0.5 * (1.0 + Torch.erf(input / Math.sqrt(2.0))) 28 | end 29 | 30 | def forward(input) 31 | @act.(input) 32 | end 33 | end 34 | 35 | class ClassInstantier 36 | def initialize(data) 37 | @data = data 38 | end 39 | 40 | def [](key) 41 | content = @data.fetch(key) 42 | cls, kwargs = content.is_a?(Array) ? content : [content, {}] 43 | cls.new(**kwargs) 44 | end 45 | end 46 | 47 | ACT2CLS = { 48 | "gelu" => GELUActivation 49 | } 50 | ACT2FN = ClassInstantier.new(ACT2CLS) 51 | 52 | module Activations 53 | def self.get_activation(activation_string) 54 | ACT2FN[activation_string] 55 | end 56 | end 57 | end 58 | -------------------------------------------------------------------------------- /lib/transformers/convert_slow_tokenizer.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module ConvertSlowTokenizer 17 | class Converter 18 | def initialize(original_tokenizer) 19 | @original_tokenizer = original_tokenizer 20 | end 21 | 22 | def converted 23 | raise NotImplementedError 24 | end 25 | end 26 | 27 | class BertConverter < Converter 28 | def converted 29 | vocab = @original_tokenizer.vocab 30 | tokenizer = Tokenizers::Tokenizer.new(Tokenizers::Models::WordPiece.new(vocab: vocab, unk_token: @original_tokenizer.unk_token.to_s)) 31 | 32 | tokenize_chinese_chars = false 33 | strip_accents = false 34 | do_lower_case = false 35 | if @original_tokenizer.basic_tokenizer 36 | tokenize_chinese_chars = @original_tokenizer.basic_tokenizer.tokenize_chinese_chars 37 | strip_accents = @original_tokenizer.basic_tokenizer.strip_accents 38 | do_lower_case = @original_tokenizer.basic_tokenizer.do_lower_case 39 | end 40 | 41 | tokenizer.normalizer = 42 | Tokenizers::Normalizers::BertNormalizer.new( 43 | clean_text: true, 44 | handle_chinese_chars: tokenize_chinese_chars, 45 | strip_accents: strip_accents, 46 | lowercase: do_lower_case, 47 | ) 48 | tokenizer.pre_tokenizer = Tokenizers::PreTokenizers::BertPreTokenizer.new 49 | 50 | cls = @original_tokenizer.cls_token.to_s 51 | sep = @original_tokenizer.sep_token.to_s 52 | cls_token_id = @original_tokenizer.cls_token_id 53 | sep_token_id = @original_tokenizer.sep_token_id 54 | 55 | tokenizer.post_processor = 56 | Tokenizers::Processors::TemplateProcessing.new( 57 | single: "#{cls}:0 $A:0 #{sep}:0", 58 | pair: "#{cls}:0 $A:0 #{sep}:0 $B:1 #{sep}:1", 59 | special_tokens: [ 60 | [cls, cls_token_id], 61 | [sep, sep_token_id] 62 | ] 63 | ) 64 | tokenizer.decoder = Tokenizers::Decoders::WordPiece.new(prefix: "##") 65 | 66 | tokenizer 67 | end 68 | end 69 | 70 | SLOW_TO_FAST_CONVERTERS = { 71 | "BertTokenizer" => BertConverter, 72 | "DistilBertTokenizer" => BertConverter 73 | } 74 | 75 | def self.convert_slow_tokenizer(transformer_tokenizer) 76 | tokenizer_class_name = transformer_tokenizer.class.name.split("::").last 77 | 78 | if !SLOW_TO_FAST_CONVERTERS.include?(tokenizer_class_name) 79 | raise ArgumentError, 80 | "An instance of tokenizer class #{tokenizer_class_name} cannot be converted in a Fast tokenizer instance." + 81 | " No converter was found. Currently available slow->fast convertors:" + 82 | " #{SLOW_TO_FAST_CONVERTERS.keys}" 83 | end 84 | 85 | converter_class = SLOW_TO_FAST_CONVERTERS.fetch(tokenizer_class_name) 86 | 87 | converter_class.new(transformer_tokenizer).converted 88 | end 89 | end 90 | end 91 | -------------------------------------------------------------------------------- /lib/transformers/data/processors/squad.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class SquadExample 17 | attr_reader :question_text, :context_text 18 | 19 | def initialize( 20 | qas_id, 21 | question_text, 22 | context_text, 23 | answer_text, 24 | start_position_character, 25 | title, 26 | answers: [], 27 | is_impossible: false 28 | ) 29 | @qas_id = qas_id 30 | @question_text = question_text 31 | @context_text = context_text 32 | @answer_text = answer_text 33 | @title = title 34 | @is_impossible = is_impossible 35 | @answers = answers 36 | 37 | @start_position, @end_position = 0, 0 38 | 39 | doc_tokens = [] 40 | char_to_word_offset = [] 41 | prev_is_whitespace = true 42 | 43 | # Split on whitespace so that different tokens may be attributed to their original position. 44 | @context_text.each_char do |c| 45 | if _is_whitespace(c) 46 | prev_is_whitespace = true 47 | else 48 | if prev_is_whitespace 49 | doc_tokens << c 50 | else 51 | doc_tokens[-1] += c 52 | end 53 | prev_is_whitespace = false 54 | end 55 | char_to_word_offset << (doc_tokens.length - 1) 56 | end 57 | 58 | @doc_tokens = doc_tokens 59 | @char_to_word_offset = char_to_word_offset 60 | 61 | # Start and end positions only has a value during evaluation. 62 | if !start_position_character.nil? && !is_impossible 63 | @start_position = char_to_word_offset[start_position_character] 64 | @end_position = char_to_word_offset[ 65 | [start_position_character + answer_text.length - 1, char_to_word_offset.length - 1].min 66 | ] 67 | end 68 | end 69 | 70 | def _is_whitespace(c) 71 | c == " " || c == "\t" || c == "\r" || c == "\n" || c.ord == 0x202F 72 | end 73 | end 74 | 75 | class SquadFeatures 76 | def initialize( 77 | input_ids:, 78 | attention_mask:, 79 | token_type_ids:, 80 | cls_index:, 81 | p_mask:, 82 | example_index:, 83 | unique_id:, 84 | paragraph_len:, 85 | token_is_max_context:, 86 | tokens:, 87 | token_to_orig_map:, 88 | start_position:, 89 | end_position:, 90 | is_impossible:, 91 | qas_id: nil, 92 | encoding: nil 93 | ) 94 | @input_ids = input_ids 95 | @attention_mask = attention_mask 96 | @token_type_ids = token_type_ids 97 | @cls_index = cls_index 98 | @p_mask = p_mask 99 | 100 | @example_index = example_index 101 | @unique_id = unique_id 102 | @paragraph_len = paragraph_len 103 | @token_is_max_context = token_is_max_context 104 | @tokens = tokens 105 | @token_to_orig_map = token_to_orig_map 106 | 107 | @start_position = start_position 108 | @end_position = end_position 109 | @is_impossible = is_impossible 110 | @qas_id = qas_id 111 | 112 | @encoding = encoding 113 | end 114 | end 115 | end 116 | -------------------------------------------------------------------------------- /lib/transformers/dynamic_module_utils.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module DynamicModuleUtils 17 | # TODO improve 18 | def self.resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code) 19 | if trust_remote_code 20 | raise Error, "trust_remote_code not supported" 21 | end 22 | trust_remote_code 23 | end 24 | end 25 | end 26 | -------------------------------------------------------------------------------- /lib/transformers/feature_extraction_utils.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class BatchFeature 17 | def initialize(data:, tensor_type:) 18 | @data = data 19 | convert_to_tensors(tensor_type: tensor_type) 20 | end 21 | 22 | def to_h 23 | @data 24 | end 25 | alias_method :to_hash, :to_h 26 | 27 | def [](item) 28 | @data[item] 29 | end 30 | 31 | def keys 32 | @data.keys 33 | end 34 | 35 | def values 36 | @data.values 37 | end 38 | 39 | def items 40 | @data 41 | end 42 | 43 | def _get_is_as_tensor_fns(tensor_type: nil) 44 | if tensor_type.nil? 45 | return [nil, nil] 46 | end 47 | 48 | as_tensor = lambda do |value| 49 | if value.is_a?(Array) && value.length > 0 && value[0].is_a?(Numo::NArray) 50 | value = Numo::NArray.cast(value) 51 | end 52 | Torch.tensor(value) 53 | end 54 | 55 | is_tensor = Torch.method(:tensor?) 56 | 57 | [is_tensor, as_tensor] 58 | end 59 | 60 | def convert_to_tensors(tensor_type: nil) 61 | if tensor_type.nil? 62 | return self 63 | end 64 | 65 | is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type: tensor_type) 66 | 67 | # Do the tensor conversion in batch 68 | items.each do |key, value| 69 | begin 70 | if !is_tensor.(value) 71 | tensor = as_tensor.(value) 72 | 73 | @data[key] = tensor 74 | end 75 | rescue 76 | if key == :overflowing_values 77 | raise ArgumentError, "Unable to create tensor returning overflowing values of different lengths." 78 | end 79 | raise ArgumentError, 80 | "Unable to create tensor, you should probably activate padding " + 81 | "with 'padding: true' to have batched tensors with the same length." 82 | end 83 | end 84 | 85 | self 86 | end 87 | 88 | def to(*args, **kwargs) 89 | new_data = {} 90 | device = kwargs[:device] 91 | # Check if the args are a device or a dtype 92 | if device.nil? && args.length > 0 93 | raise Todo 94 | end 95 | # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` 96 | items.each do |k, v| 97 | # check if v is a floating point 98 | if Torch.floating_point?(v) 99 | # cast and send to device 100 | new_data[k] = v.to(*args, **kwargs) 101 | elsif !device.nil? 102 | new_data[k] = v.to(device) 103 | else 104 | new_data[k] = v 105 | end 106 | end 107 | @data = new_data 108 | self 109 | end 110 | end 111 | end 112 | -------------------------------------------------------------------------------- /lib/transformers/hf_hub/constants.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | module HfHub 3 | # Possible values for env variables 4 | 5 | ENV_VARS_TRUE_VALUES = ["1", "ON", "YES", "TRUE"] 6 | ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES + ["AUTO"] 7 | 8 | def self._is_true(value) 9 | if value.nil? 10 | return false 11 | end 12 | ENV_VARS_TRUE_VALUES.include?(value.upcase) 13 | end 14 | 15 | def self._as_int(value) 16 | if value.nil? 17 | return nil 18 | end 19 | value.to_i 20 | end 21 | 22 | # Constants for file downloads 23 | 24 | DEFAULT_ETAG_TIMEOUT = 10 25 | 26 | # Git-related constants 27 | 28 | DEFAULT_REVISION = "main" 29 | 30 | ENDPOINT = ENV["HF_ENDPOINT"] || "https://huggingface.co" 31 | 32 | HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/%{repo_id}/resolve/%{revision}/%{filename}" 33 | HUGGINGFACE_HEADER_X_REPO_COMMIT = "x-repo-commit" 34 | HUGGINGFACE_HEADER_X_LINKED_ETAG = "x-linked-etag" 35 | HUGGINGFACE_HEADER_X_LINKED_SIZE = "x-linked-size" 36 | 37 | REPO_ID_SEPARATOR = "--" 38 | # ^ this substring is not allowed in repo_ids on hf.co 39 | # and is the canonical one we use for serialization of repo ids elsewhere. 40 | 41 | REPO_TYPE_DATASET = "dataset" 42 | REPO_TYPE_SPACE = "space" 43 | REPO_TYPE_MODEL = "model" 44 | REPO_TYPES = [nil, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE] 45 | 46 | REPO_TYPES_URL_PREFIXES = { 47 | REPO_TYPE_DATASET => "datasets/", 48 | REPO_TYPE_SPACE => "spaces/" 49 | } 50 | 51 | # default cache 52 | DEFAULT_HOME = File.join(ENV.fetch("HOME"), ".cache") 53 | HF_HOME = 54 | File.expand_path( 55 | ENV.fetch( 56 | "HF_HOME", 57 | File.join(ENV.fetch("XDG_CACHE_HOME", DEFAULT_HOME), "huggingface") 58 | ) 59 | ) 60 | 61 | # New env variables 62 | HF_HUB_CACHE = ENV["HF_HUB_CACHE"] || File.join(HF_HOME, "hub") 63 | 64 | HF_HUB_OFFLINE = _is_true(ENV["HF_HUB_OFFLINE"] || ENV["TRANSFORMERS_OFFLINE"]) 65 | 66 | # Disable sending the cached token by default is all HTTP requests to the Hub 67 | HF_HUB_DISABLE_IMPLICIT_TOKEN = _is_true(ENV["HF_HUB_DISABLE_IMPLICIT_TOKEN"]) 68 | 69 | HF_HUB_ENABLE_HF_TRANSFER = _is_true(ENV["HF_HUB_ENABLE_HF_TRANSFER"]) 70 | end 71 | end 72 | -------------------------------------------------------------------------------- /lib/transformers/hf_hub/errors.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | module HfHub 3 | class Error < StandardError; end 4 | 5 | # Raised if local token is required but not found. 6 | class LocalTokenNotFoundError < Error; end 7 | 8 | # Raised when a request is made but `HF_HUB_OFFLINE=1` is set as environment variable. 9 | class OfflineModeIsEnabled < Error; end 10 | end 11 | end 12 | -------------------------------------------------------------------------------- /lib/transformers/hf_hub/utils/_errors.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | module HfHub 3 | class HfHubHTTPError < Error 4 | def initialize(message, response = nil) 5 | super(message) 6 | end 7 | end 8 | 9 | class RepositoryNotFoundError < HfHubHTTPError; end 10 | 11 | class GatedRepoError < RepositoryNotFoundError; end 12 | 13 | class DisabledRepoError < HfHubHTTPError; end 14 | 15 | class RevisionNotFoundError < HfHubHTTPError; end 16 | 17 | class EntryNotFoundError < HfHubHTTPError; end 18 | 19 | class LocalEntryNotFoundError < EntryNotFoundError; end 20 | 21 | class BadRequestError < HfHubHTTPError; end 22 | 23 | class << self 24 | def hf_raise_for_status(response, endpoint_name: nil) 25 | begin 26 | response.value unless response.is_a?(Net::HTTPRedirection) 27 | rescue => e 28 | error_code = response["X-Error-Code"] 29 | error_message = response["X-Error-Message"] 30 | 31 | if error_code == "RevisionNotFound" 32 | message = "#{response.code} Client Error." + "\n\n" + "Revision Not Found for url: #{response.uri}." 33 | raise RevisionNotFoundError.new(message, response) 34 | 35 | elsif error_code == "EntryNotFound" 36 | message = "#{response.code} Client Error." + "\n\n" + "Entry Not Found for url: #{response.uri}." 37 | raise EntryNotFoundError.new(message, response) 38 | 39 | elsif error_code == "GatedRepo" 40 | message = ( 41 | "#{response.code} Client Error." + "\n\n" + "Cannot access gated repo for url #{response.uri}." 42 | ) 43 | raise GatedRepoError.new(message, response) 44 | 45 | elsif error_message == "Access to this resource is disabled." 46 | message = ( 47 | "#{response.code} Client Error." + 48 | "\n\n" + 49 | "Cannot access repository for url #{response.uri}." + 50 | "\n" + 51 | "Access to this resource is disabled." 52 | ) 53 | raise DisabledRepoError.new(message, response) 54 | 55 | elsif error_code == "RepoNotFound" 56 | # 401 is misleading as it is returned for: 57 | # - private and gated repos if user is not authenticated 58 | # - missing repos 59 | # => for now, we process them as `RepoNotFound` anyway. 60 | # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9 61 | message = ( 62 | "#{response.code} Client Error." + 63 | "\n\n" + 64 | "Repository Not Found for url: #{response.uri}." + 65 | "\nPlease make sure you specified the correct `repo_id` and" + 66 | " `repo_type`.\nIf you are trying to access a private or gated repo," + 67 | " make sure you are authenticated." 68 | ) 69 | raise RepositoryNotFoundError.new(message, response) 70 | 71 | elsif response.code.to_i == 400 72 | message = ( 73 | !endpoint_name.nil? ? "\n\nBad request for #{endpoint_name} endpoint:" : "\n\nBad request:" 74 | ) 75 | raise BadRequestError.new(message, response) 76 | 77 | elsif response.code.to_i == 403 78 | message = ( 79 | "\n\n{response.code} Forbidden: #{error_message}." + 80 | "\nCannot access content at: #{response.uri}." + 81 | "\nIf you are trying to create or update content, " + 82 | "make sure you have a token with the `write` role." 83 | ) 84 | raise HfHubHTTPError.new(message, response) 85 | end 86 | 87 | # Convert `HTTPError` into a `HfHubHTTPError` to display request information 88 | # as well (request id and/or server error message) 89 | raise HfHubHTTPError.new(e.to_s, response) 90 | end 91 | end 92 | end 93 | end 94 | end 95 | -------------------------------------------------------------------------------- /lib/transformers/hf_hub/utils/_headers.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, the HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module HfHub 17 | class << self 18 | def build_hf_headers( 19 | token: nil, 20 | is_write_action: false, 21 | library_name: nil, 22 | library_version: nil, 23 | user_agent: nil, 24 | headers: nil 25 | ) 26 | # Get auth token to send 27 | token_to_send = get_token_to_send(token) 28 | _validate_token_to_send(token_to_send, is_write_action) 29 | 30 | # Combine headers 31 | hf_headers = { 32 | "user-agent" => _http_user_agent( 33 | library_name: library_name, 34 | library_version: library_version, 35 | user_agent: user_agent 36 | ) 37 | } 38 | if !token_to_send.nil? 39 | hf_headers["authorization"] = "Bearer #{token_to_send}" 40 | end 41 | if headers 42 | hf_headers.merge!(headers) 43 | end 44 | hf_headers 45 | end 46 | 47 | def get_token_to_send(token) 48 | # Case token is explicitly provided 49 | if token.is_a?(String) 50 | return token 51 | end 52 | 53 | # Case token is explicitly forbidden 54 | if token == false 55 | return nil 56 | end 57 | 58 | # Token is not provided: we get it from local cache 59 | cached_token = nil # get_token 60 | 61 | # Case token is explicitly required 62 | if token == true 63 | if cached_token.nil? 64 | raise LocalTokenNotFoundError, 65 | "Token is required (`token: true`), but no token found. You" + 66 | " need to provide a token or be logged in to Hugging Face with" + 67 | " `huggingface-cli login` or `huggingface_hub.login`. See" + 68 | " https://huggingface.co/settings/tokens." 69 | end 70 | return cached_token 71 | end 72 | 73 | # Case implicit use of the token is forbidden by env variable 74 | if HF_HUB_DISABLE_IMPLICIT_TOKEN 75 | return nil 76 | end 77 | 78 | # Otherwise: we use the cached token as the user has not explicitly forbidden it 79 | cached_token 80 | end 81 | 82 | def _validate_token_to_send(token, is_write_action) 83 | if is_write_action 84 | if token.nil? 85 | raise ArgumentError, 86 | "Token is required (write-access action) but no token found. You need" + 87 | " to provide a token or be logged in to Hugging Face with" + 88 | " `huggingface-cli login` or `huggingface_hub.login`. See" + 89 | " https://huggingface.co/settings/tokens." 90 | end 91 | end 92 | end 93 | 94 | def _http_user_agent( 95 | library_name: nil, 96 | library_version: nil, 97 | user_agent: nil 98 | ) 99 | if !library_name.nil? 100 | ua = "#{library_name}/#{library_version}" 101 | else 102 | ua = "unknown/None" 103 | end 104 | ua += "; ruby/#{RUBY_VERSION.to_f}" 105 | ua 106 | end 107 | end 108 | end 109 | end 110 | -------------------------------------------------------------------------------- /lib/transformers/image_processing_base.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class ImageProcessingMixin 17 | def self.from_pretrained( 18 | pretrained_model_name_or_path, 19 | cache_dir: nil, 20 | force_download: false, 21 | local_files_only: false, 22 | token: nil, 23 | revision: "main", 24 | **kwargs 25 | ) 26 | kwargs[:cache_dir] = cache_dir 27 | kwargs[:force_download] = force_download 28 | kwargs[:local_files_only] = local_files_only 29 | kwargs[:revision] = revision 30 | 31 | if !token.nil? 32 | kwargs[:token] = token 33 | end 34 | 35 | image_processor_dict, kwargs = get_image_processor_dict(pretrained_model_name_or_path, **kwargs) 36 | 37 | from_dict(image_processor_dict, **kwargs) 38 | end 39 | 40 | def self.get_image_processor_dict( 41 | pretrained_model_name_or_path, **kwargs 42 | ) 43 | cache_dir = kwargs.delete(:cache_dir) 44 | force_download = kwargs.delete(:force_download) { false } 45 | resume_download = kwargs.delete(:resume_download) 46 | proxies = kwargs.delete(:proxies) 47 | token = kwargs.delete(:token) 48 | _use_auth_token = kwargs.delete(:use_auth_token) 49 | local_files_only = kwargs.delete(:local_files_only) { false } 50 | revision = kwargs.delete(:revision) 51 | subfolder = kwargs.delete(:subfolder) { "" } 52 | 53 | from_pipeline = kwargs.delete(:_from_pipeline) 54 | from_auto_class = kwargs.delete(:_from_auto) { false } 55 | 56 | user_agent = {file_type: "image processor", from_auto_class: from_auto_class} 57 | if !from_pipeline.nil? 58 | user_agent[:using_pipeline] = from_pipeline 59 | end 60 | 61 | if Utils::Hub.is_offline_mode && !local_files_only 62 | Transformers.logger.info("Offline mode: forcing local_files_only: true") 63 | local_files_only = true 64 | end 65 | 66 | pretrained_model_name_or_path = pretrained_model_name_or_path.to_s 67 | is_local = Dir.exist?(pretrained_model_name_or_path) 68 | if Dir.exist?(pretrained_model_name_or_path) 69 | image_processor_file = File.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME) 70 | end 71 | if File.exist?(pretrained_model_name_or_path) 72 | resolved_image_processor_file = pretrained_model_name_or_path 73 | is_local = true 74 | elsif Utils::Hub.is_remote_url(pretrained_model_name_or_path) 75 | raise Todo 76 | else 77 | image_processor_file = IMAGE_PROCESSOR_NAME 78 | begin 79 | # Load from local folder or from cache or download from model Hub and cache 80 | resolved_image_processor_file = Utils::Hub.cached_file( 81 | pretrained_model_name_or_path, 82 | image_processor_file, 83 | cache_dir: cache_dir, 84 | force_download: force_download, 85 | proxies: proxies, 86 | resume_download: resume_download, 87 | local_files_only: local_files_only, 88 | token: token, 89 | user_agent: user_agent, 90 | revision: revision, 91 | subfolder: subfolder 92 | ) 93 | rescue EnvironmentError 94 | # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to 95 | # the original exception. 96 | raise 97 | rescue 98 | # For any other exception, we throw a generic error. 99 | raise EnvironmentError, 100 | "Can't load image processor for '#{pretrained_model_name_or_path}'. If you were trying to load" + 101 | " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + 102 | " same name. Otherwise, make sure '#{pretrained_model_name_or_path}' is the correct path to a" + 103 | " directory containing a #{IMAGE_PROCESSOR_NAME} file" 104 | end 105 | end 106 | 107 | begin 108 | image_processor_dict = JSON.load_file(resolved_image_processor_file).transform_keys(&:to_sym) 109 | rescue JSON::ParserError 110 | raise EnvironmentError, 111 | "It looks like the config file at '#{resolved_image_processor_file}' is not a valid JSON file." 112 | end 113 | 114 | if is_local 115 | Transformers.logger.info("loading configuration file #{resolved_image_processor_file}") 116 | else 117 | Transformers.logger.info( 118 | "loading configuration file #{image_processor_file} from cache at #{resolved_image_processor_file}" 119 | ) 120 | end 121 | 122 | if !is_local 123 | if image_processor_dict.include?("auto_map") 124 | raise Todo 125 | end 126 | if image_processor_dict.include?("custom_pipelines") 127 | raise Todo 128 | end 129 | end 130 | [image_processor_dict, kwargs] 131 | end 132 | 133 | def self.from_dict(image_processor_dict, **kwargs) 134 | image_processor_dict = image_processor_dict.dup 135 | return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false } 136 | 137 | # The `size` parameter is a dict and was previously an int or tuple in feature extractors. 138 | # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate 139 | # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg. 140 | if kwargs.include?(:size) && image_processor_dict.include?(:size) 141 | image_processor_dict[:size] = kwargs.delete(:size) 142 | end 143 | if kwargs.include?(:crop_size) && image_processor_dict.include?(:crop_size) 144 | image_processor_dict[:crop_size] = kwargs.delete(:crop_size) 145 | end 146 | 147 | image_processor = new(**image_processor_dict) 148 | 149 | # Update image_processor with kwargs if needed 150 | to_remove = [] 151 | kwargs.each do |key, value| 152 | if image_processor.instance_variable_defined?("@#{key}") 153 | image_processor.instance_variable_set("@#{key}", value) 154 | to_remove << key 155 | end 156 | end 157 | to_remove.each do |key| 158 | kwargs.delete(key) 159 | end 160 | 161 | Transformers.logger.info("Image processor #{image_processor}") 162 | if return_unused_kwargs 163 | [image_processor, kwargs] 164 | else 165 | image_processor 166 | end 167 | end 168 | end 169 | end 170 | -------------------------------------------------------------------------------- /lib/transformers/image_processing_utils.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class BaseImageProcessor < ImageProcessingMixin 17 | def initialize(**kwargs) 18 | super(**kwargs) 19 | end 20 | 21 | def call(images, **kwargs) 22 | preprocess(images, **kwargs) 23 | end 24 | 25 | def preprocess(images, **kwargs) 26 | raise NotImplementedError, "Each image processor must implement its own preprocess method" 27 | end 28 | 29 | def rescale( 30 | image, 31 | scale, 32 | data_format: nil, 33 | input_data_format: nil, 34 | **kwargs 35 | ) 36 | ImageTransforms.rescale(image, scale, data_format: data_format, input_data_format: input_data_format, **kwargs) 37 | end 38 | 39 | def normalize( 40 | image, 41 | mean, 42 | std, 43 | data_format: nil, 44 | input_data_format: nil, 45 | **kwargs 46 | ) 47 | ImageTransforms.normalize( 48 | image, mean, std, data_format: data_format, input_data_format: input_data_format, **kwargs 49 | ) 50 | end 51 | end 52 | 53 | module ImageProcessingUtils 54 | def self.get_size_dict(size) 55 | if !size.is_a?(Hash) 56 | size_dict = {height: size, width: size} 57 | else 58 | size_dict = size 59 | end 60 | size_dict 61 | end 62 | end 63 | end 64 | -------------------------------------------------------------------------------- /lib/transformers/image_transforms.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module ImageTransforms 17 | def self.to_channel_dimension_format( 18 | image, 19 | channel_dim, 20 | input_channel_dim: nil 21 | ) 22 | if !image.is_a?(Numo::NArray) 23 | raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}" 24 | end 25 | 26 | if input_channel_dim.nil? 27 | input_channel_dim = infer_channel_dimension_format(image) 28 | end 29 | 30 | target_channel_dim = ChannelDimension.new(channel_dim).to_s 31 | if input_channel_dim == target_channel_dim 32 | return image 33 | end 34 | 35 | if target_channel_dim == ChannelDimension::FIRST 36 | image = image.transpose(2, 0, 1) 37 | elsif target_channel_dim == ChannelDimension::LAST 38 | image = image.transpose(1, 2, 0) 39 | else 40 | raise ArgumentError, "Unsupported channel dimension format: #{channel_dim}" 41 | end 42 | 43 | image 44 | end 45 | 46 | def self.rescale( 47 | image, 48 | scale, 49 | data_format: nil, 50 | dtype: Numo::SFloat, 51 | input_data_format: nil 52 | ) 53 | if !image.is_a?(Numo::NArray) 54 | raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}" 55 | end 56 | 57 | rescaled_image = image * scale 58 | if !data_format.nil? 59 | rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format) 60 | end 61 | 62 | rescaled_image = rescaled_image.cast_to(dtype) 63 | 64 | rescaled_image 65 | end 66 | 67 | def self.resize( 68 | image, 69 | size, 70 | resample: nil, 71 | reducing_gap: nil, 72 | data_format: nil, 73 | return_numpy: true, 74 | input_data_format: nil 75 | ) 76 | resample = !resample.nil? ? resample : nil # PILImageResampling.BILINEAR 77 | 78 | if size.length != 2 79 | raise ArgumentError, "size must have 2 elements" 80 | end 81 | 82 | # For all transformations, we want to keep the same data format as the input image unless otherwise specified. 83 | # The resized image from PIL will always have channels last, so find the input format first. 84 | if input_data_format.nil? 85 | input_data_format = ImageUtils.infer_channel_dimension_format(image) 86 | end 87 | data_format = data_format.nil? ? input_data_format : data_format 88 | 89 | # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use 90 | # the pillow library to resize the image and then convert back to numpy 91 | do_rescale = false 92 | if !image.is_a?(Vips::Image) 93 | do_rescale = _rescale_for_pil_conversion(image) 94 | image = to_pil_image(image, do_rescale: do_rescale, input_data_format: input_data_format) 95 | end 96 | height, width = size 97 | # TODO support resample 98 | resized_image = image.thumbnail_image(width, height: height, size: :force) 99 | 100 | if return_numpy 101 | resized_image = ImageUtils.to_numo_array(resized_image) 102 | # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image 103 | # so we need to add it back if necessary. 104 | resized_image = resized_image.ndim == 2 ? resized_image.expand_dims(-1) : resized_image 105 | # The image is always in channels last format after converting from a PIL image 106 | resized_image = to_channel_dimension_format( 107 | resized_image, data_format, input_channel_dim: ChannelDimension::LAST 108 | ) 109 | # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to 110 | # rescale it back to the original range. 111 | resized_image = do_rescale ? rescale(resized_image, 1 / 255.0) : resized_image 112 | end 113 | resized_image 114 | end 115 | 116 | def self.normalize( 117 | image, 118 | mean, 119 | std, 120 | data_format: nil, 121 | input_data_format: nil 122 | ) 123 | if !image.is_a?(Numo::NArray) 124 | raise ArgumentError, "image must be a numpy array" 125 | end 126 | 127 | if input_data_format.nil? 128 | input_data_format = infer_channel_dimension_format(image) 129 | end 130 | 131 | channel_axis = ImageUtils.get_channel_dimension_axis(image, input_data_format: input_data_format) 132 | num_channels = image.shape[channel_axis] 133 | 134 | # We cast to float32 to avoid errors that can occur when subtracting uint8 values. 135 | # We preserve the original dtype if it is a float type to prevent upcasting float16. 136 | if !image.is_a?(Numo::SFloat) && !image.is_a?(Numo::DFloat) 137 | image = image.cast_to(Numo::SFloat) 138 | end 139 | 140 | if mean.is_a?(Enumerable) 141 | if mean.length != num_channels 142 | raise ArgumentError, "mean must have #{num_channels} elements if it is an iterable, got #{mean.length}" 143 | end 144 | else 145 | mean = [mean] * num_channels 146 | end 147 | mean = Numo::DFloat.cast(mean) 148 | 149 | if std.is_a?(Enumerable) 150 | if std.length != num_channels 151 | raise ArgumentError, "std must have #{num_channels} elements if it is an iterable, got #{std.length}" 152 | end 153 | else 154 | std = [std] * num_channels 155 | end 156 | std = Numo::DFloat.cast(std) 157 | 158 | if input_data_format == ChannelDimension::LAST 159 | image = (image - mean) / std 160 | else 161 | image = ((image.transpose - mean) / std).transpose 162 | end 163 | 164 | image = !data_format.nil? ? to_channel_dimension_format(image, data_format, input_data_format) : image 165 | image 166 | end 167 | 168 | def self.to_pil_image( 169 | image, 170 | do_rescale: nil, 171 | input_data_format: nil 172 | ) 173 | if image.is_a?(Vips::Image) 174 | return image 175 | end 176 | 177 | # Convert all tensors to numo arrays before converting to Vips image 178 | if !image.is_a?(Numo::NArray) 179 | raise ArgumentError, "Input image type not supported: #{image.class.name}" 180 | end 181 | 182 | # If the channel has been moved to first dim, we put it back at the end. 183 | image = to_channel_dimension_format(image, ChannelDimension::LAST, input_channel_dim: input_data_format) 184 | 185 | # If there is a single channel, we squeeze it, as otherwise PIL can't handle it. 186 | # image = image.shape[-1] == 1 ? image.squeeze(-1) : image 187 | 188 | # Rescale the image to be between 0 and 255 if needed. 189 | do_rescale = do_rescale.nil? ? _rescale_for_pil_conversion(image) : do_rescale 190 | 191 | if do_rescale 192 | image = rescale(image, 255) 193 | end 194 | 195 | image = image.cast_to(Numo::UInt8) 196 | Vips::Image.new_from_memory(image.to_binary, image.shape[1], image.shape[0], image.shape[2], :uchar) 197 | end 198 | 199 | def self._rescale_for_pil_conversion(image) 200 | if image.is_a?(Numo::UInt8) 201 | do_rescale = false 202 | else 203 | raise Todo 204 | end 205 | do_rescale 206 | end 207 | end 208 | end 209 | -------------------------------------------------------------------------------- /lib/transformers/image_utils.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class ChannelDimension < ExplicitEnum 17 | FIRST = "channels_first" 18 | LAST = "channels_last" 19 | end 20 | 21 | module ImageUtils 22 | def self.load_image(image, timeout: nil) 23 | Utils.requires_backends(__method__, ["vision"]) 24 | if image.is_a?(URI) 25 | require "open-uri" 26 | 27 | image = Vips::Image.new_from_buffer(image.read(open_timeout: timeout, read_timeout: timeout), "") 28 | elsif image.is_a?(String) && File.exist?(image) 29 | image = Vips::Image.new_from_file(image) 30 | elsif image.is_a?(Vips::Image) 31 | image = image 32 | else 33 | raise ArgumentError, "Incorrect format used for image" 34 | end 35 | image 36 | end 37 | 38 | def self.validate_preprocess_arguments( 39 | do_rescale: nil, 40 | rescale_factor: nil, 41 | do_normalize: nil, 42 | image_mean: nil, 43 | image_std: nil, 44 | do_pad: nil, 45 | size_divisibility: nil, 46 | do_center_crop: nil, 47 | crop_size: nil, 48 | do_resize: nil, 49 | size: nil, 50 | resample: nil 51 | ) 52 | if do_rescale && rescale_factor.nil? 53 | raise ArgumentError, "`rescale_factor` must be specified if `do_rescale` is `true`." 54 | end 55 | 56 | if do_pad && size_divisibility.nil? 57 | # Here, size_divisor might be passed as the value of size 58 | raise ArgumentError, "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `true`." 59 | end 60 | 61 | if do_normalize && (image_mean.nil? || image_std.nil?) 62 | raise ArgumentError, "`image_mean` and `image_std` must both be specified if `do_normalize` is `true`." 63 | end 64 | 65 | if do_center_crop && crop_size.nil? 66 | raise ArgumentError, "`crop_size` must be specified if `do_center_crop` is `true`." 67 | end 68 | 69 | if do_resize && (size.nil? || resample.nil?) 70 | raise ArgumentError, "`size` and `resample` must be specified if `do_resize` is `true`." 71 | end 72 | end 73 | 74 | def self.make_list_of_images(images, expected_ndims: 3) 75 | # TODO improve 76 | images.is_a?(Array) ? images : [images] 77 | end 78 | 79 | def self.to_numo_array(img) 80 | Numo::UInt8.from_binary(img.write_to_memory, [img.height, img.width, img.bands]) 81 | end 82 | 83 | def self.infer_channel_dimension_format( 84 | image, num_channels: nil 85 | ) 86 | num_channels = !num_channels.nil? ? num_channels : [1, 3] 87 | num_channels = num_channels.is_a?(Integer) ? [num_channels] : num_channels 88 | 89 | if image.ndim == 3 90 | first_dim, last_dim = 0, 2 91 | elsif image.ndim == 4 92 | first_dim, last_dim = 1, 3 93 | else 94 | raise ArgumentError, "Unsupported number of image dimensions: #{image.ndim}" 95 | end 96 | 97 | if num_channels.include?(image.shape[first_dim]) && num_channels.include?(image.shape[last_dim]) 98 | Transformers.logger.warn( 99 | "The channel dimension is ambiguous. Got image shape #{image.shape}. Assuming channels are the first dimension." 100 | ) 101 | return ChannelDimension::FIRST 102 | elsif num_channels.include?(image.shape[first_dim]) 103 | return ChannelDimension::FIRST 104 | elsif num_channels.include?(image.shape[last_dim]) 105 | return ChannelDimension::LAST 106 | end 107 | raise ArgumentError, "Unable to infer channel dimension format" 108 | end 109 | 110 | def self.get_channel_dimension_axis( 111 | image, input_data_format: nil 112 | ) 113 | if input_data_format.nil? 114 | input_data_format = infer_channel_dimension_format(image) 115 | end 116 | if input_data_format == ChannelDimension::FIRST 117 | return image.ndim - 3 118 | elsif input_data_format == ChannelDimension::LAST 119 | return image.ndim - 1 120 | end 121 | raise ArgumentError, "Unsupported data format: #{input_data_format}" 122 | end 123 | 124 | def self.is_vips_image(img) 125 | Utils.is_vision_available && img.is_a?(Vips::Image) 126 | end 127 | 128 | def self.is_valid_image(img) 129 | is_vips_image(img) || is_numo_array(img) || is_torch_tensor(img) 130 | end 131 | 132 | def self.valid_images(imgs) 133 | # If we have an list of images, make sure every image is valid 134 | if imgs.is_a?(Array) 135 | imgs.each do |img| 136 | if !valid_images(img) 137 | return false 138 | end 139 | end 140 | # If not a list of tuple, we have been given a single image or batched tensor of images 141 | elsif !is_valid_image(imgs) 142 | return false 143 | end 144 | true 145 | end 146 | 147 | def self.is_scaled_image(image) 148 | if image.is_a?(Numo::UInt8) 149 | return false 150 | end 151 | 152 | # It's possible the image has pixel values in [0, 255] but is of floating type 153 | image.min >= 0 && image.max <= 1 154 | end 155 | 156 | def self.validate_kwargs(valid_processor_keys:, captured_kwargs:) 157 | unused_keys = Set.new(captured_kwargs).difference(Set.new(valid_processor_keys)) 158 | if unused_keys.any? 159 | unused_key_str = unused_keys.join(", ") 160 | # TODO raise a warning here instead of simply logging? 161 | Transformers.logger.warn("Unused or unrecognized kwargs: #{unused_key_str}.") 162 | end 163 | end 164 | end 165 | end 166 | -------------------------------------------------------------------------------- /lib/transformers/modeling_outputs.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class BaseModelOutput < ModelOutput 17 | attribute :last_hidden_state 18 | attribute :hidden_states 19 | attribute :attentions 20 | end 21 | 22 | class BaseModelOutputWithPooling < ModelOutput 23 | attribute :last_hidden_state 24 | attribute :pooler_output 25 | attribute :hidden_states 26 | attribute :attentions 27 | end 28 | 29 | class BaseModelOutputWithPoolingAndCrossAttentions < ModelOutput 30 | attribute :last_hidden_state 31 | attribute :pooler_output 32 | attribute :hidden_states 33 | attribute :past_key_values 34 | attribute :attentions 35 | attribute :cross_attentions 36 | end 37 | 38 | class BaseModelOutputWithPastAndCrossAttentions < ModelOutput 39 | attribute :last_hidden_state 40 | attribute :past_key_values 41 | attribute :hidden_states 42 | attribute :attentions 43 | attribute :cross_attentions 44 | end 45 | 46 | class MaskedLMOutput < ModelOutput 47 | attribute :loss 48 | attribute :logits 49 | attribute :hidden_states 50 | attribute :attentions 51 | end 52 | 53 | class SequenceClassifierOutput < ModelOutput 54 | attribute :loss 55 | attribute :logits 56 | attribute :hidden_states 57 | attribute :attentions 58 | end 59 | 60 | class TokenClassifierOutput < ModelOutput 61 | attribute :loss 62 | attribute :logits 63 | attribute :hidden_states 64 | attribute :attentions 65 | end 66 | 67 | class QuestionAnsweringModelOutput < ModelOutput 68 | attribute :loss 69 | attribute :start_logits 70 | attribute :end_logits 71 | attribute :hidden_states 72 | attribute :attentions 73 | end 74 | 75 | class ImageClassifierOutput < ModelOutput 76 | attribute :loss 77 | attribute :logits 78 | attribute :hidden_states 79 | attribute :attentions 80 | end 81 | end 82 | -------------------------------------------------------------------------------- /lib/transformers/models/auto/auto_factory.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class BaseAutoModelClass 17 | extend ClassAttribute 18 | 19 | class_attribute :_model_mapping 20 | 21 | class << self 22 | def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) 23 | config = kwargs.delete(:config) 24 | trust_remote_code = kwargs.delete(:trust_remote_code) 25 | hub_kwargs_names = [ 26 | :cache_dir, 27 | :force_download, 28 | :local_files_only, 29 | :proxies, 30 | :resume_download, 31 | :revision, 32 | :subfolder, 33 | :use_auth_token, 34 | :token 35 | ] 36 | hub_kwargs = hub_kwargs_names.select { |k| kwargs.key?(k) }.to_h { |name| [name, kwargs.delete(name)] } 37 | code_revision = kwargs.delete(:code_revision) 38 | commit_hash = kwargs.delete(:_commit_hash) 39 | 40 | if commit_hash.nil? 41 | if !config.is_a?(PretrainedConfig) 42 | # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible 43 | resolved_config_file = Utils::Hub.cached_file( 44 | pretrained_model_name_or_path, 45 | CONFIG_NAME, 46 | _raise_exceptions_for_gated_repo: false, 47 | _raise_exceptions_for_missing_entries: false, 48 | _raise_exceptions_for_connection_errors: false, 49 | **hub_kwargs 50 | ) 51 | commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash) 52 | else 53 | raise Todo 54 | end 55 | end 56 | 57 | if !config.is_a?(PretrainedConfig) 58 | config, kwargs = 59 | AutoConfig.from_pretrained( 60 | pretrained_model_name_or_path, 61 | return_unused_kwargs: true, 62 | trust_remote_code: trust_remote_code, 63 | code_revision: code_revision, 64 | _commit_hash: commit_hash, 65 | **hub_kwargs, 66 | **kwargs 67 | ) 68 | end 69 | 70 | model_class = _get_model_class(config, _model_mapping) 71 | model_class.from_pretrained( 72 | pretrained_model_name_or_path, *model_args, config: config, **hub_kwargs, **kwargs 73 | ) 74 | end 75 | 76 | private 77 | 78 | def _get_model_class(config, model_mapping) 79 | supported_models = model_mapping[config.class.name.split("::").last] 80 | if !supported_models.is_a?(Array) 81 | return supported_models 82 | end 83 | 84 | raise Todo 85 | end 86 | end 87 | end 88 | 89 | class LazyAutoMapping 90 | def initialize(config_mapping, model_mapping) 91 | @config_mapping = config_mapping 92 | @reverse_config_mapping = config_mapping.invert 93 | @model_mapping = model_mapping 94 | @modules = {} 95 | end 96 | 97 | def [](key) 98 | model_type = @reverse_config_mapping[key] 99 | if @model_mapping[model_type] 100 | model_name = @model_mapping[model_type] 101 | return _load_attr_from_module(model_type, model_name) 102 | end 103 | 104 | raise KeyError, key 105 | end 106 | 107 | def include?(key) 108 | self[key] 109 | true 110 | rescue KeyError 111 | false 112 | end 113 | 114 | private 115 | 116 | def _load_attr_from_module(model_type, attr) 117 | module_name = model_type_to_module_name(model_type) 118 | if !@modules.include?(module_name) 119 | @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join) 120 | end 121 | getattribute_from_module(@modules[module_name], attr) 122 | end 123 | 124 | def getattribute_from_module(mod, attr) 125 | if attr.nil? 126 | nil 127 | elsif attr.is_a?(Array) 128 | attr.map { |a| mod.const_get(a) } 129 | else 130 | mod.const_get(attr) 131 | end 132 | end 133 | 134 | def model_type_to_module_name(key) 135 | key 136 | end 137 | end 138 | end 139 | -------------------------------------------------------------------------------- /lib/transformers/models/auto/configuration_auto.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | CONFIG_MAPPING_NAMES = { 17 | "bert" => "BertConfig", 18 | "deberta-v2" => "DebertaV2Config", 19 | "distilbert" => "DistilBertConfig", 20 | "mpnet" => "MPNetConfig", 21 | "vit" => "ViTConfig", 22 | "xlm-roberta" => "XLMRobertaConfig" 23 | } 24 | 25 | class LazyConfigMapping 26 | def initialize(mapping) 27 | @mapping = mapping 28 | @extra_content = {} 29 | @modules = {} 30 | end 31 | 32 | def [](key) 33 | value = @mapping.fetch(key) 34 | module_name = model_type_to_module_name(key) 35 | if !@modules.include?(module_name) 36 | @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join) 37 | end 38 | @modules[module_name].const_get(value) 39 | end 40 | 41 | def model_type_to_module_name(key) 42 | key 43 | end 44 | end 45 | 46 | CONFIG_MAPPING = LazyConfigMapping.new(CONFIG_MAPPING_NAMES) 47 | 48 | class AutoConfig 49 | def self.from_pretrained(pretrained_model_name_or_path, **kwargs) 50 | kwargs[:_from_auto] = true 51 | kwargs[:name_or_path] = pretrained_model_name_or_path 52 | _trust_remote_code = kwargs.delete(:trust_remote_code) 53 | _code_revision = kwargs.delete(:code_revision) 54 | 55 | config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) 56 | if config_dict[:model_type] 57 | config_class = CONFIG_MAPPING[config_dict[:model_type]] 58 | config_class.from_dict(config_dict, **unused_kwargs) 59 | else 60 | raise Todo 61 | end 62 | end 63 | end 64 | end 65 | -------------------------------------------------------------------------------- /lib/transformers/models/auto/feature_extraction_auto.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | FEATURE_EXTRACTOR_MAPPING_NAMES = { 17 | } 18 | 19 | FEATURE_EXTRACTOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES) 20 | end 21 | -------------------------------------------------------------------------------- /lib/transformers/models/auto/image_processing_auto.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | IMAGE_PROCESSOR_MAPPING_NAMES = { 17 | "vit" => ["ViTImageProcessor"] 18 | } 19 | 20 | IMAGE_PROCESSOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES) 21 | 22 | class AutoImageProcessor 23 | def self.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 24 | config = kwargs.delete(:config) 25 | use_fast = kwargs.delete(:use_fast) 26 | trust_remote_code = kwargs.delete(:trust_remote_code) 27 | kwargs[:_from_auto] = true 28 | 29 | config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) 30 | image_processor_class = config_dict[:image_processor_type] 31 | image_processor_auto_map = nil 32 | if (config_dict[:auto_map] || {}).include?("AutoImageProcessor") 33 | image_processor_auto_map = config_dict[:auto_map]["AutoImageProcessor"] 34 | end 35 | 36 | # If we still don't have the image processor class, check if we're loading from a previous feature extractor config 37 | # and if so, infer the image processor class from there. 38 | if image_processor_class.nil? && image_processor_auto_map.nil? 39 | feature_extractor_class = config_dict.delete(:feature_extractor_type) 40 | if !feature_extractor_class.nil? 41 | image_processor_class = feature_extractor_class.sub("FeatureExtractor", "ImageProcessor") 42 | end 43 | if (config_dict[:auto_map] || {}).include?("AutoFeatureExtractor") 44 | feature_extractor_auto_map = config_dict[:auto_map]["AutoFeatureExtractor"] 45 | image_processor_auto_map = feature_extractor_auto_map.sub("FeatureExtractor", "ImageProcessor") 46 | end 47 | end 48 | 49 | # If we don't find the image processor class in the image processor config, let's try the model config. 50 | if image_processor_class.nil? && image_processor_auto_map.nil? 51 | if !config.is_a?(PretrainedConfig) 52 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 53 | end 54 | # It could be in `config.image_processor_type`` 55 | image_processor_class = config.instance_variable_get(:@image_processor_type) 56 | end 57 | 58 | if !image_processor_class.nil? 59 | raise Todo 60 | end 61 | 62 | has_remote_code = !image_processor_auto_map.nil? 63 | has_local_code = !image_processor_class.nil? || IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last) 64 | trust_remote_code = DynamicModuleUtils.resolve_trust_remote_code( 65 | trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code 66 | ) 67 | 68 | if !image_processor_auto_map.nil? && !image_processor_auto_map.is_a?(Array) 69 | raise Todo 70 | end 71 | 72 | if has_remote_code && trust_remote_code 73 | raise Todo 74 | elsif !image_processor_class.nil? 75 | return image_processor_class.from_dict(config_dict, **kwargs) 76 | # Last try: we use the IMAGE_PROCESSOR_MAPPING. 77 | elsif IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last) 78 | image_processor_tuple = IMAGE_PROCESSOR_MAPPING[config.class.name.split("::").last] 79 | 80 | image_processor_class_py, image_processor_class_fast = image_processor_tuple 81 | 82 | if !use_fast && !image_processor_class_fast.nil? 83 | _warning_fast_image_processor_available(image_processor_class_fast) 84 | end 85 | 86 | if image_processor_class_fast && (use_fast || image_processor_class_py.nil?) 87 | return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 88 | else 89 | if !image_processor_class_py.nil? 90 | return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 91 | else 92 | raise ArgumentError, 93 | "This image processor cannot be instantiated. Please make sure you have `Pillow` installed." 94 | end 95 | end 96 | end 97 | 98 | raise ArgumentError, 99 | "Unrecognized image processor in #{pretrained_model_name_or_path}. Should have a " + 100 | "`image_processor_type` key in its #{IMAGE_PROCESSOR_NAME} of #{CONFIG_NAME}, or one of the following " + 101 | "`model_type` keys in its #{CONFIG_NAME}: #{IMAGE_PROCESSOR_MAPPING_NAMES.keys.join(", ")}" 102 | end 103 | end 104 | end 105 | -------------------------------------------------------------------------------- /lib/transformers/models/auto/modeling_auto.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | MODEL_MAPPING_NAMES = { 17 | "bert" => "BertModel", 18 | "deberta-v2" => "DebertaV2Model", 19 | "distilbert" => "DistilBertModel", 20 | "mpnet" => "MPNetModel", 21 | "vit" => "ViTModel", 22 | "xlm-roberta" => "XLMRobertaModel" 23 | } 24 | 25 | MODEL_FOR_MASKED_LM_MAPPING_NAMES = { 26 | "bert" => "BertForMaskedLM", 27 | "mpnet" => "MPNetForMaskedLM" 28 | } 29 | 30 | MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = { 31 | "bert" => "BertForSequenceClassification", 32 | "deberta-v2" => "DebertaV2ForSequenceClassification", 33 | "distilbert" => "DistilBertForSequenceClassification", 34 | "xlm-roberta" => "XLMRobertaForSequenceClassification" 35 | } 36 | 37 | MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = { 38 | "distilbert" => "DistilBertForQuestionAnswering" 39 | } 40 | 41 | MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = { 42 | "vit" => "ViTForImageClassification" 43 | } 44 | 45 | MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = { 46 | "bert" => "BertForTokenClassification" 47 | } 48 | 49 | MODEL_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) 50 | MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = LazyAutoMapping.new( 51 | CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES 52 | ) 53 | MODEL_FOR_MASKED_LM_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) 54 | MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = LazyAutoMapping.new( 55 | CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES 56 | ) 57 | MODEL_FOR_QUESTION_ANSWERING_MAPPING = LazyAutoMapping.new( 58 | CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES 59 | ) 60 | MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = LazyAutoMapping.new( 61 | CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES 62 | ) 63 | 64 | class AutoModel < BaseAutoModelClass 65 | self._model_mapping = MODEL_MAPPING 66 | end 67 | 68 | class AutoModelForMaskedLM < BaseAutoModelClass 69 | self._model_mapping = MODEL_FOR_MASKED_LM_MAPPING 70 | end 71 | 72 | class AutoModelForSequenceClassification < BaseAutoModelClass 73 | self._model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING 74 | end 75 | 76 | class AutoModelForQuestionAnswering < BaseAutoModelClass 77 | self._model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING 78 | end 79 | 80 | class AutoModelForTokenClassification < BaseAutoModelClass 81 | self._model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING 82 | end 83 | 84 | class AutoModelForImageClassification < BaseAutoModelClass 85 | self._model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING 86 | end 87 | end 88 | -------------------------------------------------------------------------------- /lib/transformers/models/auto/tokenization_auto.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | TOKENIZER_MAPPING_NAMES = { 17 | "bert" => ["BertTokenizer", "BertTokenizerFast"], 18 | "deberta-v2" => ["DebertaV2TokenizerFast"], 19 | "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"], 20 | "mpnet" => ["MPNetTokenizerFast"], 21 | "xlm-roberta" => ["XLMRobertaTokenizerFast"] 22 | } 23 | 24 | TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES) 25 | 26 | class AutoTokenizer 27 | class << self 28 | def from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 29 | config = kwargs.delete(:config) 30 | kwargs[:_from_auto] = true 31 | 32 | use_fast = kwargs.delete(:use_fast) { true } 33 | tokenizer_type = kwargs.delete(:tokenizer_type) { nil } 34 | trust_remote_code = kwargs.delete(:trust_remote_code) 35 | 36 | if !tokenizer_type.nil? 37 | raise Todo 38 | end 39 | 40 | tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) 41 | if tokenizer_config.include?("_commit_hash") 42 | kwargs[:_commit_hash] = tokenizer_config["_commit_hash"] 43 | end 44 | config_tokenizer_class = tokenizer_config["tokenizer_class"] 45 | _tokenizer_auto_map = nil 46 | if tokenizer_config["auto_map"] 47 | raise Todo 48 | end 49 | 50 | # If that did not work, let's try to use the config. 51 | if config_tokenizer_class.nil? 52 | if !config.is_a?(PretrainedConfig) 53 | config = AutoConfig.from_pretrained( 54 | pretrained_model_name_or_path, trust_remote_code: trust_remote_code, **kwargs 55 | ) 56 | config_tokenizer_class = config.tokenizer_class 57 | # if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map: 58 | # tokenizer_auto_map = config.auto_map["AutoTokenizer"] 59 | end 60 | end 61 | 62 | if !config_tokenizer_class.nil? 63 | tokenizer_class = nil 64 | if use_fast && !config_tokenizer_class.end_with?("Fast") 65 | tokenizer_class_candidate = "#{config_tokenizer_class}Fast" 66 | tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) 67 | end 68 | if tokenizer_class.nil? 69 | tokenizer_class_candidate = config_tokenizer_class 70 | tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) 71 | end 72 | if tokenizer_class.nil? 73 | raise ArgumentError, "Tokenizer class #{tokenizer_class_candidate} does not exist or is not currently imported." 74 | end 75 | return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 76 | end 77 | 78 | model_type = config_class_to_model_type(config.class.name.split("::").last) 79 | if !model_type.nil? 80 | tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[config.class.name.split("::").last] 81 | if tokenizer_class_fast && (use_fast || tokenizer_class_py.nil?) 82 | return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 83 | else 84 | if !tokenizer_class_py.nil? 85 | return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 86 | else 87 | raise ArgumentError, "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed in order to use this tokenizer." 88 | end 89 | end 90 | end 91 | 92 | raise ArgumentError, "Unrecognized configuration class #{config.class.name} to build an AutoTokenizer." 93 | end 94 | 95 | private 96 | 97 | def tokenizer_class_from_name(class_name) 98 | if class_name == "PreTrainedTokenizerFast" 99 | return PreTrainedTokenizerFast 100 | end 101 | 102 | TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers| 103 | if tokenizers.include?(class_name) 104 | cls = Transformers.const_get(module_name.split("-").map(&:capitalize).join).const_get(class_name) 105 | raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast 106 | return cls 107 | end 108 | end 109 | 110 | raise Todo 111 | end 112 | 113 | def get_tokenizer_config( 114 | pretrained_model_name_or_path, 115 | cache_dir: nil, 116 | force_download: false, 117 | resume_download: false, 118 | proxies: nil, 119 | token: nil, 120 | revision: nil, 121 | local_files_only: false, 122 | subfolder: "", 123 | **kwargs 124 | ) 125 | commit_hash = kwargs[:_commit_hash] 126 | resolved_config_file = Utils::Hub.cached_file( 127 | pretrained_model_name_or_path, 128 | TOKENIZER_CONFIG_FILE, 129 | cache_dir: cache_dir, 130 | force_download: force_download, 131 | resume_download: resume_download, 132 | proxies: proxies, 133 | token: token, 134 | revision: revision, 135 | local_files_only: local_files_only, 136 | subfolder: subfolder, 137 | _raise_exceptions_for_gated_repo: false, 138 | _raise_exceptions_for_missing_entries: false, 139 | _raise_exceptions_for_connection_errors: false, 140 | _commit_hash: commit_hash 141 | ) 142 | if resolved_config_file.nil? 143 | Transformers.logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") 144 | return {} 145 | end 146 | commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash) 147 | 148 | result = JSON.load_file(resolved_config_file) 149 | result["_commit_hash"] = commit_hash 150 | result 151 | end 152 | 153 | def config_class_to_model_type(config) 154 | CONFIG_MAPPING_NAMES.each do |key, cls| 155 | if cls == config 156 | return key 157 | end 158 | end 159 | nil 160 | end 161 | end 162 | end 163 | end 164 | -------------------------------------------------------------------------------- /lib/transformers/models/bert/configuration_bert.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | module Transformers 17 | module Bert 18 | class BertConfig < PretrainedConfig 19 | self.model_type = "bert" 20 | 21 | attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads, 22 | :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob, 23 | :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps, 24 | :position_embedding_type, :use_cache, :classifier_dropout 25 | 26 | def initialize( 27 | vocab_size: 30522, 28 | hidden_size: 768, 29 | num_hidden_layers: 12, 30 | num_attention_heads: 12, 31 | intermediate_size: 3072, 32 | hidden_act: "gelu", 33 | hidden_dropout_prob: 0.1, 34 | attention_probs_dropout_prob: 0.1, 35 | max_position_embeddings: 512, 36 | type_vocab_size: 2, 37 | initializer_range: 0.02, 38 | layer_norm_eps: 1e-12, 39 | pad_token_id: 0, 40 | position_embedding_type: "absolute", 41 | use_cache: true, 42 | classifier_dropout: nil, 43 | **kwargs 44 | ) 45 | super(pad_token_id: pad_token_id, **kwargs) 46 | 47 | @vocab_size = vocab_size 48 | @hidden_size = hidden_size 49 | @num_hidden_layers = num_hidden_layers 50 | @num_attention_heads = num_attention_heads 51 | @hidden_act = hidden_act 52 | @intermediate_size = intermediate_size 53 | @hidden_dropout_prob = hidden_dropout_prob 54 | @attention_probs_dropout_prob = attention_probs_dropout_prob 55 | @max_position_embeddings = max_position_embeddings 56 | @type_vocab_size = type_vocab_size 57 | @initializer_range = initializer_range 58 | @layer_norm_eps = layer_norm_eps 59 | @position_embedding_type = position_embedding_type 60 | @use_cache = use_cache 61 | @classifier_dropout = classifier_dropout 62 | end 63 | end 64 | end 65 | end 66 | -------------------------------------------------------------------------------- /lib/transformers/models/bert/tokenization_bert.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module Bert 17 | class BertTokenizer < PreTrainedTokenizer 18 | class BasicTokenizer 19 | attr_reader :do_lower_case, :tokenize_chinese_chars, :never_split, :strip_accents, :do_split_on_punc 20 | 21 | def initialize( 22 | do_lower_case: true, 23 | never_split: nil, 24 | tokenize_chinese_chars: true, 25 | strip_accents: nil, 26 | do_split_on_punc: true 27 | ) 28 | if never_split.nil? 29 | never_split = [] 30 | end 31 | @do_lower_case = do_lower_case 32 | @never_split = Set.new(never_split) 33 | @tokenize_chinese_chars = tokenize_chinese_chars 34 | @strip_accents = strip_accents 35 | @do_split_on_punc = do_split_on_punc 36 | end 37 | end 38 | 39 | class WordpieceTokenizer 40 | def initialize(vocab:, unk_token:, max_input_chars_per_word: 100) 41 | @vocab = vocab 42 | @unk_token = unk_token 43 | @max_input_chars_per_word = max_input_chars_per_word 44 | end 45 | end 46 | 47 | attr_reader :vocab, :basic_tokenizer 48 | 49 | def initialize( 50 | vocab_file:, 51 | do_lower_case: true, 52 | do_basic_tokenize: true, 53 | never_split: nil, 54 | unk_token: "[UNK]", 55 | sep_token: "[SEP]", 56 | pad_token: "[PAD]", 57 | cls_token: "[CLS]", 58 | mask_token: "[MASK]", 59 | tokenize_chinese_chars: true, 60 | strip_accents: nil, 61 | **kwargs 62 | ) 63 | if !File.exist?(vocab_file) 64 | raise ArgumentError, 65 | "Can't find a vocabulary file at path '#{vocab_file}'. To load the vocabulary from a Google pretrained" + 66 | " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" 67 | end 68 | @vocab = load_vocab(vocab_file) 69 | @ids_to_tokens = @vocab.invert 70 | @do_basic_tokenize = do_basic_tokenize 71 | if do_basic_tokenize 72 | @basic_tokenizer = 73 | BasicTokenizer.new( 74 | do_lower_case: do_lower_case, 75 | never_split: never_split, 76 | tokenize_chinese_chars: tokenize_chinese_chars, 77 | strip_accents: strip_accents 78 | ) 79 | end 80 | 81 | @wordpiece_tokenizer = WordpieceTokenizer.new(vocab: @vocab, unk_token: unk_token.to_s) 82 | 83 | super( 84 | do_lower_case: do_lower_case, 85 | do_basic_tokenize: do_basic_tokenize, 86 | never_split: never_split, 87 | unk_token: unk_token, 88 | sep_token: sep_token, 89 | pad_token: pad_token, 90 | cls_token: cls_token, 91 | mask_token: mask_token, 92 | tokenize_chinese_chars: tokenize_chinese_chars, 93 | strip_accents: strip_accents, 94 | **kwargs 95 | ) 96 | end 97 | 98 | def _convert_token_to_id(token) 99 | @vocab.fetch(token, @vocab.fetch(@unk_token)) 100 | end 101 | 102 | private 103 | 104 | def load_vocab(vocab_file) 105 | vocab = {} 106 | tokens = File.readlines(vocab_file) 107 | tokens.each_with_index do |token, index| 108 | token = token.chomp 109 | vocab[token] = index 110 | end 111 | vocab 112 | end 113 | end 114 | end 115 | end 116 | -------------------------------------------------------------------------------- /lib/transformers/models/bert/tokenization_bert_fast.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module Bert 17 | class BertTokenizerFast < PreTrainedTokenizerFast 18 | VOCAB_FILES_NAMES = {vocab_file: "vocab.txt", tokenizer_file: "tokenizer.json"} 19 | 20 | self.vocab_files_names = VOCAB_FILES_NAMES 21 | self.slow_tokenizer_class = BertTokenizer 22 | 23 | def initialize( 24 | vocab_file: nil, 25 | tokenizer_file: nil, 26 | do_lower_case: true, 27 | unk_token: "[UNK]", 28 | sep_token: "[SEP]", 29 | pad_token: "[PAD]", 30 | cls_token: "[CLS]", 31 | mask_token: "[MASK]", 32 | tokenize_chinese_chars: true, 33 | strip_accents: nil, 34 | **kwargs 35 | ) 36 | super( 37 | vocab_file, 38 | tokenizer_file: tokenizer_file, 39 | do_lower_case: do_lower_case, 40 | unk_token: unk_token, 41 | sep_token: sep_token, 42 | pad_token: pad_token, 43 | cls_token: cls_token, 44 | mask_token: mask_token, 45 | tokenize_chinese_chars: tokenize_chinese_chars, 46 | strip_accents: strip_accents, 47 | **kwargs 48 | ) 49 | end 50 | end 51 | end 52 | end 53 | -------------------------------------------------------------------------------- /lib/transformers/models/deberta_v2/configuration_deberta_v2.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2020, Microsoft and the HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module DebertaV2 17 | class DebertaV2Config < PretrainedConfig 18 | self.model_type = "deberta-v2" 19 | 20 | attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads, 21 | :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob, 22 | :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps, 23 | :relative_attention, :max_relative_positions, :pad_token_id, :position_biased_input, 24 | :pos_att_type, :pooler_dropout, :pooler_hidden_act, :pooler_hidden_size 25 | 26 | def initialize( 27 | vocab_size: 128100, 28 | hidden_size: 1536, 29 | num_hidden_layers: 24, 30 | num_attention_heads: 24, 31 | intermediate_size: 6144, 32 | hidden_act: "gelu", 33 | hidden_dropout_prob: 0.1, 34 | attention_probs_dropout_prob: 0.1, 35 | max_position_embeddings: 512, 36 | type_vocab_size: 0, 37 | initializer_range: 0.02, 38 | layer_norm_eps: 1e-07, 39 | relative_attention: false, 40 | max_relative_positions: -1, 41 | pad_token_id: 0, 42 | position_biased_input: true, 43 | pos_att_type: nil, 44 | pooler_dropout: 0, 45 | pooler_hidden_act: "gelu", 46 | **kwargs 47 | ) 48 | super(**kwargs) 49 | 50 | @hidden_size = hidden_size 51 | @num_hidden_layers = num_hidden_layers 52 | @num_attention_heads = num_attention_heads 53 | @intermediate_size = intermediate_size 54 | @hidden_act = hidden_act 55 | @hidden_dropout_prob = hidden_dropout_prob 56 | @attention_probs_dropout_prob = attention_probs_dropout_prob 57 | @max_position_embeddings = max_position_embeddings 58 | @type_vocab_size = type_vocab_size 59 | @initializer_range = initializer_range 60 | @relative_attention = relative_attention 61 | @max_relative_positions = max_relative_positions 62 | @pad_token_id = pad_token_id 63 | @position_biased_input = position_biased_input 64 | 65 | # Backwards compatibility 66 | if pos_att_type.is_a?(String) 67 | pos_att_type = pos_att_type.downcase.split("|").map { |x| x.strip } 68 | end 69 | 70 | @pos_att_type = pos_att_type 71 | @vocab_size = vocab_size 72 | @layer_norm_eps = layer_norm_eps 73 | 74 | @pooler_hidden_size = kwargs[:pooler_hidden_size] || hidden_size 75 | @pooler_dropout = pooler_dropout 76 | @pooler_hidden_act = pooler_hidden_act 77 | end 78 | end 79 | end 80 | end 81 | -------------------------------------------------------------------------------- /lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Microsoft and the HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module DebertaV2 17 | class DebertaV2TokenizerFast < PreTrainedTokenizerFast 18 | VOCAB_FILES_NAMES = {vocab_file: "spm.model", tokenizer_file: "tokenizer.json"} 19 | 20 | self.vocab_files_names = VOCAB_FILES_NAMES 21 | # self.slow_tokenizer_class = DebertaV2Tokenizer 22 | 23 | def initialize( 24 | vocab_file: nil, 25 | tokenizer_file: nil, 26 | do_lower_case: false, 27 | split_by_punct: false, 28 | bos_token: "[CLS]", 29 | eos_token: "[SEP]", 30 | unk_token: "[UNK]", 31 | sep_token: "[SEP]", 32 | pad_token: "[PAD]", 33 | cls_token: "[CLS]", 34 | mask_token: "[MASK]", 35 | **kwargs 36 | ) 37 | super(vocab_file, tokenizer_file: tokenizer_file, do_lower_case: do_lower_case, bos_token: bos_token, eos_token: eos_token, unk_token: unk_token, sep_token: sep_token, pad_token: pad_token, cls_token: cls_token, mask_token: mask_token, split_by_punct: split_by_punct, **kwargs) 38 | 39 | @do_lower_case = do_lower_case 40 | @split_by_punct = split_by_punct 41 | @vocab_file = vocab_file 42 | end 43 | 44 | def can_save_slow_tokenizer 45 | @vocab_file ? File.exist?(@vocab_file) : false 46 | end 47 | 48 | def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil) 49 | if token_ids_1.nil? 50 | return [@cls_token_id] + token_ids_0 + [@sep_token_id] 51 | end 52 | cls = [@cls_token_id] 53 | sep = [@sep_token_id] 54 | cls + token_ids_0 + sep + token_ids_1 + sep 55 | end 56 | 57 | def get_special_tokens_mask(token_ids_0, token_ids_1: nil, already_has_special_tokens: false) 58 | if already_has_special_tokens 59 | return super(token_ids_0: token_ids_0, token_ids_1: token_ids_1, already_has_special_tokens: true) 60 | end 61 | 62 | if !token_ids_1.nil? 63 | return [1] + ([0] * token_ids_0.length) + [1] + ([0] * token_ids_1.length) + [1] 64 | end 65 | [1] + ([0] * token_ids_0.length) + [1] 66 | end 67 | 68 | def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil) 69 | sep = [@sep_token_id] 70 | cls = [@cls_token_id] 71 | if token_ids_1.nil? 72 | return (cls + token_ids_0 + sep).length * [0] 73 | end 74 | ((cls + token_ids_0 + sep).length * [0]) + ((token_ids_1 + sep).length * [1]) 75 | end 76 | end 77 | end 78 | end 79 | -------------------------------------------------------------------------------- /lib/transformers/models/distilbert/configuration_distilbert.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module Distilbert 17 | class DistilBertConfig < PretrainedConfig 18 | self.model_type = "distilbert" 19 | self.attribute_map = { 20 | hidden_size: "dim", 21 | num_attention_heads: "n_heads", 22 | num_hidden_layers: "n_layers" 23 | } 24 | 25 | attr_reader :vocab_size, :max_position_embeddings, :sinusoidal_pos_embds, :n_layers, :n_heads, 26 | :dim, :hidden_dim, :dropout, :attention_dropout, :activation, :initializer_range, :qa_dropout, 27 | :seq_classif_dropout, :pad_token_id 28 | 29 | def initialize( 30 | vocab_size: 30522, 31 | max_position_embeddings: 512, 32 | sinusoidal_pos_embds: false, 33 | n_layers: 6, 34 | n_heads: 12, 35 | dim: 768, 36 | hidden_dim: 4 * 768, 37 | dropout: 0.1, 38 | attention_dropout: 0.1, 39 | activation: "gelu", 40 | initializer_range: 0.02, 41 | qa_dropout: 0.1, 42 | seq_classif_dropout: 0.2, 43 | pad_token_id: 0, 44 | **kwargs 45 | ) 46 | @vocab_size = vocab_size 47 | @max_position_embeddings = max_position_embeddings 48 | @sinusoidal_pos_embds = sinusoidal_pos_embds 49 | @n_layers = n_layers 50 | @n_heads = n_heads 51 | @dim = dim 52 | @hidden_dim = hidden_dim 53 | @dropout = dropout 54 | @attention_dropout = attention_dropout 55 | @activation = activation 56 | @initializer_range = initializer_range 57 | @qa_dropout = qa_dropout 58 | @seq_classif_dropout = seq_classif_dropout 59 | super(**kwargs, pad_token_id: pad_token_id) 60 | end 61 | end 62 | end 63 | end 64 | -------------------------------------------------------------------------------- /lib/transformers/models/distilbert/tokenization_distilbert.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module Distilbert 17 | class DistilBertTokenizer < PreTrainedTokenizer 18 | VOCAB_FILES_NAMES = {vocab_file: "vocab.txt"} 19 | 20 | self.vocab_files_names = VOCAB_FILES_NAMES 21 | self.model_input_names = ["input_ids", "attention_mask"] 22 | 23 | class BasicTokenizer 24 | attr_reader :do_lower_case, :tokenize_chinese_chars, :strip_accents 25 | 26 | def initialize( 27 | do_lower_case: true, 28 | never_split: nil, 29 | tokenize_chinese_chars: true, 30 | strip_accents: nil, 31 | do_split_on_punc: true 32 | ) 33 | if never_split.nil? 34 | never_split = [] 35 | end 36 | @do_lower_case = do_lower_case 37 | @never_split = Set.new(never_split) 38 | @tokenize_chinese_chars = tokenize_chinese_chars 39 | @strip_accents = strip_accents 40 | @do_split_on_punc = do_split_on_punc 41 | end 42 | end 43 | 44 | class WordpieceTokenizer 45 | def initialize(vocab:, unk_token:, max_input_chars_per_word: 100) 46 | @vocab = vocab 47 | @unk_token = unk_token 48 | @max_input_chars_per_word = max_input_chars_per_word 49 | end 50 | end 51 | 52 | attr_reader :vocab, :basic_tokenizer 53 | 54 | def initialize( 55 | vocab_file:, 56 | do_lower_case: true, 57 | do_basic_tokenize: true, 58 | never_split: nil, 59 | unk_token: "[UNK]", 60 | sep_token: "[SEP]", 61 | pad_token: "[PAD]", 62 | cls_token: "[CLS]", 63 | mask_token: "[MASK]", 64 | tokenize_chinese_chars: true, 65 | strip_accents: nil, 66 | **kwargs 67 | ) 68 | @vocab = load_vocab(vocab_file) 69 | @ids_to_tokens = @vocab.invert 70 | @do_basic_tokenize = do_basic_tokenize 71 | if do_basic_tokenize 72 | @basic_tokenizer = 73 | BasicTokenizer.new( 74 | do_lower_case: do_lower_case, 75 | never_split: never_split, 76 | tokenize_chinese_chars: tokenize_chinese_chars, 77 | strip_accents: strip_accents 78 | ) 79 | end 80 | @wordpiece_tokenizer = WordpieceTokenizer.new(vocab: @vocab, unk_token: unk_token.to_s) 81 | 82 | super( 83 | do_lower_case: do_lower_case, 84 | do_basic_tokenize: do_basic_tokenize, 85 | never_split: never_split, 86 | unk_token: unk_token, 87 | sep_token: sep_token, 88 | pad_token: pad_token, 89 | cls_token: cls_token, 90 | mask_token: mask_token, 91 | tokenize_chinese_chars: tokenize_chinese_chars, 92 | strip_accents: strip_accents, 93 | **kwargs 94 | ) 95 | end 96 | 97 | def _convert_token_to_id(token) 98 | @vocab.fetch(token, @vocab.fetch(@unk_token)) 99 | end 100 | 101 | private 102 | 103 | def load_vocab(vocab_file) 104 | vocab = {} 105 | tokens = File.readlines(vocab_file) 106 | tokens.each_with_index do |token, index| 107 | token = token.chomp 108 | vocab[token] = index 109 | end 110 | vocab 111 | end 112 | end 113 | end 114 | end 115 | -------------------------------------------------------------------------------- /lib/transformers/models/distilbert/tokenization_distilbert_fast.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module Distilbert 17 | class DistilBertTokenizerFast < PreTrainedTokenizerFast 18 | VOCAB_FILES_NAMES = {vocab_file: "vocab.txt", tokenizer_file: "tokenizer.json"} 19 | 20 | self.vocab_files_names = VOCAB_FILES_NAMES 21 | self.model_input_names = ["input_ids", "attention_mask"] 22 | self.slow_tokenizer_class = DistilBertTokenizer 23 | 24 | def initialize( 25 | vocab_file: nil, 26 | tokenizer_file: nil, 27 | do_lower_case: true, 28 | unk_token: "[UNK]", 29 | sep_token: "[SEP]", 30 | pad_token: "[PAD]", 31 | cls_token: "[CLS]", 32 | mask_token: "[MASK]", 33 | tokenize_chinese_chars: true, 34 | strip_accents: nil, 35 | **kwargs 36 | ) 37 | super( 38 | vocab_file, 39 | tokenizer_file: tokenizer_file, 40 | do_lower_case: do_lower_case, 41 | unk_token: unk_token, 42 | sep_token: sep_token, 43 | pad_token: pad_token, 44 | cls_token: cls_token, 45 | mask_token: mask_token, 46 | tokenize_chinese_chars: tokenize_chinese_chars, 47 | strip_accents: strip_accents, 48 | **kwargs 49 | ) 50 | 51 | if @backend_tokenizer 52 | raise Todo 53 | end 54 | 55 | @do_lower_case = do_lower_case 56 | end 57 | 58 | def build_inputs_with_special_tokens(token_ids_0, token_ids_1 = nil) 59 | raise Todo 60 | end 61 | 62 | def create_token_type_ids_from_sequences(token_ids_0, token_ids_1 = nil) 63 | raise Todo 64 | end 65 | 66 | def save_vocabulary(save_directory, filename_prefix: nil) 67 | raise Todo 68 | end 69 | end 70 | end 71 | end 72 | -------------------------------------------------------------------------------- /lib/transformers/models/mpnet/configuration_mpnet.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | module Transformers 17 | module Mpnet 18 | class MPNetConfig < PretrainedConfig 19 | self.model_type = "mpnet" 20 | 21 | attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads, 22 | :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob, 23 | :max_position_embeddings, :initializer_range, :layer_norm_eps, :relative_attention_num_buckets, 24 | :pad_token_id, :bos_token_id, :eos_token_id 25 | 26 | def initialize( 27 | vocab_size: 30527, 28 | hidden_size: 768, 29 | num_hidden_layers: 12, 30 | num_attention_heads: 12, 31 | intermediate_size: 3072, 32 | hidden_act: "gelu", 33 | hidden_dropout_prob: 0.1, 34 | attention_probs_dropout_prob: 0.1, 35 | max_position_embeddings: 512, 36 | initializer_range: 0.02, 37 | layer_norm_eps: 1e-12, 38 | relative_attention_num_buckets: 32, 39 | pad_token_id: 1, 40 | bos_token_id: 0, 41 | eos_token_id: 2, 42 | **kwargs 43 | ) 44 | super(pad_token_id: pad_token_id, bos_token_id: bos_token_id, eos_token_id: eos_token_id, **kwargs) 45 | 46 | @vocab_size = vocab_size 47 | @hidden_size = hidden_size 48 | @num_hidden_layers = num_hidden_layers 49 | @num_attention_heads = num_attention_heads 50 | @hidden_act = hidden_act 51 | @intermediate_size = intermediate_size 52 | @hidden_dropout_prob = hidden_dropout_prob 53 | @attention_probs_dropout_prob = attention_probs_dropout_prob 54 | @max_position_embeddings = max_position_embeddings 55 | @initializer_range = initializer_range 56 | @layer_norm_eps = layer_norm_eps 57 | @relative_attention_num_buckets = relative_attention_num_buckets 58 | end 59 | end 60 | end 61 | end 62 | -------------------------------------------------------------------------------- /lib/transformers/models/mpnet/tokenization_mpnet_fast.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | module Transformers 17 | module Mpnet 18 | class MPNetTokenizerFast < PreTrainedTokenizerFast 19 | VOCAB_FILES_NAMES = {vocab_file: "vocab.txt", tokenizer_file: "tokenizer.json"} 20 | 21 | self.vocab_files_names = VOCAB_FILES_NAMES 22 | # self.slow_tokenizer_class = MPNetTokenizer 23 | self.model_input_names = ["input_ids", "attention_mask"] 24 | 25 | def initialize( 26 | vocab_file: nil, 27 | tokenizer_file: nil, 28 | do_lower_case: true, 29 | bos_token: "", 30 | eos_token: "", 31 | sep_token: "", 32 | cls_token: "", 33 | unk_token: "[UNK]", 34 | pad_token: "", 35 | mask_token: "", 36 | tokenize_chinese_chars: true, 37 | strip_accents: nil, 38 | **kwargs 39 | ) 40 | bos_token = bos_token.is_a?(String) ? Tokenizers::AddedToken.new(bos_token, lstrip: false, rstrip: false) : bos_token 41 | eos_token = eos_token.is_a?(String) ? Tokenizers::AddedToken.new(eos_token, lstrip: false, rstrip: false) : eos_token 42 | sep_token = sep_token.is_a?(String) ? Tokenizers::AddedToken.new(sep_token, lstrip: false, rstrip: false) : sep_token 43 | cls_token = cls_token.is_a?(String) ? Tokenizers::AddedToken.new(cls_token, lstrip: false, rstrip: false) : cls_token 44 | unk_token = unk_token.is_a?(String) ? Tokenizers::AddedToken.new(unk_token, lstrip: false, rstrip: false) : unk_token 45 | pad_token = pad_token.is_a?(String) ? Tokenizers::AddedToken.new(pad_token, lstrip: false, rstrip: false) : pad_token 46 | 47 | # Mask token behave like a normal word, i.e. include the space before it 48 | mask_token = mask_token.is_a?(String) ? Tokenizers::AddedToken.new(mask_token, lstrip: true, rstrip: false) : mask_token 49 | 50 | super(vocab_file, tokenizer_file: tokenizer_file, do_lower_case: do_lower_case, bos_token: bos_token, eos_token: eos_token, sep_token: sep_token, cls_token: cls_token, unk_token: unk_token, pad_token: pad_token, mask_token: mask_token, tokenize_chinese_chars: tokenize_chinese_chars, strip_accents: strip_accents, **kwargs) 51 | 52 | # TODO support 53 | # pre_tok_state = JSON.parse(backend_tokenizer.normalizer.__getstate__) 54 | # if (pre_tok_state["lowercase"] || do_lower_case) != do_lower_case || (pre_tok_state["strip_accents"] || strip_accents) != strip_accents 55 | # pre_tok_class = getattr(normalizers, pre_tok_state.delete("type")) 56 | # pre_tok_state["lowercase"] = do_lower_case 57 | # pre_tok_state["strip_accents"] = strip_accents 58 | # @normalizer = pre_tok_class(**pre_tok_state) 59 | # end 60 | 61 | @do_lower_case = do_lower_case 62 | end 63 | 64 | def mask_token 65 | if @mask_token.nil? 66 | if @verbose 67 | Transformers.logger.error("Using mask_token, but it is not set yet.") 68 | end 69 | return nil 70 | end 71 | @mask_token.to_s 72 | end 73 | 74 | def mask_token=(value) 75 | # Mask token behave like a normal word, i.e. include the space before it 76 | # So we set lstrip to True 77 | value = value.is_a?(String) ? Tokenizers::AddedToken.new(value, lstrip: true, rstrip: false) : value 78 | @mask_token = value 79 | end 80 | 81 | def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil) 82 | output = [@bos_token_id] + token_ids_0 + [@eos_token_id] 83 | if token_ids_1.nil? 84 | return output 85 | end 86 | 87 | output + [@eos_token_id] + token_ids_1 + [@eos_token_id] 88 | end 89 | 90 | def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil) 91 | sep = [@sep_token_id] 92 | cls = [@cls_token_id] 93 | 94 | if token_ids_1.nil? 95 | return (cls + token_ids_0 + sep).length * [0] 96 | end 97 | (cls + token_ids_0 + sep + sep + token_ids_1 + sep).length * [0] 98 | end 99 | 100 | def save_vocabulary(save_directory, filename_prefix: nil) 101 | files = @tokenizer.model.save(save_directory, name: filename_prefix) 102 | Array(files) 103 | end 104 | end 105 | end 106 | end 107 | -------------------------------------------------------------------------------- /lib/transformers/models/vit/configuration_vit.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google AI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module Vit 17 | class ViTConfig < PretrainedConfig 18 | self.model_type = "vit" 19 | 20 | attr_reader :hidden_size, :num_hidden_layers, :num_attention_heads, :intermediate_size, 21 | :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob, :initializer_range, 22 | :layer_norm_eps, :image_size, :patch_size, :num_channels, :qkv_bias, :encoder_stride 23 | 24 | def initialize( 25 | hidden_size: 768, 26 | num_hidden_layers: 12, 27 | num_attention_heads: 12, 28 | intermediate_size: 3072, 29 | hidden_act: "gelu", 30 | hidden_dropout_prob: 0.0, 31 | attention_probs_dropout_prob: 0.0, 32 | initializer_range: 0.02, 33 | layer_norm_eps: 1e-12, 34 | image_size: 224, 35 | patch_size: 16, 36 | num_channels: 3, 37 | qkv_bias: true, 38 | encoder_stride: 16, 39 | **kwargs 40 | ) 41 | super(**kwargs) 42 | 43 | @hidden_size = hidden_size 44 | @num_hidden_layers = num_hidden_layers 45 | @num_attention_heads = num_attention_heads 46 | @intermediate_size = intermediate_size 47 | @hidden_act = hidden_act 48 | @hidden_dropout_prob = hidden_dropout_prob 49 | @attention_probs_dropout_prob = attention_probs_dropout_prob 50 | @initializer_range = initializer_range 51 | @layer_norm_eps = layer_norm_eps 52 | @image_size = image_size 53 | @patch_size = patch_size 54 | @num_channels = num_channels 55 | @qkv_bias = qkv_bias 56 | @encoder_stride = encoder_stride 57 | end 58 | end 59 | end 60 | end 61 | -------------------------------------------------------------------------------- /lib/transformers/models/vit/image_processing_vit.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module Vit 17 | class ViTImageProcessor < BaseImageProcessor 18 | def initialize( 19 | do_resize: true, 20 | size: nil, 21 | resample: :bilinear, 22 | do_rescale: true, 23 | rescale_factor: 1 / 255.0, 24 | do_normalize: true, 25 | image_mean: nil, 26 | image_std: nil, 27 | **kwargs 28 | ) 29 | super(**kwargs) 30 | size = !size.nil? ? size : {height: 224, width: 224} 31 | size = ImageProcessingUtils.get_size_dict(size) 32 | @do_resize = do_resize 33 | @do_rescale = do_rescale 34 | @do_normalize = do_normalize 35 | @size = size 36 | @resample = resample 37 | @rescale_factor = rescale_factor 38 | @image_mean = !image_mean.nil? ? image_mean : IMAGENET_STANDARD_MEAN 39 | @image_std = !image_std.nil? ? image_std : IMAGENET_STANDARD_STD 40 | @valid_processor_keys = [ 41 | :images, 42 | :do_resize, 43 | :size, 44 | :resample, 45 | :do_rescale, 46 | :rescale_factor, 47 | :do_normalize, 48 | :image_mean, 49 | :image_std, 50 | :return_tensors, 51 | :data_format, 52 | :input_data_format 53 | ] 54 | end 55 | 56 | def resize( 57 | image, 58 | size, 59 | resample: :bilinear, 60 | data_format: nil, 61 | input_data_format: nil, 62 | **kwargs 63 | ) 64 | size = ImageProcessingUtils.get_size_dict(size) 65 | if !size.include?(:height) || !size.include?(:width) 66 | raise ArgumentError, "The `size` dictionary must contain the keys `height` and `width`. Got #{size.keys}" 67 | end 68 | output_size = [size[:height], size[:width]] 69 | ImageTransforms.resize( 70 | image, 71 | output_size, 72 | resample: resample, 73 | data_format: data_format, 74 | input_data_format: input_data_format, 75 | **kwargs 76 | ) 77 | end 78 | 79 | def preprocess( 80 | images, 81 | do_resize: nil, 82 | size: nil, 83 | resample: nil, 84 | do_rescale: nil, 85 | rescale_factor: nil, 86 | do_normalize: nil, 87 | image_mean: nil, 88 | image_std: nil, 89 | return_tensors: nil, 90 | data_format: ChannelDimension::FIRST, 91 | input_data_format: nil, 92 | **kwargs 93 | ) 94 | do_resize = !do_resize.nil? ? do_resize : @do_resize 95 | do_rescale = !do_rescale.nil? ? do_rescale : @do_rescale 96 | do_normalize = !do_normalize.nil? ? do_normalize : @do_normalize 97 | resample = !resample.nil? ? resample : @resample 98 | rescale_factor = !rescale_factor.nil? ? rescale_factor : @rescale_factor 99 | image_mean = !image_mean.nil? ? image_mean : @image_mean 100 | image_std = !image_std.nil? ? image_std : @image_std 101 | 102 | size = !size.nil? ? size : @size 103 | size_dict = ImageProcessingUtils.get_size_dict(size) 104 | 105 | images = ImageUtils.make_list_of_images(images) 106 | 107 | ImageUtils.validate_kwargs(captured_kwargs: kwargs.keys, valid_processor_keys: @valid_processor_keys) 108 | 109 | if !ImageUtils.valid_images(images) 110 | raise ArgumentError, 111 | "Invalid image type. Must be of type Vips::Image, Numo::NArray, or Torch::Tensor." 112 | end 113 | ImageUtils.validate_preprocess_arguments( 114 | do_rescale: do_rescale, 115 | rescale_factor: rescale_factor, 116 | do_normalize: do_normalize, 117 | image_mean: image_mean, 118 | image_std: image_std, 119 | do_resize: do_resize, 120 | size: size, 121 | resample: resample 122 | ) 123 | 124 | # All transformations expect numo arrays. 125 | images = images.map { |image| ImageUtils.to_numo_array(image) } 126 | 127 | if ImageUtils.is_scaled_image(images[0]) && do_rescale 128 | Transformers.logger.warn( 129 | "It looks like you are trying to rescale already rescaled images. If the input" + 130 | " images have pixel values between 0 and 1, set `do_rescale: false` to avoid rescaling them again." 131 | ) 132 | end 133 | 134 | if input_data_format.nil? 135 | # We assume that all images have the same channel dimension format. 136 | input_data_format = ImageUtils.infer_channel_dimension_format(images[0]) 137 | end 138 | 139 | if do_resize 140 | images = 141 | images.map do |image| 142 | resize(image, size_dict, resample: resample, input_data_format: input_data_format) 143 | end 144 | end 145 | 146 | if do_rescale 147 | images = 148 | images.map do |image| 149 | rescale(image, rescale_factor, input_data_format: input_data_format) 150 | end 151 | end 152 | 153 | if do_normalize 154 | images = 155 | images.map do |image| 156 | normalize(image, image_mean, image_std, input_data_format: input_data_format) 157 | end 158 | end 159 | 160 | images = 161 | images.map do |image| 162 | ImageTransforms.to_channel_dimension_format(image, data_format, input_channel_dim: input_data_format) 163 | end 164 | 165 | data = {pixel_values: images} 166 | BatchFeature.new(data: data, tensor_type: return_tensors) 167 | end 168 | end 169 | end 170 | end 171 | -------------------------------------------------------------------------------- /lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | module Transformers 17 | module XlmRoberta 18 | class XLMRobertaConfig < PretrainedConfig 19 | self.model_type = "xlm-roberta" 20 | 21 | attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads, 22 | :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob, 23 | :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps, 24 | :pad_token_id, :bos_token_id, :eos_token_id, :position_embedding_type, :use_cache, 25 | :classifier_dropout 26 | 27 | def initialize( 28 | vocab_size: 30522, 29 | hidden_size: 768, 30 | num_hidden_layers: 12, 31 | num_attention_heads: 12, 32 | intermediate_size: 3072, 33 | hidden_act: "gelu", 34 | hidden_dropout_prob: 0.1, 35 | attention_probs_dropout_prob: 0.1, 36 | max_position_embeddings: 512, 37 | type_vocab_size: 2, 38 | initializer_range: 0.02, 39 | layer_norm_eps: 1e-12, 40 | pad_token_id: 1, 41 | bos_token_id: 0, 42 | eos_token_id: 2, 43 | position_embedding_type: "absolute", 44 | use_cache: true, 45 | classifier_dropout: nil, 46 | **kwargs 47 | ) 48 | super(pad_token_id: pad_token_id, bos_token_id: bos_token_id, eos_token_id: eos_token_id, **kwargs) 49 | 50 | @vocab_size = vocab_size 51 | @hidden_size = hidden_size 52 | @num_hidden_layers = num_hidden_layers 53 | @num_attention_heads = num_attention_heads 54 | @hidden_act = hidden_act 55 | @intermediate_size = intermediate_size 56 | @hidden_dropout_prob = hidden_dropout_prob 57 | @attention_probs_dropout_prob = attention_probs_dropout_prob 58 | @max_position_embeddings = max_position_embeddings 59 | @type_vocab_size = type_vocab_size 60 | @initializer_range = initializer_range 61 | @layer_norm_eps = layer_norm_eps 62 | @position_embedding_type = position_embedding_type 63 | @use_cache = use_cache 64 | @classifier_dropout = classifier_dropout 65 | end 66 | end 67 | end 68 | end 69 | -------------------------------------------------------------------------------- /lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | module Transformers 16 | module XlmRoberta 17 | class XLMRobertaTokenizerFast < PreTrainedTokenizerFast 18 | VOCAB_FILES_NAMES = {vocab_file: "sentencepiece.bpe.model", tokenizer_file: "tokenizer.json"} 19 | 20 | self.vocab_files_names = VOCAB_FILES_NAMES 21 | self.model_input_names = ["input_ids", "attention_mask"] 22 | # self.slow_tokenizer_class = XLMRobertaTokenizer 23 | 24 | def initialize( 25 | vocab_file: nil, 26 | tokenizer_file: nil, 27 | bos_token: "", 28 | eos_token: "", 29 | sep_token: "", 30 | cls_token: "", 31 | unk_token: "", 32 | pad_token: "", 33 | mask_token: "", 34 | **kwargs 35 | ) 36 | # Mask token behave like a normal word, i.e. include the space before it 37 | mask_token = mask_token.is_a?(String) ? Tokenizers::AddedToken.new(mask_token, lstrip: true, rstrip: false) : mask_token 38 | 39 | super(vocab_file, tokenizer_file: tokenizer_file, bos_token: bos_token, eos_token: eos_token, sep_token: sep_token, cls_token: cls_token, unk_token: unk_token, pad_token: pad_token, mask_token: mask_token, **kwargs) 40 | 41 | @vocab_file = vocab_file 42 | end 43 | 44 | def can_save_slow_tokenizer 45 | @vocab_file ? File.exist?(@vocab_file) : false 46 | end 47 | 48 | def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil) 49 | if token_ids_1.nil? 50 | return [@cls_token_id] + token_ids_0 + [@sep_token_id] 51 | end 52 | cls = [@cls_token_id] 53 | sep = [@sep_token_id] 54 | cls + token_ids_0 + sep + sep + token_ids_1 + sep 55 | end 56 | 57 | def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil) 58 | sep = [@sep_token_id] 59 | cls = [@cls_token_id] 60 | 61 | if token_ids_1.nil? 62 | return (cls + token_ids_0 + sep).length * [0] 63 | end 64 | (cls + token_ids_0 + sep + sep + token_ids_1 + sep).length * [0] 65 | end 66 | end 67 | end 68 | end 69 | -------------------------------------------------------------------------------- /lib/transformers/pipelines/embedding.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | class EmbeddingPipeline < Pipeline 3 | def _sanitize_parameters(**kwargs) 4 | [{}, {}, kwargs] 5 | end 6 | 7 | def preprocess(inputs) 8 | @tokenizer.(inputs, return_tensors: @framework) 9 | end 10 | 11 | def _forward(model_inputs) 12 | { 13 | last_hidden_state: @model.(**model_inputs)[0], 14 | attention_mask: model_inputs[:attention_mask] 15 | } 16 | end 17 | 18 | def postprocess(model_outputs, pooling: "mean", normalize: true) 19 | output = model_outputs[:last_hidden_state] 20 | 21 | case pooling 22 | when "none" 23 | # do nothing 24 | when "mean" 25 | output = mean_pooling(output, model_outputs[:attention_mask]) 26 | when "cls" 27 | output = output[0.., 0] 28 | else 29 | raise Error, "Pooling method '#{pooling}' not supported." 30 | end 31 | 32 | if normalize 33 | output = Torch::NN::Functional.normalize(output, p: 2, dim: 1) 34 | end 35 | 36 | output[0].to_a 37 | end 38 | 39 | private 40 | 41 | def mean_pooling(output, attention_mask) 42 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float 43 | Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9) 44 | end 45 | end 46 | end 47 | -------------------------------------------------------------------------------- /lib/transformers/pipelines/feature_extraction.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | class FeatureExtractionPipeline < Pipeline 3 | def _sanitize_parameters(truncation: nil, tokenize_kwargs: nil, return_tensors: nil, **kwargs) 4 | if tokenize_kwargs.nil? 5 | tokenize_kwargs = {} 6 | end 7 | 8 | if !truncation.nil? 9 | if tokenize_kwargs.include?(:truncation) 10 | raise ArgumentError, 11 | "truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)" 12 | end 13 | tokenize_kwargs[:truncation] = truncation 14 | end 15 | 16 | preprocess_params = tokenize_kwargs 17 | 18 | postprocess_params = {} 19 | if !return_tensors.nil? 20 | postprocess_params[:return_tensors] = return_tensors 21 | end 22 | 23 | [preprocess_params, {}, postprocess_params] 24 | end 25 | 26 | def preprocess(inputs, **tokenize_kwargs) 27 | model_inputs = @tokenizer.(inputs, return_tensors: @framework, **tokenize_kwargs) 28 | model_inputs 29 | end 30 | 31 | def _forward(model_inputs) 32 | model_outputs = @model.(**model_inputs) 33 | model_outputs 34 | end 35 | 36 | def postprocess(model_outputs, return_tensors: false) 37 | # [0] is the first available tensor, logits or last_hidden_state. 38 | if return_tensors 39 | model_outputs[0] 40 | elsif @framework == "pt" 41 | model_outputs[0].to_a 42 | elsif @framework == "tf" 43 | raise Todo 44 | end 45 | end 46 | end 47 | end 48 | -------------------------------------------------------------------------------- /lib/transformers/pipelines/image_classification.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | class ClassificationFunction < ExplicitEnum 3 | SIGMOID = "sigmoid" 4 | SOFTMAX = "softmax" 5 | NONE = "none" 6 | end 7 | 8 | class ImageClassificationPipeline < Pipeline 9 | extend ClassAttribute 10 | 11 | class_attribute :function_to_apply, ClassificationFunction::NONE 12 | 13 | def initialize(*args, **kwargs) 14 | super(*args, **kwargs) 15 | Utils.requires_backends(self, "vision") 16 | check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES) 17 | end 18 | 19 | def _sanitize_parameters(top_k: nil, function_to_apply: nil, timeout: nil) 20 | preprocess_params = {} 21 | if !timeout.nil? 22 | preprocess_params[:timeout] = timeout 23 | end 24 | postprocess_params = {} 25 | if !top_k.nil? 26 | postprocess_params[:top_k] = top_k 27 | end 28 | if function_to_apply.is_a?(String) 29 | function_to_apply = ClassificationFunction.new(function_to_apply.downcase).to_s 30 | end 31 | if !function_to_apply.nil? 32 | postprocess_params[:function_to_apply] = function_to_apply 33 | end 34 | [preprocess_params, {}, postprocess_params] 35 | end 36 | 37 | def preprocess(image, timeout: nil) 38 | image = ImageUtils.load_image(image, timeout: timeout) 39 | model_inputs = @image_processor.(image, return_tensors: @framework) 40 | if @framework == "pt" 41 | # TODO 42 | # model_inputs = model_inputs.to(torch_dtype) 43 | end 44 | model_inputs 45 | end 46 | 47 | def _forward(model_inputs) 48 | model_outputs = @model.(**model_inputs) 49 | model_outputs 50 | end 51 | 52 | def postprocess(model_outputs, function_to_apply: nil, top_k: 5) 53 | if function_to_apply.nil? 54 | if @model.config.problem_type == "multi_label_classification" || @model.config.num_labels == 1 55 | function_to_apply = ClassificationFunction::SIGMOID 56 | elsif @model.config.problem_type == "single_label_classification" || @model.config.num_labels > 1 57 | function_to_apply = ClassificationFunction::SOFTMAX 58 | elsif @model.config.instance_variable_defined?(:@function_to_apply) && function_to_apply.nil? 59 | function_to_apply = @model.config.function_to_apply 60 | else 61 | function_to_apply = ClassificationFunction::NONE 62 | end 63 | end 64 | 65 | if top_k > @model.config.num_labels 66 | top_k = @model.config.num_labels 67 | end 68 | 69 | outputs = model_outputs[:logits][0] 70 | if @framework == "pt" && [Torch.bfloat16, Torch.float16].include?(outputs.dtype) 71 | outputs = outputs.to(Torch.float32).numo 72 | else 73 | outputs = outputs.numo 74 | end 75 | 76 | if function_to_apply == ClassificationFunction::SIGMOID 77 | scores = sigmoid(outputs) 78 | elsif function_to_apply == ClassificationFunction::SOFTMAX 79 | scores = softmax(outputs) 80 | elsif function_to_apply == ClassificationFunction::NONE 81 | scores = outputs 82 | else 83 | raise ArgumentError, "Unrecognized `function_to_apply` argument: #{function_to_apply}" 84 | end 85 | 86 | dict_scores = 87 | scores.to_a.map.with_index do |score, i| 88 | {label: @model.config.id2label[i], score: score} 89 | end 90 | dict_scores.sort_by! { |x| -x[:score] } 91 | if !top_k.nil? 92 | dict_scores = dict_scores[...top_k] 93 | end 94 | 95 | dict_scores 96 | end 97 | 98 | private 99 | 100 | def sigmoid(_outputs) 101 | 1.0 / (1.0 + Numo::NMath.exp(-_outputs)) 102 | end 103 | 104 | def softmax(_outputs) 105 | maxes = _outputs.max(axis: -1, keepdims: true) 106 | shifted_exp = Numo::NMath.exp(_outputs - maxes) 107 | shifted_exp / shifted_exp.sum(axis: -1, keepdims: true) 108 | end 109 | end 110 | end 111 | -------------------------------------------------------------------------------- /lib/transformers/pipelines/image_feature_extraction.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | class ImageFeatureExtractionPipeline < Pipeline 3 | def _sanitize_parameters(image_processor_kwargs: nil, return_tensors: nil, pool: nil, **kwargs) 4 | preprocess_params = image_processor_kwargs.nil? ? {} : image_processor_kwargs 5 | 6 | postprocess_params = {} 7 | if !pool.nil? 8 | postprocess_params[:pool] = pool 9 | end 10 | if !return_tensors.nil? 11 | postprocess_params[:return_tensors] = return_tensors 12 | end 13 | 14 | if kwargs.include?(:timeout) 15 | preprocess_params[:timeout] = kwargs[:timeout] 16 | end 17 | 18 | [preprocess_params, {}, postprocess_params] 19 | end 20 | 21 | def preprocess(image, timeout: nil, **image_processor_kwargs) 22 | image = ImageUtils.load_image(image, timeout: timeout) 23 | model_inputs = @image_processor.(image, return_tensors: @framework, **image_processor_kwargs) 24 | if @framework == "pt" 25 | # TODO 26 | # model_inputs = model_inputs.to(torch_dtype) 27 | end 28 | model_inputs 29 | end 30 | 31 | def _forward(model_inputs) 32 | model_outputs = @model.(**model_inputs) 33 | model_outputs 34 | end 35 | 36 | def postprocess(model_outputs, pool: nil, return_tensors: false) 37 | pool = !pool.nil? ? pool : false 38 | 39 | if pool 40 | raise Todo 41 | else 42 | # [0] is the first available tensor, logits or last_hidden_state. 43 | outputs = model_outputs[0] 44 | end 45 | 46 | if return_tensors 47 | return outputs 48 | end 49 | if @framework == "pt" 50 | outputs.to_a 51 | else 52 | raise Todo 53 | end 54 | end 55 | end 56 | end 57 | -------------------------------------------------------------------------------- /lib/transformers/pipelines/pt_utils.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | class PipelineDataset < Torch::Utils::Data::Dataset 3 | def initialize(dataset, process, params) 4 | @dataset = dataset 5 | @process = process 6 | @params = params 7 | end 8 | 9 | def size 10 | @dataset.size 11 | end 12 | 13 | def [](i) 14 | item = @dataset[i] 15 | processed = @process.(item, **@params) 16 | processed 17 | end 18 | end 19 | 20 | class PipelineIterator < Torch::Utils::Data::IterableDataset 21 | def initialize(loader, infer, params, loader_batch_size: nil) 22 | @loader = loader 23 | @infer = infer 24 | @params = params 25 | if loader_batch_size == 1 26 | # Let's spare some time by deactivating altogether 27 | loader_batch_size = nil 28 | end 29 | @loader_batch_size = loader_batch_size 30 | 31 | # Internal bookkeeping 32 | @loader_batch_index = nil 33 | @loader_batch_data = nil 34 | end 35 | 36 | def size 37 | @loader.size 38 | end 39 | 40 | def [](i) 41 | @infer.(@loader[i], **@params) 42 | end 43 | 44 | def each 45 | @iterator = @loader 46 | 47 | @iterator.each do |item| 48 | processed = @infer.(item, **@params) 49 | yield processed 50 | end 51 | end 52 | end 53 | end 54 | -------------------------------------------------------------------------------- /lib/transformers/pipelines/reranking.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | class RerankingPipeline < Pipeline 3 | def _sanitize_parameters(**kwargs) 4 | [{}, {}, kwargs] 5 | end 6 | 7 | def preprocess(inputs) 8 | @tokenizer.( 9 | [inputs[:query]] * inputs[:documents].length, 10 | text_pair: inputs[:documents], 11 | return_tensors: @framework, 12 | padding: true 13 | ) 14 | end 15 | 16 | def _forward(model_inputs) 17 | model_outputs = @model.(**model_inputs) 18 | model_outputs 19 | end 20 | 21 | def call(query, documents) 22 | super({query: query, documents: documents}) 23 | end 24 | 25 | def postprocess(model_outputs) 26 | model_outputs[0] 27 | .sigmoid 28 | .squeeze 29 | .to_a 30 | .map.with_index { |s, i| {index: i, score: s} } 31 | .sort_by { |v| -v[:score] } 32 | end 33 | end 34 | end 35 | -------------------------------------------------------------------------------- /lib/transformers/pipelines/text_classification.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | class TextClassificationPipeline < Pipeline 3 | def initialize(*args, **kwargs) 4 | super 5 | 6 | check_model_type(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES) 7 | end 8 | 9 | private 10 | 11 | def _sanitize_parameters(return_all_scores: nil, function_to_apply: nil, top_k: "", **tokenizer_kwargs) 12 | # Using "" as default argument because we're going to use `top_k=None` in user code to declare 13 | # "No top_k" 14 | preprocess_params = tokenizer_kwargs 15 | 16 | postprocess_params = {} 17 | if @model.config.respond_to?(:return_all_scores) && return_all_scores.nil? 18 | return_all_scores = @model.config.return_all_scores 19 | end 20 | 21 | if top_k.is_a?(Integer) || top_k.nil? 22 | postprocess_params[:top_k] = top_k 23 | postprocess_params[:_legacy] = false 24 | elsif !return_all_scores.nil? 25 | warn( 26 | "`return_all_scores` is now deprecated, if want a similar functionality use `top_k: nil` instead of" + 27 | " `return_all_scores: true` or `top_k: 1` instead of `return_all_scores: false`.", 28 | ) 29 | if return_all_scores 30 | postprocess_params[:top_k] = nil 31 | else 32 | postprocess_params[:top_k] = 1 33 | end 34 | end 35 | 36 | if function_to_apply.is_a?(String) 37 | function_to_apply = ClassificationFunction.new(function_to_apply.downcase).to_s 38 | end 39 | 40 | if !function_to_apply.nil? 41 | postprocess_params[:function_to_apply] = function_to_apply 42 | end 43 | [preprocess_params, {}, postprocess_params] 44 | end 45 | 46 | def preprocess(inputs, **tokenizer_kwargs) 47 | return_tensors = @framework 48 | if inputs.is_a?(Hash) 49 | return @tokenizer.(**inputs, return_tensors: return_tensors, **tokenizer_kwargs) 50 | elsif inputs.is_a?(Array) && inputs.length == 1 && inputs[0].is_a?(Array) && inputs[0].length == 2 51 | # It used to be valid to use a list of list of list for text pairs, keeping this path for BC 52 | return @tokenizer.( 53 | inputs[0][0], text_pair: inputs[0][1], return_tensors: return_tensors, **tokenizer_kwargs 54 | ) 55 | elsif inputs.is_a?(Array) 56 | # This is likely an invalid usage of the pipeline attempting to pass text pairs. 57 | raise ArgumentError, 58 | "The pipeline received invalid inputs, if you are trying to send text pairs, you can try to send a" + 59 | ' dictionary `{"text": "My text", "text_pair": "My pair"}` in order to send a text pair.' 60 | end 61 | @tokenizer.(inputs, return_tensors: return_tensors, **tokenizer_kwargs) 62 | end 63 | 64 | def _forward(model_inputs) 65 | @model.(**model_inputs) 66 | end 67 | 68 | def postprocess(model_outputs, function_to_apply: nil, top_k: 1, _legacy: true) 69 | if function_to_apply.nil? 70 | if @model.config.problem_type == "multi_label_classification" || @model.config.num_labels == 1 71 | function_to_apply = ClassificationFunction::SIGMOID 72 | elsif @model.config.problem_type == "single_label_classification" || @model.config.num_labels > 1 73 | function_to_apply = ClassificationFunction::SOFTMAX 74 | elsif @model.config.instance_variable_defined?(:@function_to_apply) && function_to_apply.nil? 75 | function_to_apply = @model.config.function_to_apply 76 | else 77 | function_to_apply = ClassificationFunction::NONE 78 | end 79 | end 80 | 81 | outputs = model_outputs["logits"][0] 82 | outputs = outputs.numo 83 | 84 | if function_to_apply == ClassificationFunction::SIGMOID 85 | scores = sigmoid(outputs) 86 | elsif function_to_apply == ClassificationFunction::SOFTMAX 87 | scores = softmax(outputs) 88 | elsif function_to_apply == ClassificationFunction::NONE 89 | scores = outputs 90 | else 91 | raise ArgumentError, "Unrecognized `function_to_apply` argument: #{function_to_apply}" 92 | end 93 | 94 | if top_k == 1 && _legacy 95 | return {label: @model.config.id2label[scores.argmax], score: scores.max} 96 | end 97 | 98 | dict_scores = 99 | scores.to_a.map.with_index do |score, i| 100 | {label: @model.config.id2label[i], score: score} 101 | end 102 | if !_legacy 103 | dict_scores.sort_by! { |x| -x[:score] } 104 | if !top_k.nil? 105 | dict_scores = dict_scores.first(top_k) 106 | end 107 | end 108 | dict_scores 109 | end 110 | 111 | private 112 | 113 | def sigmoid(_outputs) 114 | 1.0 / (1.0 + Numo::NMath.exp(-_outputs)) 115 | end 116 | 117 | def softmax(_outputs) 118 | maxes = _outputs.max(axis: -1, keepdims: true) 119 | shifted_exp = Numo::NMath.exp(_outputs - maxes) 120 | shifted_exp / shifted_exp.sum(axis: -1, keepdims: true) 121 | end 122 | end 123 | end 124 | -------------------------------------------------------------------------------- /lib/transformers/ruby_utils.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | module ClassAttribute 3 | def class_attribute(name, default = nil) 4 | singleton_class.attr_writer name 5 | var = "@#{name}" 6 | instance_variable_set(var, default) 7 | singleton_class.define_method(name) do 8 | # ancestors includes current module 9 | ancestors.find { |c| c.instance_variable_defined?(var) }.instance_variable_get(var) 10 | end 11 | define_method(name) do 12 | self.class.send(name) 13 | end 14 | end 15 | end 16 | 17 | module Copy 18 | def self.deepcopy(value, memo = {}) 19 | key = value.object_id 20 | if !memo.key?(key) 21 | copy = value.dup 22 | memo[key] = copy 23 | if value.is_a?(Hash) 24 | copy.transform_keys! { |k| deepcopy(k, memo) } 25 | copy.transform_values! { |v| deepcopy(v, memo) } 26 | elsif value.is_a?(Array) 27 | copy.map! { |v| deepcopy(v, memo) } 28 | end 29 | end 30 | memo[key] 31 | end 32 | end 33 | end 34 | -------------------------------------------------------------------------------- /lib/transformers/sentence_transformer.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | # TODO remove in 0.2.0 3 | class SentenceTransformer 4 | def initialize(model_id) 5 | @model_id = model_id 6 | @model = Transformers.pipeline("embedding", model_id) 7 | end 8 | 9 | def encode(sentences) 10 | # TODO check modules.json 11 | if [ 12 | "sentence-transformers/all-MiniLM-L6-v2", 13 | "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" 14 | ].include?(@model_id) 15 | @model.(sentences) 16 | else 17 | @model.(sentences, pooling: "cls", normalize: false) 18 | end 19 | end 20 | end 21 | end 22 | -------------------------------------------------------------------------------- /lib/transformers/tokenization_utils.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class PreTrainedTokenizer < PreTrainedTokenizerBase 17 | def initialize(**kwargs) 18 | # 2. init `_added_tokens_decoder` if child class did not 19 | if !instance_variable_defined?(:@added_tokens_decoder) 20 | @added_tokens_decoder = {} 21 | end 22 | 23 | # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite 24 | @added_tokens_decoder.merge!(kwargs.delete(:added_tokens_decoder) { {} }) 25 | @added_tokens_encoder = @added_tokens_decoder.to_h { |k, v| [k.content, v] } 26 | 27 | # 4 init the parent class 28 | super(**kwargs) 29 | end 30 | 31 | def is_fast 32 | false 33 | end 34 | 35 | def vocab_size 36 | raise NotImplementedError 37 | end 38 | 39 | def tokenize(text, **kwargs) 40 | raise Todo 41 | end 42 | 43 | def _encode_plus( 44 | text:, 45 | text_pair: nil, 46 | add_special_tokens: true, 47 | padding_strategy: PaddingStrategy::DO_NOT_PAD, 48 | truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE, 49 | max_length: nil, 50 | stride: 0, 51 | is_split_into_words: false, 52 | pad_to_multiple_of: nil, 53 | return_tensors: nil, 54 | return_token_type_ids: nil, 55 | return_attention_mask: nil, 56 | return_overflowing_tokens: false, 57 | return_special_tokens_mask: false, 58 | return_offsets_mapping: false, 59 | return_length: false, 60 | verbose: true, 61 | **kwargs 62 | ) 63 | get_input_ids = lambda do |text| 64 | if text.is_a?(String) 65 | tokens = tokenize(text, **kwargs) 66 | convert_tokens_to_ids(tokens) 67 | elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(String) 68 | if is_split_into_words 69 | raise Todo 70 | else 71 | convert_tokens_to_ids(text) 72 | end 73 | elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(Integer) 74 | text 75 | else 76 | if is_split_into_words 77 | raise ArgumentError, 78 | "Input #{text} is not valid. Should be a string or a list/tuple of strings when" + 79 | " `is_split_into_words=True`." 80 | else 81 | raise ArgumentError, 82 | "Input #{text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of" + 83 | " integers." 84 | end 85 | end 86 | end 87 | 88 | if return_offsets_mapping 89 | raise RuntimeError, 90 | "return_offset_mapping is not available when using Ruby tokenizers. " + 91 | "To use this feature, change your tokenizer to one deriving from " + 92 | "Transformers::PreTrainedTokenizerFast. " + 93 | "More information on available tokenizers at " + 94 | "https://github.com/huggingface/transformers/pull/2674" 95 | end 96 | 97 | first_ids = get_input_ids.(text) 98 | second_ids = !text_pair.nil? ? get_input_ids.(text_pair) : nil 99 | 100 | prepare_for_model( 101 | first_ids, 102 | pair_ids: second_ids, 103 | add_special_tokens: add_special_tokens, 104 | padding: padding_strategy, 105 | truncation: truncation_strategy, 106 | max_length: max_length, 107 | stride: stride, 108 | pad_to_multiple_of: pad_to_multiple_of, 109 | return_tensors: return_tensors, 110 | prepend_batch_axis: true, 111 | return_attention_mask: return_attention_mask, 112 | return_token_type_ids: return_token_type_ids, 113 | return_overflowing_tokens: return_overflowing_tokens, 114 | return_special_tokens_mask: return_special_tokens_mask, 115 | return_length: return_length, 116 | verbose: verbose 117 | ) 118 | end 119 | 120 | def convert_tokens_to_ids(tokens) 121 | if tokens.nil? 122 | return nil 123 | end 124 | 125 | if tokens.is_a?(String) 126 | return _convert_token_to_id_with_added_voc(tokens) 127 | end 128 | 129 | ids = [] 130 | tokens.each do |token| 131 | ids << _convert_token_to_id_with_added_voc(token) 132 | end 133 | ids 134 | end 135 | 136 | def _convert_token_to_id_with_added_voc(token) 137 | if token.nil? 138 | return nil 139 | end 140 | 141 | if @added_tokens_encoder.include?(token) 142 | return @added_tokens_encoder[token] 143 | end 144 | _convert_token_to_id(token) 145 | end 146 | 147 | def _convert_token_to_id(token) 148 | raise NotImplementedError 149 | end 150 | end 151 | end 152 | -------------------------------------------------------------------------------- /lib/transformers/torch_utils.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module TorchUtils 17 | def self.apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors) 18 | if chunk_size > 0 19 | raise Todo 20 | end 21 | 22 | forward_fn.(*input_tensors) 23 | end 24 | end 25 | end 26 | -------------------------------------------------------------------------------- /lib/transformers/utils/_init.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | WEIGHTS_NAME = "pytorch_model.bin" 17 | WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" 18 | TF2_WEIGHTS_NAME = "tf_model.h5" 19 | TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" 20 | TF_WEIGHTS_NAME = "model.ckpt" 21 | FLAX_WEIGHTS_NAME = "flax_model.msgpack" 22 | FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json" 23 | SAFE_WEIGHTS_NAME = "model.safetensors" 24 | SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" 25 | CONFIG_NAME = "config.json" 26 | FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" 27 | IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME 28 | PROCESSOR_NAME = "processor_config.json" 29 | GENERATION_CONFIG_NAME = "generation_config.json" 30 | MODEL_CARD_NAME = "modelcard.json" 31 | end 32 | -------------------------------------------------------------------------------- /lib/transformers/utils/generic.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class ModelOutput 17 | def self.attributes 18 | @attributes ||= [] 19 | end 20 | 21 | def self.attribute(attribute) 22 | attributes << attribute.to_sym 23 | 24 | define_method(attribute) do 25 | self[attribute] 26 | end 27 | end 28 | 29 | def initialize(**kwargs) 30 | @data = kwargs 31 | end 32 | 33 | def [](k) 34 | if k.is_a?(String) || k.is_a?(Symbol) 35 | @data[k.to_sym] 36 | else 37 | to_tuple[k] 38 | end 39 | end 40 | 41 | def to_tuple 42 | self.class.attributes.map { |k| @data[k] }.compact 43 | end 44 | end 45 | 46 | class ExplicitEnum 47 | def initialize(value) 48 | expected = self.class.constants.map { |k| self.class.const_get(k) } 49 | unless expected.include?(value) 50 | raise ArgumentError, "#{value} is not a valid #{self.class.name}, please select one of #{expected.inspect}" 51 | end 52 | @value = value 53 | end 54 | 55 | def to_s 56 | @value 57 | end 58 | end 59 | 60 | class PaddingStrategy < ExplicitEnum 61 | LONGEST = "longest" 62 | MAX_LENGTH = "max_length" 63 | DO_NOT_PAD = "do_not_pad" 64 | end 65 | 66 | class TensorType < ExplicitEnum 67 | PYTORCH = "pt" 68 | TENSORFLOW = "tf" 69 | NUMPY = "np" 70 | JAX = "jax" 71 | MLX = "mlx" 72 | end 73 | 74 | module Utils 75 | def self.infer_framework(model_class) 76 | if model_class < Torch::NN::Module 77 | "pt" 78 | else 79 | raise TypeError, "Could not infer framework from class #{model_class}." 80 | end 81 | end 82 | 83 | def self._is_numo(x) 84 | x.is_a?(Numo::NArray) 85 | end 86 | 87 | def self.is_numo_array(x) 88 | _is_numo(x) 89 | end 90 | 91 | def self._is_torch(x) 92 | x.is_a?(Torch::Tensor) 93 | end 94 | 95 | def self.is_torch_tensor(x) 96 | _is_torch(x) 97 | end 98 | 99 | def self._is_torch_device(x) 100 | x.is_a?(Torch::Device) 101 | end 102 | 103 | def self.is_torch_device(x) 104 | _is_torch_device(x) 105 | end 106 | end 107 | end 108 | -------------------------------------------------------------------------------- /lib/transformers/utils/hub.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module Utils 17 | module Hub 18 | IS_OFFLINE_MODE = HfHub::HF_HUB_OFFLINE 19 | 20 | PYTORCH_PRETRAINED_BERT_CACHE = ENV.fetch("PYTORCH_PRETRAINED_BERT_CACHE", HfHub::HF_HUB_CACHE) 21 | PYTORCH_TRANSFORMERS_CACHE = ENV.fetch("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) 22 | TRANSFORMERS_CACHE = ENV.fetch("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) 23 | 24 | DEFAULT_ENDPOINT = "https://huggingface.co" 25 | HUGGINGFACE_CO_RESOLVE_ENDPOINT = ENV.fetch("HF_ENDPOINT", DEFAULT_ENDPOINT) 26 | 27 | class << self 28 | def is_offline_mode 29 | IS_OFFLINE_MODE 30 | end 31 | 32 | def is_remote_url(url_or_filename) 33 | url_or_filename.is_a?(URI) 34 | end 35 | 36 | def http_user_agent(user_agent = nil) 37 | ua = "transformers.rb/#{Transformers::VERSION}; ruby/#{RUBY_VERSION.to_f}" 38 | if user_agent.is_a?(Hash) 39 | ua += "; " + user_agent.map { |k, v| "#{k}/#{v}" }.join("; ") 40 | elsif user_agent.is_a?(String) 41 | ua += "; " + user_agent 42 | end 43 | ua 44 | end 45 | 46 | def extract_commit_hash(resolved_file, commit_hash) 47 | if resolved_file.nil? || !commit_hash.nil? 48 | return commit_hash 49 | end 50 | search = /snapshots\/([^\/]+)/.match(resolved_file) 51 | if search.nil? 52 | return nil 53 | end 54 | commit_hash = search[1] 55 | HfHub::REGEX_COMMIT_HASH.match(commit_hash) ? commit_hash : nil 56 | end 57 | 58 | def cached_file( 59 | path_or_repo_id, 60 | filename, 61 | cache_dir: nil, 62 | force_download: false, 63 | resume_download: false, 64 | proxies: nil, 65 | token: nil, 66 | revision: nil, 67 | local_files_only: false, 68 | subfolder: "", 69 | repo_type: nil, 70 | user_agent: nil, 71 | _raise_exceptions_for_gated_repo: true, 72 | _raise_exceptions_for_missing_entries: true, 73 | _raise_exceptions_for_connection_errors: true, 74 | _commit_hash: nil, 75 | **deprecated_kwargs 76 | ) 77 | if is_offline_mode && !local_files_only 78 | Transformers.logger.info "Offline mode: forcing local_files_only: true" 79 | local_files_only = true 80 | end 81 | if subfolder.nil? 82 | subfolder = "" 83 | end 84 | 85 | path_or_repo_id = path_or_repo_id.to_s 86 | full_filename = File.join(subfolder, filename) 87 | if Dir.exist?(path_or_repo_id) 88 | raise Todo 89 | end 90 | 91 | if cache_dir.nil? 92 | cache_dir = TRANSFORMERS_CACHE 93 | end 94 | if cache_dir.is_a?(Pathname) 95 | cache_dir = cache_dir.to_s 96 | end 97 | 98 | if !_commit_hash.nil? && !force_download 99 | # If the file is cached under that commit hash, we return it directly. 100 | resolved_file = 101 | HfHub.try_to_load_from_cache( 102 | path_or_repo_id, full_filename, cache_dir: cache_dir, revision: _commit_hash, repo_type: repo_type 103 | ) 104 | if !resolved_file.nil? 105 | if resolved_file != HfHub::CACHED_NO_EXIST 106 | return resolved_file 107 | elsif !_raise_exceptions_for_missing_entries 108 | return nil 109 | else 110 | raise EnvironmentError, "Could not locate #{full_filename} inside #{path_or_repo_id}." 111 | end 112 | end 113 | end 114 | 115 | user_agent = http_user_agent(user_agent) 116 | 117 | resolved_file = nil 118 | begin 119 | resolved_file = 120 | HfHub.hf_hub_download( 121 | path_or_repo_id, 122 | filename, 123 | subfolder: subfolder.length == 0 ? nil : subfolder, 124 | repo_type: repo_type, 125 | revision: revision, 126 | cache_dir: cache_dir, 127 | user_agent: user_agent, 128 | force_download: force_download, 129 | proxies: proxies, 130 | resume_download: resume_download, 131 | token: token, 132 | local_files_only: local_files_only 133 | ) 134 | rescue => e 135 | raise e if _raise_exceptions_for_missing_entries 136 | end 137 | resolved_file 138 | end 139 | 140 | def has_file( 141 | path_or_repo, 142 | filename, 143 | revision: nil, 144 | proxies: nil, 145 | token: nil, 146 | local_files_only: false, 147 | cache_dir: nil, 148 | repo_type: nil, 149 | **deprecated_kwargs 150 | ) 151 | # If path to local directory, check if the file exists 152 | if Dir.exist?(path_or_repo) 153 | return File.exist?(File.join(path_or_repo, filename)) 154 | end 155 | 156 | # Else it's a repo => let's check if the file exists in local cache or on the Hub 157 | 158 | # Check if file exists in cache 159 | # This information might be outdated so it's best to also make a HEAD call (if allowed). 160 | cached_path = HfHub.try_to_load_from_cache( 161 | path_or_repo, 162 | filename, 163 | revision: revision, 164 | repo_type: repo_type, 165 | cache_dir: cache_dir 166 | ) 167 | has_file_in_cache = cached_path.is_a?(String) 168 | 169 | # If local_files_only, don't try the HEAD call 170 | if local_files_only 171 | return has_file_in_cache 172 | end 173 | 174 | # Check if the file exists 175 | begin 176 | HfHub._request_wrapper( 177 | "HEAD", 178 | HfHub.hf_hub_url(path_or_repo, filename, revision: revision, repo_type: repo_type), 179 | headers: HfHub.build_hf_headers(token: token, user_agent: http_user_agent), 180 | allow_redirects: false, 181 | proxies: proxies, 182 | timeout: 10 183 | ) 184 | true 185 | rescue HfHub::OfflineModeIsEnabled 186 | has_file_in_cache 187 | rescue HfHub::GatedRepoError => e 188 | Transformers.logger.error(e) 189 | raise EnvironmentError, 190 | "#{path_or_repo} is a gated repository. Make sure to request access at " + 191 | "https://huggingface.co/#{path_or_repo} and pass a token having permission to this repo either by " + 192 | "logging in with `huggingface-cli login` or by passing `token=`." 193 | rescue HfHub::RepositoryNotFoundError => e 194 | Transformers.logger.error(e) 195 | raise EnvironmentError, 196 | "#{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'." 197 | rescue HfHub::RevisionNotFoundError => e 198 | Transformers.logger.error(e) 199 | raise EnvironmentError, 200 | "#{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " + 201 | "model name. Check the model page at 'https://huggingface.co/#{path_or_repo}' for available revisions." 202 | rescue HfHub::EntryNotFoundError 203 | false # File does not exist 204 | end 205 | end 206 | end 207 | end 208 | end 209 | end 210 | -------------------------------------------------------------------------------- /lib/transformers/utils/import_utils.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | module Utils 17 | ENV_VARS_TRUE_VALUES = ["1", "ON", "YES", "TRUE"] 18 | 19 | def self.requires_backends(obj, backends) 20 | if !backends.is_a?(Array) 21 | backends = [backends] 22 | end 23 | 24 | name = obj.is_a?(Symbol) ? obj : obj.class.name 25 | 26 | checks = backends.map { |backend| BACKENDS_MAPPING.fetch(backend) } 27 | failed = checks.filter_map { |available, msg| format(msg, name) if !available.call } 28 | if failed.any? 29 | raise Error, failed.join("") 30 | end 31 | end 32 | 33 | def self.is_vision_available 34 | defined?(Vips) 35 | end 36 | 37 | VISION_IMPORT_ERROR = <<~MSG 38 | %s requires the `ruby-vips` gem 39 | MSG 40 | 41 | BACKENDS_MAPPING = { 42 | "vision" => [singleton_method(:is_vision_available), VISION_IMPORT_ERROR] 43 | } 44 | end 45 | end 46 | -------------------------------------------------------------------------------- /lib/transformers/utils/logging.rb: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Optuna, Hugging Face 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | module Transformers 16 | class << self 17 | attr_accessor :logger 18 | end 19 | 20 | # TODO add detail 21 | LOG_LEVELS = { 22 | "debug" => Logger::DEBUG, 23 | "info" => Logger::INFO, 24 | "warning" => Logger::WARN, 25 | "error" => Logger::ERROR, 26 | "critical" => Logger::FATAL 27 | } 28 | 29 | DEFAULT_LOG_LEVEL = Logger::WARN 30 | 31 | def self._get_default_logging_level 32 | env_level_str = ENV["TRANSFORMERS_VERBOSITY"] 33 | if env_level_str 34 | if LOG_LEVELS.include?(env_level_str) 35 | return LOG_LEVELS[env_level_str] 36 | else 37 | warn( 38 | "Unknown option TRANSFORMERS_VERBOSITY=#{env_level_str}, " + 39 | "has to be one of: #{LOG_LEVELS.keys.join(", ")}" 40 | ) 41 | end 42 | end 43 | DEFAULT_LOG_LEVEL 44 | end 45 | 46 | self.logger = begin 47 | logger = Logger.new(STDERR) 48 | logger.level = _get_default_logging_level 49 | logger.formatter = proc { |severity, datetime, progname, msg| "#{msg}\n" } 50 | logger 51 | end 52 | end 53 | -------------------------------------------------------------------------------- /lib/transformers/version.rb: -------------------------------------------------------------------------------- 1 | module Transformers 2 | VERSION = "0.1.6" 3 | end 4 | -------------------------------------------------------------------------------- /licenses/LICENSE-huggingface-hub.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /licenses/LICENSE-sentence-transformers.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 Nils Reimers 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /licenses/NOTICE-sentence-transformers.txt: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------------- 2 | Copyright 2019 3 | Ubiquitous Knowledge Processing (UKP) Lab 4 | Technische Universität Darmstadt 5 | ------------------------------------------------------------------------------- -------------------------------------------------------------------------------- /test/model_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class ModelTest < Minitest::Test 4 | def setup 5 | skip if ci? 6 | end 7 | 8 | # https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 9 | def test_all_mini_lm 10 | sentences = ["This is an example sentence", "Each sentence is converted"] 11 | 12 | model = Transformers.pipeline("embedding", "sentence-transformers/all-MiniLM-L6-v2") 13 | embeddings = model.(sentences) 14 | 15 | assert_elements_in_delta [0.067657, 0.063496, 0.048713], embeddings[0][..2] 16 | assert_elements_in_delta [0.086439, 0.10276, 0.0053946], embeddings[1][..2] 17 | end 18 | 19 | # https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1 20 | def test_multi_qa_minilm 21 | query = "How many people live in London?" 22 | docs = ["Around 9 Million people live in London", "London is known for its financial district"] 23 | 24 | model = Transformers.pipeline("embedding", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1") 25 | query_embedding = model.(query) 26 | doc_embeddings = model.(docs) 27 | scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } } 28 | doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s } 29 | 30 | assert_equal "Around 9 Million people live in London", doc_score_pairs[0][0] 31 | assert_in_delta 0.9156, doc_score_pairs[0][1] 32 | assert_equal "London is known for its financial district", doc_score_pairs[1][0] 33 | assert_in_delta 0.4948, doc_score_pairs[1][1] 34 | end 35 | 36 | # https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2 37 | def test_paraphrase_minilm 38 | sentences = ["This is an example sentence", "Each sentence is converted"] 39 | 40 | model = Transformers.pipeline("embedding", "sentence-transformers/paraphrase-MiniLM-L6-v2") 41 | embeddings = model.(sentences, normalize: false) 42 | 43 | assert_elements_in_delta [0.067359, 0.783935, 0.270018], embeddings[0][..2] 44 | assert_elements_in_delta [0.122117, 0.670228, 0.317166], embeddings[1][..2] 45 | end 46 | 47 | # https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1 48 | def test_mxbai_embed 49 | query_prefix = "Represent this sentence for searching relevant passages: " 50 | 51 | input = [ 52 | query_prefix + "puppy", 53 | "The dog is barking", 54 | "The cat is purring" 55 | ] 56 | 57 | model = Transformers.pipeline("embedding", "mixedbread-ai/mxbai-embed-large-v1") 58 | embeddings = model.(input, pooling: "cls", normalize: false) 59 | 60 | assert_elements_in_delta [-0.00624076, 0.12864432, 0.5248165], embeddings[0][..2] 61 | assert_elements_in_delta [-0.61227727, 1.4060247, -0.04079155], embeddings[-1][..2] 62 | end 63 | 64 | # https://huggingface.co/thenlper/gte-small 65 | def test_gte_small 66 | sentences = ["That is a happy person", "That is a very happy person"] 67 | 68 | model = Transformers.pipeline("embedding", "thenlper/gte-small") 69 | embeddings = model.(sentences) 70 | 71 | assert_elements_in_delta [-0.05316979, 0.01044252, 0.06194701], embeddings[0][..2] 72 | assert_elements_in_delta [-0.05246907, 0.03752426, 0.07344585], embeddings[-1][..2] 73 | end 74 | 75 | # https://huggingface.co/intfloat/e5-base-v2 76 | def test_e5_base 77 | doc_prefix = "passage: " 78 | query_prefix = "query: " 79 | 80 | input = [ 81 | doc_prefix + "Ruby is a programming language created by Matz", 82 | query_prefix + "Ruby creator" 83 | ] 84 | 85 | model = Transformers.pipeline("embedding", "intfloat/e5-base-v2") 86 | embeddings = model.(input) 87 | 88 | assert_elements_in_delta [-0.00596662, -0.03730119, -0.0703470], embeddings[0][..2] 89 | assert_elements_in_delta [0.00298353, -0.04421991, -0.0591884], embeddings[-1][..2] 90 | end 91 | 92 | # https://huggingface.co/BAAI/bge-base-en-v1.5 93 | def test_bge_base 94 | query_prefix = "Represent this sentence for searching relevant passages: " 95 | 96 | input = [ 97 | "The dog is barking", 98 | "The cat is purring", 99 | query_prefix + "puppy" 100 | ] 101 | 102 | model = Transformers.pipeline("embedding", "BAAI/bge-base-en-v1.5") 103 | embeddings = model.(input) 104 | 105 | assert_elements_in_delta [-0.07482512, -0.0770234, 0.03398684], embeddings[1][..2] 106 | assert_elements_in_delta [0.00029264, -0.0619305, -0.06199387], embeddings[-1][..2] 107 | end 108 | 109 | # https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5 110 | def test_snowflake_arctic_embed 111 | query_prefix = "Represent this sentence for searching relevant passages: " 112 | 113 | input = [ 114 | "The dog is barking", 115 | "The cat is purring", 116 | query_prefix + "puppy" 117 | ] 118 | 119 | model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5") 120 | embeddings = model.(input, pooling: "cls") 121 | 122 | assert_elements_in_delta [0.03239886, 0.0009998, 0.08401278], embeddings[0][..2] 123 | assert_elements_in_delta [-0.02530634, -0.02715422, 0.01218867], embeddings[-1][..2] 124 | end 125 | 126 | # https://huggingface.co/sentence-transformers/all-mpnet-base-v2 127 | def test_all_mpnet 128 | sentences = ["This is an example sentence", "Each sentence is converted"] 129 | 130 | model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2") 131 | embeddings = model.(sentences) 132 | 133 | assert_elements_in_delta [0.02250263, -0.07829167, -0.02303071], embeddings[0][..2] 134 | assert_elements_in_delta [0.04170236, 0.00109747, -0.01553415], embeddings[1][..2] 135 | end 136 | 137 | # https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1 138 | def test_opensearch 139 | docs = ["The dog is barking", "The cat is purring", "The bear is growling"] 140 | 141 | model_id = "opensearch-project/opensearch-neural-sparse-encoding-v1" 142 | model = Transformers::AutoModelForMaskedLM.from_pretrained(model_id) 143 | tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id) 144 | special_token_ids = tokenizer.special_tokens_map.map { |_, token| tokenizer.vocab[token] } 145 | 146 | feature = tokenizer.(docs, padding: true, truncation: true, return_tensors: "pt", return_token_type_ids: false) 147 | output = model.(**feature)[0] 148 | 149 | values, _ = Torch.max(output * feature[:attention_mask].unsqueeze(-1), dim: 1) 150 | values = Torch.log(1 + Torch.relu(values)) 151 | values[0.., special_token_ids] = 0 152 | embeddings = values.to_a 153 | 154 | assert_equal 74, embeddings[0].count { |v| v != 0 } 155 | assert_equal 77, embeddings[1].count { |v| v != 0 } 156 | assert_equal 102, embeddings[2].count { |v| v != 0 } 157 | end 158 | 159 | # https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1 160 | def test_mxbai_rerank 161 | query = "How many people live in London?" 162 | docs = ["Around 9 Million people live in London", "London is known for its financial district"] 163 | 164 | model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1") 165 | result = model.(query, docs) 166 | 167 | assert_equal 0, result[0][:index] 168 | assert_in_delta 0.984, result[0][:score] 169 | 170 | assert_equal 1, result[1][:index] 171 | assert_in_delta 0.139, result[1][:score] 172 | end 173 | 174 | # https://huggingface.co/BAAI/bge-reranker-base 175 | def test_bge_reranker 176 | query = "How many people live in London?" 177 | docs = ["Around 9 Million people live in London", "London is known for its financial district"] 178 | 179 | model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base") 180 | result = model.(query, docs) 181 | 182 | assert_equal 0, result[0][:index] 183 | assert_in_delta 0.996, result[0][:score] 184 | 185 | assert_equal 1, result[1][:index] 186 | assert_in_delta 0.000158, result[1][:score], 0.000001 187 | end 188 | end 189 | -------------------------------------------------------------------------------- /test/pipeline_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class PipelineTest < Minitest::Test 4 | def test_ner 5 | ner = Transformers.pipeline("ner") 6 | result = ner.("Ruby is a programming language created by Matz") 7 | assert_equal 3, result.size 8 | assert_equal "I-MISC", result[0][:entity] 9 | assert_in_delta 0.96, result[0][:score] 10 | assert_equal 1, result[0][:index] 11 | assert_equal "Ruby", result[0][:word] 12 | assert_equal 0, result[0][:start] 13 | assert_equal 4, result[0][:end] 14 | end 15 | 16 | def test_ner_aggregation_strategy 17 | ner = Transformers.pipeline("ner", aggregation_strategy: "simple") 18 | result = ner.("Ruby is a programming language created by Matz") 19 | assert_equal 2, result.size 20 | 21 | assert_equal "MISC", result[0][:entity_group] 22 | assert_in_delta 0.9608, result[0][:score] 23 | assert_equal "Ruby", result[0][:word] 24 | assert_equal 0, result[0][:start] 25 | assert_equal 4, result[0][:end] 26 | 27 | assert_equal "PER", result[1][:entity_group] 28 | assert_in_delta 0.9496, result[1][:score] 29 | assert_equal "Matz", result[1][:word] 30 | assert_equal 42, result[1][:start] 31 | assert_equal 46, result[1][:end] 32 | end 33 | 34 | def test_sentiment_analysis 35 | classifier = Transformers.pipeline("sentiment-analysis") 36 | result = classifier.("We are very happy to show you the 🤗 Transformers library.") 37 | assert_equal "POSITIVE", result[:label] 38 | assert_in_delta 0.9998, result[:score] 39 | 40 | result = classifier.(["We are very happy to show you the 🤗 Transformers library.", "We hope you don't hate it."]) 41 | assert_equal "POSITIVE", result[0][:label] 42 | assert_in_delta 0.9998, result[0][:score] 43 | assert_equal "NEGATIVE", result[1][:label] 44 | assert_in_delta 0.5309, result[1][:score] 45 | end 46 | 47 | def test_question_answering 48 | qa = Transformers.pipeline("question-answering") 49 | result = qa.(question: "Who invented Ruby?", context: "Ruby is a programming language created by Matz") 50 | assert_in_delta 0.998, result[:score] 51 | assert_equal 42, result[:start] 52 | assert_equal 46, result[:end] 53 | assert_equal "Matz", result[:answer] 54 | 55 | result = qa.("Who invented Ruby?", "Ruby is a programming language created by Matz") 56 | assert_equal "Matz", result[:answer] 57 | end 58 | 59 | def test_feature_extraction 60 | fe = Transformers.pipeline("feature-extraction") 61 | result = fe.("We are very happy to show you the 🤗 Transformers library.") 62 | assert_in_delta 0.454, result[0][0][0] 63 | end 64 | 65 | def test_embedding 66 | sentences = ["This is an example sentence", "Each sentence is converted"] 67 | embed = Transformers.pipeline("embedding") 68 | embeddings = embed.(sentences) 69 | assert_elements_in_delta [0.067657, 0.063496, 0.048713], embeddings[0][..2] 70 | assert_elements_in_delta [0.086439, 0.10276, 0.0053946], embeddings[1][..2] 71 | end 72 | 73 | def test_reranking 74 | query = "How many people live in London?" 75 | docs = ["Around 9 Million people live in London", "London is known for its financial district"] 76 | rerank = Transformers.pipeline("reranking") 77 | result = rerank.(query, docs) 78 | assert_equal 2, result.size 79 | assert_equal 0, result[0][:index] 80 | assert_in_delta 0.984, result[0][:score] 81 | assert_equal 1, result[1][:index] 82 | assert_in_delta 0.139, result[1][:score] 83 | end 84 | 85 | def test_image_classification 86 | classifier = Transformers.pipeline("image-classification") 87 | result = classifier.("test/support/pipeline-cat-chonk.jpeg") 88 | assert_equal "lynx, catamount", result[0][:label] 89 | assert_in_delta 0.433, result[0][:score], 0.01 90 | assert_equal "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", result[1][:label] 91 | assert_in_delta 0.035, result[1][:score], 0.01 92 | end 93 | 94 | def test_image_feature_extraction 95 | fe = Transformers.pipeline("image-feature-extraction") 96 | result = fe.("test/support/pipeline-cat-chonk.jpeg") 97 | assert_in_delta 0.868, result[0][0][0], 0.01 98 | end 99 | 100 | def test_device 101 | skip unless mac? 102 | 103 | sentences = ["This is an example sentence", "Each sentence is converted"] 104 | embed = Transformers.pipeline("embedding", device: "mps") 105 | embeddings = embed.(sentences) 106 | assert_elements_in_delta [0.067657, 0.063496, 0.048713], embeddings[0][..2] 107 | assert_elements_in_delta [0.086439, 0.10276, 0.0053946], embeddings[1][..2] 108 | end 109 | 110 | def test_pipeline_input_works_with_more_than_ten 111 | embedding = Transformers.pipeline("embedding") 112 | 11.times do 113 | result = embedding.("Ruby is a programming language created by Matz") 114 | assert_instance_of(Array, result) 115 | end 116 | end 117 | end 118 | -------------------------------------------------------------------------------- /test/test_helper.rb: -------------------------------------------------------------------------------- 1 | require "bundler/setup" 2 | Bundler.require(:default) 3 | require "minitest/autorun" 4 | 5 | unless ENV["TRANSFORMERS_VERBOSITY"] 6 | Transformers.logger.level = Logger::ERROR 7 | end 8 | 9 | Transformers.fast_init = true 10 | 11 | class Minitest::Test 12 | def assert_elements_in_delta(expected, actual) 13 | assert_equal expected.size, actual.size 14 | expected.zip(actual) do |exp, act| 15 | assert_in_delta exp, act 16 | end 17 | end 18 | 19 | def ci? 20 | ENV["CI"] 21 | end 22 | 23 | def mac? 24 | RbConfig::CONFIG["host_os"] =~ /darwin/i 25 | end 26 | end 27 | -------------------------------------------------------------------------------- /test/tokenizer_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class TokenizerTest < Minitest::Test 4 | def test_auto_tokenizer 5 | model_name = "nlptown/bert-base-multilingual-uncased-sentiment" 6 | tokenizer = Transformers::AutoTokenizer.from_pretrained(model_name) 7 | 8 | encoding = tokenizer.("We are very happy to show you the 🤗 Transformers library.") 9 | assert_equal [101, 11312, 10320, 12495, 19308, 10114, 11391, 10855, 10103, 100, 58263, 13299, 119, 102], encoding[:input_ids] 10 | assert_equal [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], encoding[:attention_mask] 11 | end 12 | end 13 | -------------------------------------------------------------------------------- /transformers-rb.gemspec: -------------------------------------------------------------------------------- 1 | require_relative "lib/transformers/version" 2 | 3 | Gem::Specification.new do |spec| 4 | spec.name = "transformers-rb" 5 | spec.version = Transformers::VERSION 6 | spec.summary = "State-of-the-art transformers for Ruby" 7 | spec.homepage = "https://github.com/ankane/transformers-ruby" 8 | spec.license = "Apache-2.0" 9 | 10 | spec.author = "Andrew Kane" 11 | spec.email = "andrew@ankane.org" 12 | 13 | spec.files = Dir["*.{md,txt}", "{lib,licenses}/**/*"] 14 | spec.require_path = "lib" 15 | 16 | spec.required_ruby_version = ">= 3.1" 17 | 18 | spec.add_dependency "logger" 19 | spec.add_dependency "numo-narray", ">= 0.9.2" 20 | spec.add_dependency "safetensors", ">= 0.1.1" 21 | spec.add_dependency "tokenizers", ">= 0.5.3" 22 | spec.add_dependency "torch-rb", ">= 0.17.1" 23 | end 24 | --------------------------------------------------------------------------------