├── test ├── support │ ├── data.txt │ └── data.csv ├── test_helper.rb └── data_pipes_test.rb ├── lib ├── torchdata │ ├── version.rb │ └── data_pipes │ │ └── iter │ │ └── util │ │ ├── csv_parser.rb │ │ └── random_splitter.rb └── torchdata.rb ├── Gemfile ├── .gitignore ├── CHANGELOG.md ├── Rakefile ├── torchdata.gemspec ├── .github └── workflows │ └── build.yml ├── README.md └── LICENSE.txt /test/support/data.txt: -------------------------------------------------------------------------------- 1 | hello 2 | -------------------------------------------------------------------------------- /test/support/data.csv: -------------------------------------------------------------------------------- 1 | 1,one 2 | 2,two 3 | 3,three 4 | 4,four 5 | -------------------------------------------------------------------------------- /lib/torchdata/version.rb: -------------------------------------------------------------------------------- 1 | module TorchData 2 | VERSION = "0.0.3" 3 | end 4 | -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | gemspec 4 | 5 | gem "rake" 6 | gem "minitest" 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.bundle/ 2 | /.yardoc 3 | /_yardoc/ 4 | /coverage/ 5 | /doc/ 6 | /pkg/ 7 | /spec/reports/ 8 | /tmp/ 9 | *.lock 10 | -------------------------------------------------------------------------------- /test/test_helper.rb: -------------------------------------------------------------------------------- 1 | require "bundler/setup" 2 | Bundler.require(:default) 3 | require "minitest/autorun" 4 | require "minitest/pride" 5 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.0.3 (2024-12-29) 2 | 3 | - Added support for Ruby 3.4 4 | 5 | ## 0.0.2 (2024-08-02) 6 | 7 | - Dropped support for Ruby < 3.1 8 | 9 | ## 0.0.1 (2023-01-30) 10 | 11 | - First release 12 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require "bundler/gem_tasks" 2 | require "rake/testtask" 3 | 4 | Rake::TestTask.new(:test) do |t| 5 | t.libs << "test" 6 | t.pattern = "test/**/*_test.rb" 7 | end 8 | 9 | task default: :test 10 | -------------------------------------------------------------------------------- /lib/torchdata.rb: -------------------------------------------------------------------------------- 1 | # dependencies 2 | require "torch" 3 | 4 | # stdlib 5 | require "csv" 6 | 7 | # modules 8 | require_relative "torchdata/version" 9 | 10 | module TorchData 11 | class Error < StandardError; end 12 | 13 | module DataPipes 14 | module Iter 15 | IterDataPipe = Torch::Utils::Data::DataPipes::IterDataPipe 16 | FileLister = Torch::Utils::Data::DataPipes::Iter::FileLister 17 | FileOpener = Torch::Utils::Data::DataPipes::Iter::FileOpener 18 | end 19 | end 20 | end 21 | 22 | require_relative "torchdata/data_pipes/iter/util/csv_parser" 23 | require_relative "torchdata/data_pipes/iter/util/random_splitter" 24 | -------------------------------------------------------------------------------- /torchdata.gemspec: -------------------------------------------------------------------------------- 1 | require_relative "lib/torchdata/version" 2 | 3 | Gem::Specification.new do |spec| 4 | spec.name = "torchdata" 5 | spec.version = TorchData::VERSION 6 | spec.summary = "Composable data loading for Ruby" 7 | spec.homepage = "https://github.com/ankane/torchdata-ruby" 8 | spec.license = "BSD-3-Clause" 9 | 10 | spec.author = "Andrew Kane" 11 | spec.email = "andrew@ankane.org" 12 | 13 | spec.files = Dir["*.{md,txt}", "{lib}/**/*"] 14 | spec.require_path = "lib" 15 | 16 | spec.required_ruby_version = ">= 3.1" 17 | 18 | spec.add_dependency "csv" 19 | spec.add_dependency "torch-rb", ">= 0.13" 20 | end 21 | -------------------------------------------------------------------------------- /test/data_pipes_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class DataPipesTest < Minitest::Test 4 | # https://pytorch.org/data/main/tutorial.html 5 | def test_works 6 | folder = "test/support" 7 | datapipe = TorchData::DataPipes::Iter::FileLister.new([folder]).filter { |filename| filename.end_with?(".csv") } 8 | datapipe = TorchData::DataPipes::Iter::FileOpener.new(datapipe, mode: "rt") 9 | datapipe = datapipe.parse_csv(delimiter: ",") 10 | train, valid = datapipe.random_split(total_length: 4, weights: {train: 0.5, valid: 0.5}, seed: 0) 11 | expected = [ 12 | ["1", "one"], 13 | ["2", "two"], 14 | ["3", "three"], 15 | ["4", "four"] 16 | ] 17 | assert_equal expected, (train.to_a + valid.to_a).sort 18 | assert_equal 2, train.count 19 | assert_equal 2, valid.count 20 | end 21 | end 22 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | on: [push, pull_request] 3 | jobs: 4 | build: 5 | runs-on: ubuntu-latest 6 | env: 7 | BUNDLE_BUILD__TORCH___RB: "--with-torch-dir=/home/runner/libtorch" 8 | LIBTORCH_VERSION: 2.8.0 9 | steps: 10 | - uses: actions/checkout@v4 11 | - uses: actions/cache@v4 12 | with: 13 | path: ~/libtorch 14 | key: libtorch-${{ env.LIBTORCH_VERSION }} 15 | id: cache-libtorch 16 | - name: Download LibTorch 17 | if: steps.cache-libtorch.outputs.cache-hit != 'true' 18 | run: | 19 | cd ~ 20 | wget -q -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-$LIBTORCH_VERSION%2Bcpu.zip 21 | unzip -q libtorch.zip 22 | - uses: ruby/setup-ruby@v1 23 | with: 24 | ruby-version: 3.4 25 | bundler-cache: true 26 | - run: bundle exec rake test 27 | -------------------------------------------------------------------------------- /lib/torchdata/data_pipes/iter/util/csv_parser.rb: -------------------------------------------------------------------------------- 1 | module TorchData 2 | module DataPipes 3 | module Iter 4 | module Util 5 | class CSVParser < IterDataPipe 6 | functional_datapipe :parse_csv 7 | 8 | def initialize(source_datapipe, delimiter: ",") 9 | @source_datapipe = source_datapipe 10 | @helper = PlainTextReaderHelper.new 11 | @fmtparams = {col_sep: delimiter} 12 | end 13 | 14 | def each(&block) 15 | @source_datapipe.each do |path, file| 16 | stream = @helper.skip_lines(file) 17 | stream = @helper.decode(stream) 18 | stream = CSV.parse(stream, **@fmtparams) 19 | stream = @helper.as_tuple(stream) 20 | @helper.return_path(stream, path: path).each(&block) 21 | end 22 | end 23 | end 24 | 25 | class PlainTextReaderHelper 26 | def skip_lines(file) 27 | file 28 | end 29 | 30 | def decode(stream) 31 | stream 32 | end 33 | 34 | def return_path(stream, path: nil) 35 | stream 36 | end 37 | 38 | def as_tuple(stream) 39 | stream 40 | end 41 | end 42 | end 43 | end 44 | end 45 | end 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchData Ruby 2 | 3 | Composable data loading for Ruby 4 | 5 | [![Build Status](https://github.com/ankane/torchdata-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/torchdata-ruby/actions) 6 | 7 | ## Installation 8 | 9 | Add this line to your application’s Gemfile: 10 | 11 | ```ruby 12 | gem "torchdata" 13 | ``` 14 | 15 | ## Getting Started 16 | 17 | This library follows the [Python API](https://pytorch.org/data/). Many methods and options are missing at the moment. PRs welcome! 18 | 19 | ```ruby 20 | folder = "path/to/csv/folder" 21 | datapipe = TorchData::DataPipes::Iter::FileLister.new([folder]).filter { |filename| filename.end_with?(".csv") } 22 | datapipe = TorchData::DataPipes::Iter::FileOpener.new(datapipe, mode: "rt") 23 | datapipe = datapipe.parse_csv(delimiter: ",") 24 | train, valid = datapipe.random_split(total_length: 10000, weights: {train: 0.5, valid: 0.5}, seed: 0) 25 | 26 | train.each do |x| 27 | # code 28 | end 29 | 30 | valid.each do |y| 31 | # code 32 | end 33 | ``` 34 | 35 | ## History 36 | 37 | View the [changelog](CHANGELOG.md) 38 | 39 | ## Contributing 40 | 41 | Everyone is encouraged to help improve this project. Here are a few ways you can help: 42 | 43 | - [Report bugs](https://github.com/ankane/torchdata-ruby/issues) 44 | - Fix bugs and [submit pull requests](https://github.com/ankane/torchdata-ruby/pulls) 45 | - Write, clarify, or fix documentation 46 | - Suggest or add new features 47 | 48 | To get started with development: 49 | 50 | ```sh 51 | git clone https://github.com/ankane/torchdata-ruby.git 52 | cd torchdata-ruby 53 | bundle install 54 | bundle exec rake test 55 | ``` 56 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021-present, Facebook, Inc. 4 | Copyright (c) 2023, Andrew Kane. 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /lib/torchdata/data_pipes/iter/util/random_splitter.rb: -------------------------------------------------------------------------------- 1 | module TorchData 2 | module DataPipes 3 | module Iter 4 | module Util 5 | class RandomSplitter < IterDataPipe 6 | functional_datapipe :random_split 7 | 8 | def self.new(source_datapipe, weights:, seed:, total_length: nil, target: nil) 9 | if total_length.nil? 10 | begin 11 | total_length = source_datapipe.length 12 | rescue NoMethodError 13 | raise TypeError, "RandomSplitter needs `total_length`, but it is unable to infer it from the `source_datapipe`: #{source_datapipe}." 14 | end 15 | end 16 | 17 | container = InternalRandomSplitterIterDataPipe.new(source_datapipe, total_length, weights, seed) 18 | 19 | if target.nil? 20 | weights.map { |k, _| SplitterIterator.new(container, k) } 21 | else 22 | raise "todo" 23 | end 24 | end 25 | end 26 | 27 | class InternalRandomSplitterIterDataPipe < IterDataPipe 28 | attr_reader :source_datapipe 29 | 30 | def initialize(source_datapipe, total_length, weights, seed) 31 | @source_datapipe = source_datapipe 32 | @total_length = total_length 33 | @remaining_length = @total_length 34 | @seed = seed 35 | @keys = weights.keys 36 | @key_to_index = @keys.map.with_index.to_h 37 | @norm_weights = self.class.normalize_weights(@keys.map { |k| weights[k] }, total_length) 38 | @weights = @norm_weights.dup 39 | @rng = Random.new(@seed) 40 | @lengths = [] 41 | end 42 | 43 | def draw 44 | selected_key = choices(@rng, @keys, @weights) 45 | index = @key_to_index[selected_key] 46 | @weights[index] -= 1 47 | @remaining_length -= 1 48 | if @weights[index] < 0 49 | @weights[index] = 0 50 | @weights = self.class.normalize_weights(@weights, @remaining_length) 51 | end 52 | selected_key 53 | end 54 | 55 | def self.normalize_weights(weights, total_length) 56 | total_weight = weights.sum 57 | weights.map { |w| w.to_f * total_length / total_weight } 58 | end 59 | 60 | def reset 61 | @rng = Random.new(@seed) 62 | @weights = @norm_weights.dup 63 | @remaining_length = @total_length 64 | end 65 | 66 | def override_seed(seed) 67 | @seed = seed 68 | self 69 | end 70 | 71 | def get_length(target) 72 | raise "todo" 73 | end 74 | 75 | private 76 | 77 | def choices(rng, keys, weights) 78 | total = weights.sum 79 | x = rng.rand * total 80 | weights.each_with_index do |w, i| 81 | return keys[i] if x < w 82 | x -= w 83 | end 84 | keys[-1] 85 | end 86 | end 87 | 88 | class SplitterIterator < IterDataPipe 89 | def initialize(main_datapipe, target) 90 | @main_datapipe = main_datapipe 91 | @target = target 92 | end 93 | 94 | def each 95 | @main_datapipe.reset 96 | @main_datapipe.source_datapipe.each do |sample| 97 | if @main_datapipe.draw == @target 98 | yield sample 99 | end 100 | end 101 | end 102 | 103 | def override_seed(seed) 104 | @main_datapipe.override_seed(seed) 105 | end 106 | 107 | def length 108 | @main_datapipe.get_length(@target) 109 | end 110 | end 111 | end 112 | end 113 | end 114 | end 115 | --------------------------------------------------------------------------------