├── .Rbuildignore ├── .github ├── .gitignore └── workflows │ ├── R-CMD-check.yaml │ └── pkgdown.yaml ├── .gitignore ├── DESCRIPTION ├── LICENSE ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── bank-marketing.R ├── bird-species.R ├── cityscapes-pix2pix.R ├── dogs-vs-cats.R ├── guess-the-correlation.R ├── imdb.R ├── oxford-flowers-dataset.R ├── oxford-pet-dataset.R ├── spam-dataset.R └── utils.R ├── README.md ├── _pkgdown.yml ├── cran-comments.md ├── man ├── bank_marketing_dataset.Rd ├── bird_species_dataset.Rd ├── cityscapes_pix2pix_dataset.Rd ├── dogs_vs_cats_dataset.Rd ├── guess_the_correlation_dataset.Rd ├── imdb_dataset.Rd ├── oxford_flowers102_dataset.Rd ├── oxford_pet_dataset.Rd └── spam_dataset.Rd ├── tests ├── testthat.R └── testthat │ ├── helper-torchdatasets.R │ ├── test-bank-marketing.R │ ├── test-bird-species.R │ ├── test-cityscapes-pix2pix.R │ ├── test-dogs-vs-cats.R │ ├── test-guess-the-correlation.R │ ├── test-imdb.R │ ├── test-oxford-flowers-dataset.R │ ├── test-oxford-pet-dataset.R │ └── test-spam-dataset.R └── torchdatasets.Rproj /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^torchdatasets\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^\.github$ 4 | ^LICENSE\.md$ 5 | ^_pkgdown\.yml$ 6 | ^docs$ 7 | ^pkgdown$ 8 | ^cran-comments\.md$ 9 | ^gtc$ 10 | ^examples$ 11 | ^CRAN-RELEASE$ 12 | ^imdb 13 | ^CRAN-SUBMISSION$ 14 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # For help debugging build failures open an issue on the RStudio community with the 'github-actions' tag. 2 | # https://community.rstudio.com/new-topic?category=Package%20development&tags=github-actions 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | schedule: 9 | - cron: "0 1 * * *" 10 | workflow_dispatch: 11 | 12 | name: R-CMD-check 13 | 14 | jobs: 15 | R-CMD-check: 16 | runs-on: ${{ matrix.config.os }} 17 | 18 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 19 | 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | config: 24 | - {os: windows-latest, r: 'release'} 25 | - {os: macOS-latest, r: 'release'} 26 | - {os: ubuntu-20.04, r: 'release'} 27 | 28 | env: 29 | R_REMOTES_NO_ERRORS_FROM_WARNINGS: true 30 | TORCH_INSTALL: 1 31 | TORCH_TEST: 1 32 | 33 | steps: 34 | - uses: actions/checkout@v2 35 | 36 | - uses: r-lib/actions/setup-r@v2 37 | with: 38 | r-version: ${{ matrix.config.r }} 39 | 40 | - uses: r-lib/actions/setup-pandoc@v2 41 | 42 | - name: Query dependencies 43 | run: | 44 | install.packages('remotes') 45 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2) 46 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version") 47 | shell: Rscript {0} 48 | 49 | - name: Install magick 50 | if: runner.os == 'Linux' 51 | run: | 52 | sudo apt-get install -y libmagick++-dev 53 | 54 | - name: Install system dependencies 55 | if: runner.os == 'Linux' 56 | run: | 57 | sudo rm -rf /usr/share/dotnet 58 | sudo rm -rf /opt/ghc 59 | while read -r cmd 60 | do 61 | eval sudo $cmd 62 | done < <(Rscript -e 'writeLines(remotes::system_requirements("ubuntu", "20.04"))') 63 | 64 | - name: Install dependencies 65 | run: | 66 | install.packages("processx") 67 | remotes::install_deps(dependencies = TRUE) 68 | remotes::install_cran("rcmdcheck") 69 | shell: Rscript {0} 70 | 71 | - name: Install Kaggle auth 72 | env: 73 | KAGGLE: ${{ secrets.KAGGLE }} 74 | run: | 75 | writeLines(c(Sys.getenv("KAGGLE"), "\n"), "tests/testthat/kaggle.json") 76 | shell: Rscript {0} 77 | 78 | - name: Check 79 | env: 80 | _R_CHECK_CRAN_INCOMING_REMOTE_: false 81 | run: rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran"), error_on = "error", check_dir = "check") 82 | shell: Rscript {0} 83 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | - master 6 | 7 | name: pkgdown 8 | 9 | jobs: 10 | pkgdown: 11 | runs-on: macOS-latest 12 | env: 13 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 14 | TORCH_INSTALL: 1 15 | TORCH_TEST: 1 16 | steps: 17 | - uses: actions/checkout@v2 18 | 19 | - uses: r-lib/actions/setup-r@v2 20 | 21 | - uses: r-lib/actions/setup-pandoc@v2 22 | 23 | - name: Query dependencies 24 | run: | 25 | install.packages('remotes') 26 | saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2) 27 | writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version") 28 | shell: Rscript {0} 29 | 30 | - name: Cache R packages 31 | uses: actions/cache@v2 32 | with: 33 | path: ${{ env.R_LIBS_USER }} 34 | key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }} 35 | restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1- 36 | 37 | - name: Install dependencies 38 | run: | 39 | remotes::install_deps(dependencies = TRUE) 40 | install.packages("pkgdown", type = "binary") 41 | shell: Rscript {0} 42 | 43 | - name: Install package 44 | run: R CMD INSTALL . 45 | 46 | - name: Deploy package 47 | run: | 48 | git config --local user.email "actions@github.com" 49 | git config --local user.name "GitHub Actions" 50 | Rscript -e 'pkgdown::deploy_to_branch(new_process = FALSE)' 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | examples 5 | gtc 6 | docs 7 | examples 8 | gtc 9 | imdb 10 | dogs-vs-cats 11 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: torchdatasets 2 | Title: Ready to Use Extra Datasets for Torch 3 | Version: 0.3.1.9000 4 | Authors@R: 5 | c( 6 | person(given = "Daniel", 7 | family = "Falbel", 8 | role = c("aut", "cre"), 9 | email = "daniel@rstudio.com" 10 | ), 11 | person(family = "RStudio", role = c("cph")) 12 | ) 13 | Description: Provides datasets in a format that can be easily consumed by torch 'dataloaders'. 14 | Handles data downloading from multiple sources, caching and pre-processing so 15 | users can focus only on their model implementations. 16 | License: MIT + file LICENSE 17 | Encoding: UTF-8 18 | Roxygen: list(markdown = TRUE) 19 | RoxygenNote: 7.3.2 20 | Imports: 21 | torch (>= 0.5.0), 22 | fs, 23 | zip, 24 | pins, 25 | torchvision, 26 | stringr, 27 | withr, 28 | utils 29 | Suggests: 30 | testthat, 31 | readr, 32 | coro, 33 | tokenizers, 34 | R.matlab 35 | URL: https://mlverse.github.io/torchdatasets/, https://github.com/mlverse/torchdatasets 36 | BugReports: https://github.com/mlverse/torchdatasets/issues 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2020 2 | COPYRIGHT HOLDER: RStudio, PBC 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2020 Daniel Falbel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export(bank_marketing_dataset) 4 | export(bird_species_dataset) 5 | export(cityscapes_pix2pix_dataset) 6 | export(dogs_vs_cats_dataset) 7 | export(guess_the_correlation_dataset) 8 | export(imdb_dataset) 9 | export(oxford_flowers102_dataset) 10 | export(oxford_pet_dataset) 11 | export(spam_dataset) 12 | importFrom(stringr,str_extract) 13 | importFrom(torch,dataset) 14 | importFrom(torchvision,base_loader) 15 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # torchdatasets (development version) 2 | 3 | # torchdatasets 0.3.1 4 | 5 | # torchdatasets 0.3.0 6 | 7 | - Added Oxford Flowers Dataset (#32) 8 | 9 | # torchdatasets 0.2.0 10 | 11 | - Fixed issues with the dogs-vs-cats dataset. (#29) 12 | - Better handle the bird species dataset. 13 | 14 | # torchdatasets 0.1.0 15 | 16 | * Added the bank marketing dataset (#10, #11, #12, #13, @dkgaraujo) 17 | * Added cityscapes dataset used in the pix2pix paper. (#14) 18 | * Fixed bugwith newer version of the bird species dataset. (#17) 19 | * Added the oxford_pets_dataset. (#17) 20 | * Added a `maybe_download` util that is used to unify the download/extract code paths. (#22) 21 | * Added the IMDB dataset (#23) 22 | 23 | # torchdatasets 0.0.1 24 | 25 | * Added a `NEWS.md` file to track changes to the package. 26 | -------------------------------------------------------------------------------- /R/bank-marketing.R: -------------------------------------------------------------------------------- 1 | #' Bank marketing dataset 2 | #' 3 | #' Prepares the Bank marketing dataset available on UCI Machine Learning repository [here](https://archive.ics.uci.edu/ml/datasets/Bank+Marketing) 4 | #' The data is available publicly for download, there is no need to authenticate. 5 | #' Please cite the data as Moro et al., 2014 6 | #' S. Moro, P. Cortez and P. Rita. A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems, Elsevier, 62:22-31, June 2014 7 | #' 8 | #' @param root path to the data location 9 | #' @param split string. 'train' or 'submission' 10 | #' @param indexes set of integers for subsampling (e.g. 1:41188) 11 | #' @param download whether to download or not 12 | #' @param with_call_duration whether the call duration should be included as a feature. Could lead to leakage. Default: FALSE. 13 | #' 14 | #' @return A torch dataset that can be consumed with [torch::dataloader()]. 15 | #' @examples 16 | #' if (torch::torch_is_installed() && FALSE) { 17 | #' bank_mkt <- bank_marketing_dataset("./data", download = TRUE) 18 | #' length(bank_mkt) 19 | #' } 20 | #' @export 21 | bank_marketing_dataset <- torch::dataset( 22 | "BankMarketing", 23 | initialize = function(root, split = "train", indexes = NULL, download = FALSE, with_call_duration = FALSE) { 24 | 25 | # download ---------------------------------------------------------- 26 | data_path <- maybe_download( 27 | url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00222/bank-additional.zip", 28 | root = root, 29 | name = "bank-marketing", 30 | download = download, 31 | extract_fun = function(tmp, data_path) { 32 | unzip2(tmp, exdir = data_path) 33 | } 34 | ) 35 | 36 | if(tolower(split) != "train") { 37 | stop("The bank marketing dataset only has a `train` split") 38 | } 39 | 40 | self$.path <- file.path(data_path, "bank-additional") 41 | 42 | dataset <- read.csv2(fs::path(data_path, "bank-additional/bank-additional-full.csv")) 43 | 44 | if (!with_call_duration) 45 | dataset <- dataset[,-which(colnames(dataset)=="duration")] 46 | 47 | # one-hot encode unordered categorical features 48 | 49 | unordered_categorical_features <- c("default", 50 | "job", 51 | "marital", 52 | "housing", 53 | "loan", 54 | "contact", 55 | "month", 56 | "day_of_week", 57 | "poutcome") 58 | for (catvar in unordered_categorical_features) { 59 | tmp_df <- model.matrix(~ 0 + as.data.frame(dataset)[,catvar]) 60 | colnames(tmp_df) <- paste(catvar, levels(as.factor(as.data.frame(dataset)[,catvar])), sep = "_") 61 | dataset <- dataset[,-which(colnames(dataset)==catvar)] 62 | dataset <- cbind(dataset, tmp_df) 63 | } 64 | # encodes with integers the only ordered categorical feature, education 65 | 66 | educ_factors <- c("unknown", 67 | "illiterate", 68 | "basic.4y", 69 | "basic.6y", 70 | "basic.9y", 71 | "high.school", 72 | "professional.course", 73 | "university.degree") 74 | educ <- factor(dataset[, "education"], order = TRUE, levels = educ_factors) 75 | dataset[, "education"] <- as.numeric(educ) 76 | dataset[, "y"] <- ifelse(dataset[, "y"] == "yes", 1, 0) 77 | 78 | # attributes the numbers to the data instance 79 | 80 | self$features <- as.matrix(dataset[,-which(colnames(dataset)=="y")]) 81 | 82 | self$target <- dataset[,"y"] 83 | }, 84 | 85 | .getitem = function(index) { 86 | 87 | force(index) 88 | 89 | x <- self$features[index, ] 90 | y <- self$target[index] 91 | 92 | x <- torch::torch_tensor(as.numeric(unlist(x))) 93 | y <- torch::torch_scalar_tensor(y) 94 | 95 | return(list(x = x, y = y)) 96 | }, 97 | 98 | .length = function() { 99 | nrow(self$features) 100 | } 101 | ) 102 | -------------------------------------------------------------------------------- /R/bird-species.R: -------------------------------------------------------------------------------- 1 | 2 | #' Bird species dataset 3 | #' 4 | #' 5 | #' Downloads and prepares the 450 bird species dataset found on Kaggle. 6 | #' The dataset description, license, etc can be found [here](https://www.kaggle.com/datasets/gpiosenka/100-bird-species). 7 | #' 8 | #' 9 | #' @param root path to the data location 10 | #' @param split train, test or valid 11 | #' @param download wether to download or not 12 | #' @param ... other arguments passed to [torchvision::image_folder_dataset()]. 13 | #' 14 | #' @return A [torch::dataset()] ready to be used with dataloaders. 15 | #' 16 | #' @examples 17 | #' if (torch::torch_is_installed() && FALSE) { 18 | #' birds <- bird_species_dataset("./data", token = "path/to/kaggle.json", 19 | #' download = TRUE) 20 | #' length(birds) 21 | #' } 22 | #' @export 23 | bird_species_dataset <- torch::dataset( 24 | inherit = torchvision::image_folder_dataset, 25 | initialize = function(root, split = "train", download = FALSE, ...) { 26 | 27 | url <- "https://torch-cdn.mlverse.org/datasets/bird-species.zip" 28 | data_path <- maybe_download( 29 | root = root, 30 | name = "bird-species", 31 | url = url, 32 | download = download, 33 | extract_fun = function(temp, data_path) { 34 | unzip2(temp, exdir = data_path) 35 | } 36 | ) 37 | 38 | if (!fs::dir_exists(data_path)) 39 | cli::cli_abort("No data found. Please use `download = TRUE`.") 40 | 41 | possible_splits <- c("train", "valid", "test") 42 | if (!split %in% possible_splits) { 43 | cli::cli_abort(c( 44 | "Found split {.val {split}} but expected one of {.or {.val {possible_splits}}}." 45 | )) 46 | } 47 | 48 | p <- fs::path(data_path, split) 49 | super$initialize(root = p, ...) 50 | } 51 | ) 52 | 53 | -------------------------------------------------------------------------------- /R/cityscapes-pix2pix.R: -------------------------------------------------------------------------------- 1 | #' Cityscapes Pix2Pix dataset 2 | #' 3 | #' Downloads and prepares the cityscapes dataset that has been used in the 4 | #' [pix2pix paper](https://arxiv.org/abs/1611.07004). 5 | #' 6 | #' Find more information in the [project website](https://phillipi.github.io/pix2pix/) 7 | #' 8 | #' @inheritParams bird_species_dataset 9 | #' @inheritParams torchvision::image_folder_dataset 10 | #' @param ... Currently unused. 11 | #' 12 | #' @export 13 | cityscapes_pix2pix_dataset <- torch::dataset( 14 | "CityscapesImagePairs", 15 | initialize = function(root, split = "train", download = FALSE, ..., 16 | transform = NULL, target_transform = NULL) { 17 | 18 | url <- "http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/cityscapes.tar.gz" 19 | 20 | data_path <- maybe_download( 21 | root = root, 22 | url = url, 23 | name = "cityscapes-image-pairs", 24 | download = download, 25 | extract_fun = function(f, exdir) { 26 | untar(f, exdir = exdir) 27 | } 28 | ) 29 | 30 | self$split <- split 31 | 32 | path <- fs::path( 33 | data_path, 34 | "cityscapes", 35 | ifelse(self$split == "train", "train", "val") 36 | ) 37 | 38 | self$files <- fs::dir_ls(path, glob = "*.jpg") 39 | self$transform <- if (is.null(transform)) identity else transform 40 | self$target_transform <- if (is.null(target_transform)) identity else target_transform 41 | }, 42 | .getitem = function(i) { 43 | img <- jpeg::readJPEG(self$files[i]) 44 | 45 | list( 46 | input_img = self$transform(img[,257:512,]), 47 | real_img = self$target_transform(img[,1:256,]) 48 | ) 49 | }, 50 | .length = function() { 51 | length(self$files) 52 | } 53 | ) 54 | 55 | -------------------------------------------------------------------------------- /R/dogs-vs-cats.R: -------------------------------------------------------------------------------- 1 | #' Dog vs cats dataset 2 | #' 3 | #' Prepares the dog vs cats dataset available in Kaggle 4 | #' [here](https://www.kaggle.com/c/dogs-vs-cats) 5 | #' 6 | #' @inheritParams guess_the_correlation_dataset 7 | #' @param ... Currently unused. 8 | #' 9 | #' @return A [torch::dataset()] ready to be used with dataloaders. 10 | #' @examples 11 | #' if (torch::torch_is_installed() && FALSE) { 12 | #' dogs_cats <- dogs_vs_cats_dataset("./data", token = "path/to/kaggle.json", 13 | #' download = TRUE) 14 | #' length(dogs_cats) 15 | #' } 16 | #' 17 | #' @importFrom torchvision base_loader 18 | #' @importFrom stringr str_extract 19 | #' @export 20 | dogs_vs_cats_dataset <- torch::dataset( 21 | classes = c("dog", "cat"), 22 | initialize = function(root, split = "train", download = FALSE, ..., transform = NULL, 23 | target_transform = NULL) { 24 | 25 | self$transform <- transform 26 | self$target_transform <- target_transform 27 | 28 | url <- "https://torch-cdn.mlverse.org/datasets/dogs-vs-cats.zip" 29 | 30 | data_path <- maybe_download( 31 | root = root, 32 | name = "dogs-vs-cats", 33 | url = url, 34 | download = download, 35 | extract_fun = function(temp, data_path) { 36 | unzip2(temp, exdir = data_path) 37 | unzip2(fs::path(data_path, "train.zip"), exdir = data_path) 38 | unzip2(fs::path(data_path, "test1.zip"), exdir = data_path) 39 | fs::file_delete(fs::path(data_path, "train.zip")) 40 | fs::file_delete(fs::path(data_path, "test1.zip")) 41 | } 42 | ) 43 | 44 | if (!fs::dir_exists(data_path)) 45 | cli::cli_abort("No data found. Please use `download = TRUE`.") 46 | 47 | if(split == "train") { 48 | self$images <- fs::dir_ls(fs::path(data_path, "train")) 49 | } else if(split == "test") { 50 | self$images <- fs::dir_ls(fs::path(data_path, "test1")) 51 | } else { 52 | cli::cli_abort(c( 53 | "Only 'train' and 'test' split are supported.", 54 | i = "Got {.str {split}}" 55 | )) 56 | } 57 | self$targets <- stringr::str_extract( 58 | fs::path_file(self$images), 59 | "[^.]+(?=\\.)" 60 | ) 61 | self$targets <- match(self$targets, self$classes) 62 | }, 63 | .getitem = function(i) { 64 | x <- base_loader(self$images[i]) 65 | y <- self$targets[i] 66 | 67 | if (!is.null(self$transform)) 68 | x <- self$transform(x) 69 | 70 | if (!is.null(self$target_transform)) 71 | y <- self$target_transform(y) 72 | 73 | list(x, y) 74 | }, 75 | .length = function() { 76 | length(self$images) 77 | } 78 | ) 79 | -------------------------------------------------------------------------------- /R/guess-the-correlation.R: -------------------------------------------------------------------------------- 1 | #' Guess The Correlation dataset 2 | #' 3 | #' Prepares the Guess The Correlation dataset available on Kaggle [here](https://www.kaggle.com/c/guess-the-correlation) 4 | #' A copy of this dataset is hosted in a public Google Cloud 5 | #' bucket so you don't need to authenticate. 6 | #' 7 | #' @param root path to the data location 8 | #' @param split string. 'train' or 'submission' 9 | #' @param transform function that takes a torch tensor representing an image and return another tensor, transformed. 10 | #' @param target_transform function that takes a scalar torch tensor and returns another tensor, transformed. 11 | #' @param indexes set of integers for subsampling (e.g. 1:140000) 12 | #' @param download whether to download or not 13 | #' 14 | #' @return A torch dataset that can be consumed with [torch::dataloader()]. 15 | #' @examples 16 | #' if (torch::torch_is_installed() && FALSE) { 17 | #' gtc <- guess_the_correlation_dataset("./data", download = TRUE) 18 | #' length(gtc) 19 | #' } 20 | #' @export 21 | guess_the_correlation_dataset <- torch::dataset( 22 | "GuessTheCorrelation", 23 | initialize = function(root, split = "train", transform = NULL, target_transform = NULL, indexes = NULL, download = FALSE) { 24 | 25 | self$transform <- transform 26 | self$target_transform <- target_transform 27 | 28 | # donwload ---------------------------------------------------------- 29 | data_path <- maybe_download( 30 | root = root, 31 | name = "guess-the-correlation", 32 | url = "https://torch-cdn.mlverse.org/datasets/guess-the-correlation.zip", 33 | download = download, 34 | extract_fun = function(temp, data_path) { 35 | unzip2(temp, exdir = data_path) 36 | unzip2(fs::path(data_path, "train_imgs.zip"), exdir = data_path) 37 | unzip2(fs::path(data_path, "test_imgs.zip"), exdir = data_path) 38 | } 39 | ) 40 | 41 | # variavel resposta ------------------------------------------------- 42 | 43 | if(split == "train") { 44 | self$images <- readr::read_csv(fs::path(data_path, "train.csv"), col_types = c("cn")) 45 | if(!is.null(indexes)) self$images <- self$images[indexes, ] 46 | self$.path <- file.path(data_path, "train_imgs") 47 | } else if(split == "submission") { 48 | self$images <- readr::read_csv(fs::path(data_path, "example_submition.csv"), col_types = c("cn")) 49 | self$images$corr <- NA_real_ 50 | self$.path <- file.path(data_path, "test_imgs") 51 | } 52 | }, 53 | 54 | .getitem = function(index) { 55 | 56 | force(index) 57 | 58 | sample <- self$images[index, ] 59 | id <- sample$id 60 | x <- torchvision::base_loader(file.path(self$.path, paste0(sample$id, ".png"))) 61 | x <- torchvision::transform_to_tensor(x) %>% torchvision::transform_rgb_to_grayscale() 62 | 63 | if (!is.null(self$transform)) 64 | x <- self$transform(x) 65 | 66 | y <- torch::torch_scalar_tensor(sample$corr) 67 | if (!is.null(self$target_transform)) 68 | y <- self$target_transform(y) 69 | 70 | return(list(x = x, y = y, id = id)) 71 | }, 72 | 73 | .length = function() { 74 | nrow(self$images) 75 | } 76 | ) 77 | -------------------------------------------------------------------------------- /R/imdb.R: -------------------------------------------------------------------------------- 1 | #' IMDB movie review sentiment classification dataset 2 | #' 3 | #' The format of this dataset is meant to replicate that provided by 4 | #' [Keras](https://keras.io/api/datasets/imdb/). 5 | #' 6 | #' @inheritParams bird_species_dataset 7 | #' @param shuffle whether to shuffle or not the dataset. `TRUE` if `split=="train"` 8 | #' @param num_words Words are ranked by how often they occur (in the training set), 9 | #' and only the num_words most frequent words are kept. Any less frequent word 10 | #' will appear as oov_char value in the sequence data. If `Inf`, all words are 11 | #' kept. Defaults to None, so all words are kept. 12 | #' @param skip_top skip the top N most frequently occurring words (which may not be informative). 13 | #' These words will appear as oov_char value in the dataset. Defaults to 0, so 14 | #' no words are skipped. 15 | #' @param maxlen int or `Inf`. Maximum sequence length. Any longer sequence will 16 | #' be truncated. Defaults to Inf, which means no truncation. 17 | #' @param start_char The start of a sequence will be marked with this character. 18 | #' Defaults to 2, because 1 is usually the padding character. 19 | #' @param oov_char int. The out-of-vocabulary character. Words that were cut out 20 | #' because of the num_words or skip_top limits will be replaced with this character. 21 | #' @param index_from int. Index actual words with this index and higher. 22 | # 23 | #' 24 | #' @export 25 | imdb_dataset <- torch::dataset( 26 | initialize = function(root, download = FALSE, split = "train", shuffle = (split == "train"), 27 | num_words = Inf, skip_top = 0, maxlen = Inf, 28 | start_char = 2, oov_char = 3, index_from = 4) { 29 | 30 | rlang::check_installed("tokenizers") 31 | 32 | url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" 33 | data_path <- maybe_download( 34 | url = url, 35 | root = root, 36 | download = download, 37 | name = "imdb", 38 | extract_fun = function(tmp, expath) { 39 | untar(tmp, exdir = expath) 40 | } 41 | ) 42 | self$data_path <- data_path 43 | 44 | if (!split %in% c("train", "test")) 45 | rlang::abort(paste0("Unknown split `", split, "`")) 46 | 47 | texts <- self$read_and_tokenize(split) 48 | response <- texts$response 49 | texts <- texts$texts 50 | 51 | vocabulary <- self$get_vocabulary() 52 | 53 | if (skip_top > 0) 54 | vocabulary <- vocabulary[-seq_len(skip_top)] 55 | 56 | if (num_words < length(vocabulary)) 57 | vocabulary <- vocabulary[seq_len(num_words)] 58 | 59 | if (shuffle) { 60 | new_order <- sample.int(length(texts)) 61 | texts <- texts[new_order] 62 | response <- response[new_order] 63 | } 64 | 65 | self$texts <- texts 66 | self$response <- response 67 | self$vocabulary <- vocabulary 68 | self$start_char <- start_char 69 | self$oov_char <- oov_char 70 | self$maxlen <- maxlen 71 | self$index_from <- index_from 72 | }, 73 | .getitem = function(i) { 74 | words <- self$texts[[i]] 75 | 76 | # word indexes start at 1, but we want them to start from `index_from` 77 | int <- match(words, names(self$vocabulary)) + as.integer(self$index_from - 1) 78 | int[is.na(int)] <- as.integer(self$oov_char) 79 | int <- c(as.integer(self$start_char), int) 80 | 81 | if (is.finite(self$maxlen)) { 82 | if (length(int) >= self$maxlen) { 83 | int <- int[seq_len(self$maxlen)] 84 | } else { 85 | int <- c(rep(1L, self$maxlen - length(int)), int) 86 | } 87 | } 88 | 89 | list( 90 | x = int, 91 | y = self$response[i] 92 | ) 93 | }, 94 | .length = function() { 95 | length(self$texts) 96 | }, 97 | get_vocabulary = function() { 98 | 99 | data_path <- self$data_path 100 | cached <- fs::path(data_path, "aclImdb", "cached-vocab.rds") 101 | if (!fs::file_exists(cached)) { 102 | texts <- self$read_and_tokenize("train")$texts 103 | vocabulary <- texts %>% 104 | unlist() %>% 105 | table() %>% 106 | sort(decreasing = TRUE) 107 | saveRDS(vocabulary, file = cached) 108 | } else { 109 | vocabulary <- readRDS(cached) 110 | } 111 | 112 | vocabulary 113 | }, 114 | read_and_tokenize = function(split) { 115 | 116 | data_path <- self$data_path 117 | cached <- fs::path(data_path, "aclImdb", split, "cached.rds") 118 | 119 | if (!fs::file_exists(cached)) { 120 | 121 | pos <- fs::dir_ls(fs::path(data_path, "aclImdb", split, "pos")) 122 | neg <- fs::dir_ls(fs::path(data_path, "aclImdb", split, "neg")) 123 | 124 | texts <- sapply(c(pos, neg), function(x) readr::read_file(x)) %>% 125 | tokenizers::tokenize_words() 126 | 127 | response <- c( 128 | rep(1, length.out = length(pos)), 129 | rep(0, length.out = length(neg)) 130 | ) 131 | 132 | rlang::inform(paste0("Caching tokenized texts for split: ", split)) 133 | saveRDS( 134 | list(texts = texts, response = response), 135 | file = cached 136 | ) 137 | } else { 138 | texts <- readRDS(cached) 139 | 140 | response <- texts$response 141 | texts <- texts$texts 142 | } 143 | 144 | list( 145 | texts = texts, 146 | response = response 147 | ) 148 | } 149 | ) 150 | -------------------------------------------------------------------------------- /R/oxford-flowers-dataset.R: -------------------------------------------------------------------------------- 1 | flower_categories <- c( 2 | "pink primrose", 3 | "hard-leaved pocket orchid", 4 | "canterbury bells", 5 | "sweet pea", 6 | "english marigold", 7 | "tiger lily", 8 | "moon orchid", 9 | "bird of paradise", 10 | "monkshood", 11 | "globe thistle", 12 | "snapdragon", 13 | "colt's foot", 14 | "king protea", 15 | "spear thistle", 16 | "yellow iris", 17 | "globe-flower", 18 | "purple coneflower", 19 | "peruvian lily", 20 | "balloon flower", 21 | "giant white arum lily", 22 | "fire lily", 23 | "pincushion flower", 24 | "fritillary", 25 | "red ginger", 26 | "grape hyacinth", 27 | "corn poppy", 28 | "prince of wales feathers", 29 | "stemless gentian", 30 | "artichoke", 31 | "sweet william", 32 | "carnation", 33 | "garden phlox", 34 | "love in the mist", 35 | "mexican aster", 36 | "alpine sea holly", 37 | "ruby-lipped cattleya", 38 | "cape flower", 39 | "great masterwort", 40 | "siam tulip", 41 | "lenten rose", 42 | "barbeton daisy", 43 | "daffodil", 44 | "sword lily", 45 | "poinsettia", 46 | "bolero deep blue", 47 | "wallflower", 48 | "marigold", 49 | "buttercup", 50 | "oxeye daisy", 51 | "common dandelion", 52 | "petunia", 53 | "wild pansy", 54 | "primula", 55 | "sunflower", 56 | "pelargonium", 57 | "bishop of llandaff", 58 | "gaura", 59 | "geranium", 60 | "orange dahlia", 61 | "pink-yellow dahlia?", 62 | "cautleya spicata", 63 | "japanese anemone", 64 | "black-eyed susan", 65 | "silverbush", 66 | "californian poppy", 67 | "osteospermum", 68 | "spring crocus", 69 | "bearded iris", 70 | "windflower", 71 | "tree poppy", 72 | "gazania", 73 | "azalea", 74 | "water lily", 75 | "rose", 76 | "thorn apple", 77 | "morning glory", 78 | "passion flower", 79 | "lotus", 80 | "toad lily", 81 | "anthurium", 82 | "frangipani", 83 | "clematis", 84 | "hibiscus", 85 | "columbine", 86 | "desert-rose", 87 | "tree mallow", 88 | "magnolia", 89 | "cyclamen", 90 | "watercress", 91 | "canna lily", 92 | "hippeastrum", 93 | "bee balm", 94 | "ball moss", 95 | "foxglove", 96 | "bougainvillea", 97 | "camellia", 98 | "mallow", 99 | "mexican petunia", 100 | "bromelia", 101 | "blanket flower", 102 | "trumpet creeper", 103 | "blackberry lily" 104 | ) 105 | 106 | #' 102 Category Flower Dataset 107 | #' 108 | #' The Oxford Flower Dataset is a 102 category dataset, consisting of 102 flower 109 | #' categories. The flowers chosen to be flower commonly occuring in the United 110 | #' Kingdom. Each class consists of between 40 and 258 images. The details of the 111 | #' categories and the number of images for each class can be found on 112 | #' [this category statistics page](https://www.robots.ox.ac.uk/%7Evgg/data/flowers/102/categories.html). 113 | #' 114 | #' The images have large scale, pose and light variations. In addition, there are 115 | #' categories that have large variations within the category and several very 116 | #' similar categories. The dataset is visualized using isomap with shape and colour 117 | #' features. 118 | #' 119 | #' You can find more info in the dataset [webpage](https://www.robots.ox.ac.uk/%7Evgg/data/flowers/102/). 120 | #' 121 | #' @note The official splits leaves far too many images in the test set. Depending 122 | #' on your work you might want to create different train/valid/test splits. 123 | #' 124 | #' @inheritParams oxford_pet_dataset 125 | #' @param target_type Currently only 'categories' is supported. 126 | #' @importFrom torch dataset 127 | #' @export 128 | oxford_flowers102_dataset <- torch::dataset( 129 | "OxfordFlowers102", 130 | classes = flower_categories, 131 | initialize = function(root, split = "train", target_type = c("categories"), 132 | download = FALSE, ..., transform = NULL, target_transform = NULL) { 133 | rlang::check_installed(c("R.matlab")) 134 | 135 | data_path <- fs::path_expand(fs::path(root, "oxford-flowers102")) 136 | self$data_path <- data_path 137 | 138 | if (!fs::dir_exists(data_path) && download) { 139 | 140 | images <- download_file( 141 | "https://torch-cdn.mlverse.org/datasets/oxford_flowers102/102flowers.tgz", 142 | tempfile(fileext = ".tgz") 143 | ) 144 | 145 | targets <- download_file( 146 | "https://torch-cdn.mlverse.org/datasets/oxford_flowers102/imagelabels.mat", 147 | tempfile(fileext = ".mat"), 148 | mode = "wb" 149 | ) 150 | 151 | splits <- download_file( 152 | "https://torch-cdn.mlverse.org/datasets/oxford_flowers102/setid.mat", 153 | tempfile(fileext = ".mat"), 154 | mode = "wb" 155 | ) 156 | 157 | fs::dir_create(data_path) 158 | untar(images, exdir = data_path) 159 | fs::file_move(targets, fs::path(data_path, "imagelabels.mat")) 160 | fs::file_move(splits, fs::path(data_path, "setid.mat")) 161 | } 162 | 163 | if (!fs::dir_exists(data_path)) 164 | cli::cli_abort("No data found. Please use {.var download = TRUE}.") 165 | 166 | self$split <- split 167 | splits <- R.matlab::readMat(fs::path(self$data_path, "setid.mat")) 168 | splits <- lapply(splits, as.integer) 169 | names(splits) <- c("train", "valid", "test") 170 | 171 | self$target_type <- target_type 172 | targets <- R.matlab::readMat(fs::path(self$data_path, "imagelabels.mat")) 173 | targets <- as.integer(targets$labels) 174 | 175 | ids <- unlist(splits[names(splits) %in% self$split]) 176 | self$targets <- targets[ids] 177 | 178 | self$imgs <- fs::path( 179 | self$data_path, 180 | "jpg", 181 | sprintf("image_%05d.jpg", ids) 182 | ) 183 | 184 | self$transform <- if (is.null(transform)) identity else transform 185 | self$target_transform <- if (is.null(target_transform)) identity else target_transform 186 | }, 187 | .getitem = function(i) { 188 | list( 189 | x = self$transform(jpeg::readJPEG(self$imgs[i])), 190 | y = self$target_transform(self$targets[i]) 191 | ) 192 | }, 193 | .length = function() { 194 | length(self$imgs) 195 | } 196 | ) 197 | -------------------------------------------------------------------------------- /R/oxford-pet-dataset.R: -------------------------------------------------------------------------------- 1 | #' Oxford Pet Dataset 2 | #' 3 | #' The Oxford-IIIT Pet Dataset is a 37 category pet dataset with roughly 4 | #' 200 images for each class. The images have a large variations in scale, 5 | #' pose and lighting. All images have an associated ground truth annotation of 6 | #' species (cat or dog), breed, and pixel-level trimap segmentation. 7 | #' 8 | #' @inheritParams cityscapes_pix2pix_dataset 9 | #' @param target_type The type of the target: 10 | #' - 'trimap': returns a mask array with one class per pixel. 11 | #' - 'species': returns the species id. 1 for cat and 2 for dog. 12 | #' - 'breed': returns the breed id. see `dataset$breed_classes`. 13 | #' 14 | #' @export 15 | oxford_pet_dataset <- torch::dataset( 16 | "OxfordPet", 17 | 18 | trimap_classes = c("Pixel belonging to the pet", "Pixel bordering the pet", "Surrounding pixel"), 19 | species_classes = c("cat", "dog"), 20 | breed_classes = c("Abyssinian", "american_bulldog", "american_pit_bull_terrier", 21 | "basset_hound", "beagle", "Bengal", "Birman", "Bombay", "boxer", 22 | "British_Shorthair", "chihuahua", "Egyptian_Mau", "english_cocker_spaniel", 23 | "english_setter", "german_shorthaired", "great_pyrenees", "havanese", 24 | "japanese_chin", "keeshond", "leonberger", "Maine_Coon", "miniature_pinscher", 25 | "newfoundland", "Persian", "pomeranian", "pug", "Ragdoll", "Russian_Blue", 26 | "saint_bernard", "samoyed", "scottish_terrier", "shiba_inu", 27 | "Siamese", "Sphynx", "staffordshire_bull_terrier", "wheaten_terrier", 28 | "yorkshire_terrier"), 29 | 30 | initialize = function(root, split = "train", target_type = c("trimap", "species", "breed"), 31 | download = FALSE, ..., transform = NULL, target_transform = NULL) { 32 | 33 | rlang::check_installed("readr") 34 | 35 | data_path <- fs::path_expand(fs::path(root, "oxford-pet")) 36 | self$data_path <- data_path 37 | 38 | if (!fs::dir_exists(data_path) && download) { 39 | 40 | images <- download_file( 41 | "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", 42 | tempfile(fileext = ".tar.gz") 43 | ) 44 | 45 | targets <- download_file( 46 | "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", 47 | tempfile(fileext = ".tar.gz") 48 | ) 49 | 50 | fs::dir_create(data_path) 51 | untar(images, exdir = data_path) 52 | untar(targets, exdir = data_path) 53 | } 54 | 55 | if (!fs::dir_exists(data_path)) 56 | stop("No data found. Please use `download = TRUE`.") 57 | 58 | self$split <- split 59 | self$target_type <- rlang::arg_match(target_type) 60 | self$target_reader <- get( 61 | paste0("read_", self$target_type), 62 | envir = self 63 | ) 64 | self$classes <- get( 65 | paste0(self$target_type, "_classes"), 66 | envir = self 67 | ) 68 | 69 | if (self$split == "train") { 70 | img_list <- fs::path(data_path, "annotations", "trainval.txt") 71 | } else { 72 | img_list <- fs::path(data_path, "annotations", "test.txt") 73 | } 74 | 75 | self$imgs <- readr::read_delim( 76 | img_list, 77 | delim = " ", 78 | col_names = c("image", "class_id", "species_id", "breed_id"), 79 | col_types = readr::cols() 80 | ) 81 | 82 | self$transform <- if (is.null(transform)) identity else transform 83 | self$target_transform <- if (is.null(target_transform)) identity else target_transform 84 | 85 | # rename files known to be PNG's 86 | self$imgs$ext <- "jpg" 87 | pngs <- c("Egyptian_Mau_14", "Egyptian_Mau_156", "Egyptian_Mau_186", "Abyssinian_5") 88 | self$imgs$ext[self$imgs$image %in% pngs] <- "png" 89 | 90 | # remove corrupt file 91 | self$imgs <- self$imgs[self$imgs$image != "beagle_116",] 92 | 93 | }, 94 | 95 | .getitem = function(i) { 96 | img <- self$imgs[i,] 97 | list( 98 | x = self$transform(self$read_img(img)), 99 | y = self$target_transform(self$target_reader(img)) 100 | ) 101 | }, 102 | 103 | .length = function() { 104 | nrow(self$imgs) 105 | }, 106 | 107 | read_img = function(img) { 108 | path <- fs::path(self$data_path, "images", paste0(img$image, ".jpg")) 109 | if (img$ext == "jpg") 110 | jpeg::readJPEG(path) 111 | else 112 | png::readPNG(path)[,,1:3] # we remove the alpha channel 113 | }, 114 | 115 | read_trimap = function(img) { 116 | mask <- png::readPNG(fs::path(self$data_path, "annotations", "trimaps", paste0(img$image, ".png"))) 117 | dimensions <- dim(mask) 118 | mask <- as.integer(mask*255) 119 | dim(mask) <- dimensions 120 | mask 121 | }, 122 | 123 | read_species = function(img) { 124 | as.integer(img$species_id) 125 | }, 126 | 127 | read_breed = function(img) { 128 | as.integer(img$breed_id) 129 | } 130 | 131 | ) 132 | -------------------------------------------------------------------------------- /R/spam-dataset.R: -------------------------------------------------------------------------------- 1 | #' Spam Dataset Loader 2 | #' 3 | #' Defines the spam dataset commonly used in machine learning. 4 | #' 5 | #' @param url A character string representing the URL of the dataset. 6 | #' @param download Logical; whether to download the dataset. Defaults to FALSE. 7 | #' @param transform Function to apply transformations to the features. Defaults to NULL. 8 | #' @param target_transform Function to apply transformations to the labels. Defaults to NULL. 9 | #' @return A `torch::dataset` object for the spam dataset. 10 | #' @examples 11 | #' \dontrun{ 12 | #' # Simple usage: 13 | #' ds <- spam_dataset(download = TRUE) 14 | #' loader <- dataloader(ds, batch_size = 32, shuffle = TRUE) 15 | #' batch <- dataloader_make_iter(loader) %>% dataloader_next() 16 | #' dim(batch$x) 17 | #' length(batch$y) 18 | #' } 19 | #' @export 20 | spam_dataset <- torch::dataset( 21 | name = "spam_dataset", 22 | 23 | initialize = function( 24 | url = "https://hastie.su.domains/ElemStatLearn/datasets/spam.data", 25 | download = FALSE, 26 | transform = NULL, 27 | target_transform = NULL 28 | ) { 29 | data_path <- tempfile(fileext = ".data") 30 | 31 | if (download) { 32 | download.file(url, data_path, mode = "wb") 33 | } else { 34 | data_path <- url 35 | } 36 | 37 | raw_spam_data <- read.table(data_path, header = FALSE) 38 | 39 | self$x_data <- as.matrix(raw_spam_data[, -ncol(raw_spam_data)]) 40 | self$y_data <- as.numeric(raw_spam_data[, ncol(raw_spam_data)]) 41 | 42 | self$transform <- transform 43 | self$target_transform <- target_transform 44 | }, 45 | 46 | .getitem = function(index) { 47 | x <- self$x_data[index, ] 48 | y <- self$y_data[index] 49 | 50 | if (!is.null(self$transform)) { 51 | x <- self$transform(x) 52 | } 53 | 54 | if (!is.null(self$target_transform)) { 55 | y <- self$target_transform(y) 56 | } 57 | 58 | list( 59 | x = torch::torch_tensor(x, dtype = torch_float()), 60 | y = torch::torch_tensor(y, dtype = torch_long()) 61 | ) 62 | }, 63 | 64 | .length = function() { 65 | nrow(self$x_data) 66 | } 67 | ) 68 | -------------------------------------------------------------------------------- /R/utils.R: -------------------------------------------------------------------------------- 1 | kaggle_download <- function(name, token = NULL) { 2 | 3 | if ("kaggle" %in% pins::board_list()) { 4 | file <- pins::pin_get(board = "kaggle", name, 5 | extract = FALSE) 6 | } else if (!is.null(token)) { 7 | pins::board_register_kaggle(name="torchdatasets-kaggle", token = token, 8 | cache = tempfile(pattern = "dir")) 9 | on.exit({pins::board_deregister("torchdatasets-kaggle")}, add = TRUE) 10 | file <- pins::pin_get(name, 11 | board = "torchdatasets-kaggle", 12 | extract = FALSE) 13 | } else { 14 | stop("Please register the Kaggle board or pass the `token` parameter.") 15 | } 16 | 17 | file 18 | } 19 | 20 | download_file <- function(url, destfile, ...) { 21 | withr::with_options(new = list(timeout = max(600, getOption("timeout"))), { 22 | utils::download.file(url, destfile, ...) 23 | }) 24 | destfile 25 | } 26 | 27 | maybe_download <- function(url, root, name, extract_fun, download) { 28 | data_path <- fs::path_expand(fs::path(root, name)) 29 | 30 | if (!fs::dir_exists(data_path) && download) { 31 | tmp <- tempfile() 32 | download_file(url, tmp) 33 | fs::dir_create(fs::path_dir(data_path), recurse = TRUE) 34 | extract_fun(tmp, data_path) 35 | } 36 | 37 | if (!fs::dir_exists(data_path)) 38 | stop("No data found. Please use `download = TRUE`.") 39 | 40 | data_path 41 | } 42 | 43 | unzip2 <- function(path, exdir) { 44 | if (grepl("linux", R.version$os)) { 45 | utils::unzip(path, exdir = exdir) 46 | } else { 47 | zip::unzip(path, exdir = exdir) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchdatasets 2 | 3 | 4 | 5 | [![R-CMD-check](https://github.com/mlverse/torchdatasets/workflows/R-CMD-check/badge.svg)](https://github.com/mlverse/torchdatasets/actions) [![CRAN status](https://www.r-pkg.org/badges/version/torchdatasets)](https://CRAN.R-project.org/package=torchdatasets) [![](https://cranlogs.r-pkg.org/badges/torchdatasets)](https://cran.r-project.org/package=torchdatasets) 6 | 7 | 8 | 9 | torchdatasets provides ready-to-use datasets compatible with the [torch](https://github.com/mlverse/torch) package. 10 | 11 | ## Installation 12 | 13 | The released version of torchdatasets can be installed with: 14 | 15 | ``` r 16 | install.packages("torchdatasets") 17 | ``` 18 | 19 | You can also install the development version with: 20 | 21 | ``` r 22 | remotes::install_github("mlverse/torchdatasets") 23 | ``` 24 | 25 | ## Datasets 26 | 27 | Currently, the following datasets are implemented: 28 | 29 | | Dataset | Domain | Type | Authentication | 30 | |---------------------------------|---------|----------------|----------------| 31 | | bird_species_dataset() | Images | Classification | Not required | 32 | | dogs_vs_cats_dataset() | Images | Classification | Not required | 33 | | guess_the_correlation_dataset() | Images | Regression | Not required | 34 | | cityscapes_pix2pix_dataset() | Images | Segmentation | Not required | 35 | | oxford_pet_dataset() | Images | Segmentation | Not required | 36 | | bank_marketing_dataset() | Tabular | Classification | Not required | 37 | | imdb_dataset() | Text | Classification | Not required | 38 | | spam_dataset() | Tabular | Classification | Not required | 39 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://mlverse.github.io/torchdatasets 2 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | * This is a new release. 2 | -------------------------------------------------------------------------------- /man/bank_marketing_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/bank-marketing.R 3 | \name{bank_marketing_dataset} 4 | \alias{bank_marketing_dataset} 5 | \title{Bank marketing dataset} 6 | \usage{ 7 | bank_marketing_dataset( 8 | root, 9 | split = "train", 10 | indexes = NULL, 11 | download = FALSE, 12 | with_call_duration = FALSE 13 | ) 14 | } 15 | \arguments{ 16 | \item{root}{path to the data location} 17 | 18 | \item{split}{string. 'train' or 'submission'} 19 | 20 | \item{indexes}{set of integers for subsampling (e.g. 1:41188)} 21 | 22 | \item{download}{whether to download or not} 23 | 24 | \item{with_call_duration}{whether the call duration should be included as a feature. Could lead to leakage. Default: FALSE.} 25 | } 26 | \value{ 27 | A torch dataset that can be consumed with \code{\link[torch:dataloader]{torch::dataloader()}}. 28 | } 29 | \description{ 30 | Prepares the Bank marketing dataset available on UCI Machine Learning repository \href{https://archive.ics.uci.edu/ml/datasets/Bank+Marketing}{here} 31 | The data is available publicly for download, there is no need to authenticate. 32 | Please cite the data as Moro et al., 2014 33 | S. Moro, P. Cortez and P. Rita. A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems, Elsevier, 62:22-31, June 2014 34 | } 35 | \examples{ 36 | if (torch::torch_is_installed() && FALSE) { 37 | bank_mkt <- bank_marketing_dataset("./data", download = TRUE) 38 | length(bank_mkt) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /man/bird_species_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/bird-species.R 3 | \name{bird_species_dataset} 4 | \alias{bird_species_dataset} 5 | \title{Bird species dataset} 6 | \usage{ 7 | bird_species_dataset(root, split = "train", download = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{root}{path to the data location} 11 | 12 | \item{split}{train, test or valid} 13 | 14 | \item{download}{wether to download or not} 15 | 16 | \item{...}{other arguments passed to \code{\link[torchvision:image_folder_dataset]{torchvision::image_folder_dataset()}}.} 17 | } 18 | \value{ 19 | A \code{\link[torch:dataset]{torch::dataset()}} ready to be used with dataloaders. 20 | } 21 | \description{ 22 | Downloads and prepares the 450 bird species dataset found on Kaggle. 23 | The dataset description, license, etc can be found \href{https://www.kaggle.com/datasets/gpiosenka/100-bird-species}{here}. 24 | } 25 | \examples{ 26 | if (torch::torch_is_installed() && FALSE) { 27 | birds <- bird_species_dataset("./data", token = "path/to/kaggle.json", 28 | download = TRUE) 29 | length(birds) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /man/cityscapes_pix2pix_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/cityscapes-pix2pix.R 3 | \name{cityscapes_pix2pix_dataset} 4 | \alias{cityscapes_pix2pix_dataset} 5 | \title{Cityscapes Pix2Pix dataset} 6 | \usage{ 7 | cityscapes_pix2pix_dataset( 8 | root, 9 | split = "train", 10 | download = FALSE, 11 | ..., 12 | transform = NULL, 13 | target_transform = NULL 14 | ) 15 | } 16 | \arguments{ 17 | \item{root}{path to the data location} 18 | 19 | \item{split}{train, test or valid} 20 | 21 | \item{download}{wether to download or not} 22 | 23 | \item{...}{Currently unused.} 24 | 25 | \item{transform}{A function/transform that takes in an PIL image and returns 26 | a transformed version. E.g, \code{\link[torchvision:transform_random_crop]{transform_random_crop()}}.} 27 | 28 | \item{target_transform}{A function/transform that takes in the target and 29 | transforms it.} 30 | } 31 | \description{ 32 | Downloads and prepares the cityscapes dataset that has been used in the 33 | \href{https://arxiv.org/abs/1611.07004}{pix2pix paper}. 34 | } 35 | \details{ 36 | Find more information in the \href{https://phillipi.github.io/pix2pix/}{project website} 37 | } 38 | -------------------------------------------------------------------------------- /man/dogs_vs_cats_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/dogs-vs-cats.R 3 | \name{dogs_vs_cats_dataset} 4 | \alias{dogs_vs_cats_dataset} 5 | \title{Dog vs cats dataset} 6 | \usage{ 7 | dogs_vs_cats_dataset( 8 | root, 9 | split = "train", 10 | download = FALSE, 11 | ..., 12 | transform = NULL, 13 | target_transform = NULL 14 | ) 15 | } 16 | \arguments{ 17 | \item{root}{path to the data location} 18 | 19 | \item{split}{string. 'train' or 'submission'} 20 | 21 | \item{download}{whether to download or not} 22 | 23 | \item{...}{Currently unused.} 24 | 25 | \item{transform}{function that takes a torch tensor representing an image and return another tensor, transformed.} 26 | 27 | \item{target_transform}{function that takes a scalar torch tensor and returns another tensor, transformed.} 28 | } 29 | \value{ 30 | A \code{\link[torch:dataset]{torch::dataset()}} ready to be used with dataloaders. 31 | } 32 | \description{ 33 | Prepares the dog vs cats dataset available in Kaggle 34 | \href{https://www.kaggle.com/c/dogs-vs-cats}{here} 35 | } 36 | \examples{ 37 | if (torch::torch_is_installed() && FALSE) { 38 | dogs_cats <- dogs_vs_cats_dataset("./data", token = "path/to/kaggle.json", 39 | download = TRUE) 40 | length(dogs_cats) 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /man/guess_the_correlation_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/guess-the-correlation.R 3 | \name{guess_the_correlation_dataset} 4 | \alias{guess_the_correlation_dataset} 5 | \title{Guess The Correlation dataset} 6 | \usage{ 7 | guess_the_correlation_dataset( 8 | root, 9 | split = "train", 10 | transform = NULL, 11 | target_transform = NULL, 12 | indexes = NULL, 13 | download = FALSE 14 | ) 15 | } 16 | \arguments{ 17 | \item{root}{path to the data location} 18 | 19 | \item{split}{string. 'train' or 'submission'} 20 | 21 | \item{transform}{function that takes a torch tensor representing an image and return another tensor, transformed.} 22 | 23 | \item{target_transform}{function that takes a scalar torch tensor and returns another tensor, transformed.} 24 | 25 | \item{indexes}{set of integers for subsampling (e.g. 1:140000)} 26 | 27 | \item{download}{whether to download or not} 28 | } 29 | \value{ 30 | A torch dataset that can be consumed with \code{\link[torch:dataloader]{torch::dataloader()}}. 31 | } 32 | \description{ 33 | Prepares the Guess The Correlation dataset available on Kaggle \href{https://www.kaggle.com/c/guess-the-correlation}{here} 34 | A copy of this dataset is hosted in a public Google Cloud 35 | bucket so you don't need to authenticate. 36 | } 37 | \examples{ 38 | if (torch::torch_is_installed() && FALSE) { 39 | gtc <- guess_the_correlation_dataset("./data", download = TRUE) 40 | length(gtc) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /man/imdb_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/imdb.R 3 | \name{imdb_dataset} 4 | \alias{imdb_dataset} 5 | \title{IMDB movie review sentiment classification dataset} 6 | \usage{ 7 | imdb_dataset( 8 | root, 9 | download = FALSE, 10 | split = "train", 11 | shuffle = (split == "train"), 12 | num_words = Inf, 13 | skip_top = 0, 14 | maxlen = Inf, 15 | start_char = 2, 16 | oov_char = 3, 17 | index_from = 4 18 | ) 19 | } 20 | \arguments{ 21 | \item{root}{path to the data location} 22 | 23 | \item{download}{wether to download or not} 24 | 25 | \item{split}{train, test or valid} 26 | 27 | \item{shuffle}{whether to shuffle or not the dataset. \code{TRUE} if \code{split=="train"}} 28 | 29 | \item{num_words}{Words are ranked by how often they occur (in the training set), 30 | and only the num_words most frequent words are kept. Any less frequent word 31 | will appear as oov_char value in the sequence data. If \code{Inf}, all words are 32 | kept. Defaults to None, so all words are kept.} 33 | 34 | \item{skip_top}{skip the top N most frequently occurring words (which may not be informative). 35 | These words will appear as oov_char value in the dataset. Defaults to 0, so 36 | no words are skipped.} 37 | 38 | \item{maxlen}{int or \code{Inf}. Maximum sequence length. Any longer sequence will 39 | be truncated. Defaults to Inf, which means no truncation.} 40 | 41 | \item{start_char}{The start of a sequence will be marked with this character. 42 | Defaults to 2, because 1 is usually the padding character.} 43 | 44 | \item{oov_char}{int. The out-of-vocabulary character. Words that were cut out 45 | because of the num_words or skip_top limits will be replaced with this character.} 46 | 47 | \item{index_from}{int. Index actual words with this index and higher.} 48 | } 49 | \description{ 50 | The format of this dataset is meant to replicate that provided by 51 | \href{https://keras.io/api/datasets/imdb/}{Keras}. 52 | } 53 | -------------------------------------------------------------------------------- /man/oxford_flowers102_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/oxford-flowers-dataset.R 3 | \name{oxford_flowers102_dataset} 4 | \alias{oxford_flowers102_dataset} 5 | \title{102 Category Flower Dataset} 6 | \usage{ 7 | oxford_flowers102_dataset( 8 | root, 9 | split = "train", 10 | target_type = c("categories"), 11 | download = FALSE, 12 | ..., 13 | transform = NULL, 14 | target_transform = NULL 15 | ) 16 | } 17 | \arguments{ 18 | \item{root}{path to the data location} 19 | 20 | \item{split}{train, test or valid} 21 | 22 | \item{target_type}{Currently only 'categories' is supported.} 23 | 24 | \item{download}{wether to download or not} 25 | 26 | \item{...}{Currently unused.} 27 | 28 | \item{transform}{A function/transform that takes in an PIL image and returns 29 | a transformed version. E.g, \code{\link[torchvision:transform_random_crop]{transform_random_crop()}}.} 30 | 31 | \item{target_transform}{A function/transform that takes in the target and 32 | transforms it.} 33 | } 34 | \description{ 35 | The Oxford Flower Dataset is a 102 category dataset, consisting of 102 flower 36 | categories. The flowers chosen to be flower commonly occuring in the United 37 | Kingdom. Each class consists of between 40 and 258 images. The details of the 38 | categories and the number of images for each class can be found on 39 | \href{https://www.robots.ox.ac.uk/\%7Evgg/data/flowers/102/categories.html}{this category statistics page}. 40 | } 41 | \details{ 42 | The images have large scale, pose and light variations. In addition, there are 43 | categories that have large variations within the category and several very 44 | similar categories. The dataset is visualized using isomap with shape and colour 45 | features. 46 | 47 | You can find more info in the dataset \href{https://www.robots.ox.ac.uk/\%7Evgg/data/flowers/102/}{webpage}. 48 | } 49 | \note{ 50 | The official splits leaves far too many images in the test set. Depending 51 | on your work you might want to create different train/valid/test splits. 52 | } 53 | -------------------------------------------------------------------------------- /man/oxford_pet_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/oxford-pet-dataset.R 3 | \name{oxford_pet_dataset} 4 | \alias{oxford_pet_dataset} 5 | \title{Oxford Pet Dataset} 6 | \usage{ 7 | oxford_pet_dataset( 8 | root, 9 | split = "train", 10 | target_type = c("trimap", "species", "breed"), 11 | download = FALSE, 12 | ..., 13 | transform = NULL, 14 | target_transform = NULL 15 | ) 16 | } 17 | \arguments{ 18 | \item{root}{path to the data location} 19 | 20 | \item{split}{train, test or valid} 21 | 22 | \item{target_type}{The type of the target: 23 | \itemize{ 24 | \item 'trimap': returns a mask array with one class per pixel. 25 | \item 'species': returns the species id. 1 for cat and 2 for dog. 26 | \item 'breed': returns the breed id. see \code{dataset$breed_classes}. 27 | }} 28 | 29 | \item{download}{wether to download or not} 30 | 31 | \item{...}{Currently unused.} 32 | 33 | \item{transform}{A function/transform that takes in an PIL image and returns 34 | a transformed version. E.g, \code{\link[torchvision:transform_random_crop]{transform_random_crop()}}.} 35 | 36 | \item{target_transform}{A function/transform that takes in the target and 37 | transforms it.} 38 | } 39 | \description{ 40 | The Oxford-IIIT Pet Dataset is a 37 category pet dataset with roughly 41 | 200 images for each class. The images have a large variations in scale, 42 | pose and lighting. All images have an associated ground truth annotation of 43 | species (cat or dog), breed, and pixel-level trimap segmentation. 44 | } 45 | -------------------------------------------------------------------------------- /man/spam_dataset.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/spam-dataloader.R 3 | \name{spam_dataset} 4 | \alias{spam_dataset} 5 | \title{Spam Dataset Loader} 6 | \usage{ 7 | spam_dataset( 8 | url = "https://hastie.su.domains/ElemStatLearn/datasets/spam.data", 9 | download = FALSE, 10 | transform = NULL, 11 | target_transform = NULL 12 | ) 13 | } 14 | \arguments{ 15 | \item{url}{A character string representing the URL of the dataset.} 16 | 17 | \item{download}{Logical; whether to download the dataset. Defaults to FALSE.} 18 | 19 | \item{transform}{Function to apply transformations to the features. Defaults to NULL.} 20 | 21 | \item{target_transform}{Function to apply transformations to the labels. Defaults to NULL.} 22 | } 23 | \value{ 24 | A \code{torch::dataset} object for the spam dataset. 25 | } 26 | \description{ 27 | Defines the spam dataset commonly used in machine learning. 28 | } 29 | \examples{ 30 | \dontrun{ 31 | # Simple usage: 32 | ds <- spam_dataset(download = TRUE) 33 | loader <- dataloader(ds, batch_size = 32, shuffle = TRUE) 34 | batch <- dataloader_make_iter(loader) \%>\% dataloader_next() 35 | dim(batch$x) 36 | length(batch$y) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(torchdatasets) 3 | 4 | if (Sys.getenv("TORCH_TEST", unset = 0) == 1) 5 | test_check("torchdatasets") 6 | -------------------------------------------------------------------------------- /tests/testthat/helper-torchdatasets.R: -------------------------------------------------------------------------------- 1 | is_torch_tensor <- function(x) { 2 | inherits(x, "torch_tensor") 3 | } 4 | 5 | expect_no_error <- function(object, ...) { 6 | expect_error(object, NA, ...) 7 | } 8 | 9 | expect_tensor_shape <- function(object, expected) { 10 | expect_tensor(object) 11 | expect_equal(object$shape, expected) 12 | } 13 | 14 | expect_tensor <- function(object) { 15 | expect_true(is_torch_tensor(object)) 16 | expect_no_error(torch::as_array(object)) 17 | } 18 | 19 | expect_equal_to_r <- function(object, expected) { 20 | expect_equal(torch::as_array(object), expected) 21 | } 22 | -------------------------------------------------------------------------------- /tests/testthat/test-bank-marketing.R: -------------------------------------------------------------------------------- 1 | test_that("bank marketting works", { 2 | 3 | data <- bank_marketing_dataset( 4 | root = tempfile(), 5 | download = TRUE 6 | ) 7 | 8 | expect_length(data$.getitem(1), 2) 9 | 10 | dl <- torch::dataloader(data, batch_size = 32) 11 | x <- coro::collect(dl, n = 1) 12 | 13 | expect_equal(x[[1]]$x$shape, c(32, 55)) 14 | expect_equal(x[[1]]$y$shape, c(32)) 15 | 16 | }) 17 | -------------------------------------------------------------------------------- /tests/testthat/test-bird-species.R: -------------------------------------------------------------------------------- 1 | test_that("bird-species works", { 2 | 3 | dataset <- bird_species_dataset( 4 | root = "./bird", 5 | download = TRUE 6 | ) 7 | 8 | expect_length(dataset$.getitem(1), 2) 9 | 10 | }) 11 | -------------------------------------------------------------------------------- /tests/testthat/test-cityscapes-pix2pix.R: -------------------------------------------------------------------------------- 1 | test_that("cityscapes_pix2pix works", { 2 | 3 | root <- tempfile() 4 | 5 | train <- cityscapes_pix2pix_dataset( 6 | root = root, 7 | download = TRUE, 8 | transform = torchvision::transform_to_tensor, 9 | target_transform = torchvision::transform_to_tensor 10 | ) 11 | 12 | valid <- cityscapes_pix2pix_dataset( 13 | root = root, 14 | split = "valid", 15 | download = FALSE, 16 | transform = torchvision::transform_to_tensor, 17 | target_transform = torchvision::transform_to_tensor 18 | ) 19 | 20 | expect_tensor_shape(train[1][[1]], c(3, 256, 256)) 21 | expect_tensor_shape(train[1][[2]], c(3, 256, 256)) 22 | 23 | expect_tensor_shape(valid[1][[1]], c(3, 256, 256)) 24 | expect_tensor_shape(valid[1][[2]], c(3, 256, 256)) 25 | }) 26 | -------------------------------------------------------------------------------- /tests/testthat/test-dogs-vs-cats.R: -------------------------------------------------------------------------------- 1 | 2 | 3 | test_that("dogs-vs-cats dataset", { 4 | 5 | dataset <- dogs_vs_cats_dataset( 6 | "./dogs-vs-cats", 7 | download = TRUE 8 | ) 9 | 10 | expect_length(dataset$.getitem(1), 2) 11 | 12 | }) 13 | 14 | -------------------------------------------------------------------------------- /tests/testthat/test-guess-the-correlation.R: -------------------------------------------------------------------------------- 1 | test_that("guess_the_correlation_dataset works", { 2 | 3 | tmp <- tempfile() 4 | dataset <- guess_the_correlation_dataset( 5 | root = tmp, 6 | download = TRUE, 7 | transform = function(x) torch::torch_zeros(3,3) 8 | ) 9 | 10 | expect_length(dataset$.getitem(1), 3) 11 | expect_equal(dim(dataset$.getitem(1)$x), c(3,3)) 12 | expect_true(dataset$.getitem(1)$y$dtype == torch::torch_float()) 13 | 14 | }) 15 | -------------------------------------------------------------------------------- /tests/testthat/test-imdb.R: -------------------------------------------------------------------------------- 1 | test_that("imdb dataset works", { 2 | 3 | tmp <- tempfile() 4 | 5 | dataset <- imdb_dataset( 6 | root = tmp, 7 | download = TRUE, 8 | num_words = 5000 9 | ) 10 | 11 | expect_equal(length(dataset), 25000) 12 | 13 | # can used the cached dataset 14 | dataset <- imdb_dataset( 15 | root = tmp, 16 | download = TRUE, 17 | num_words = 3000, 18 | maxlen = 2500 19 | ) 20 | 21 | expect_equal(length(dataset), 25000) 22 | expect_equal(dataset[1]$x[1], 1) 23 | expect_equal(length(dataset[1]$x), 2500) 24 | 25 | # can load a batch of obs 26 | dl <- torch::dataloader(dataset, batch_size = 32) 27 | x <- coro::collect(dl, 1)[[1]] 28 | 29 | expect_tensor_shape(x$x, c(32, 2500)) 30 | expect_tensor_shape(x$y, c(32)) 31 | 32 | # can load tests dataset 33 | dataset <- imdb_dataset( 34 | root = tmp, 35 | download = TRUE, 36 | num_words = 5000, 37 | split = "test" 38 | ) 39 | 40 | expect_equal(length(dataset), 25000) 41 | 42 | }) 43 | -------------------------------------------------------------------------------- /tests/testthat/test-oxford-flowers-dataset.R: -------------------------------------------------------------------------------- 1 | test_that("oxford flowers dataset", { 2 | 3 | root <- tempfile() 4 | 5 | train <- oxford_flowers102_dataset( 6 | root = root, 7 | download = TRUE, 8 | transform = torchvision::transform_to_tensor 9 | ) 10 | 11 | valid <- oxford_flowers102_dataset( 12 | root = root, 13 | split = "valid", 14 | download = FALSE, 15 | transform = torchvision::transform_to_tensor 16 | ) 17 | 18 | test <- oxford_flowers102_dataset( 19 | root = root, 20 | split = "test", 21 | download = FALSE, 22 | transform = torchvision::transform_to_tensor 23 | ) 24 | 25 | all <- oxford_flowers102_dataset( 26 | root = root, 27 | split = c("train", "valid", "test"), 28 | download = FALSE, 29 | transform = torchvision::transform_to_tensor 30 | ) 31 | 32 | expect_equal(train$classes[train[1]$y], "pink primrose") 33 | 34 | expect_equal(length(all), 8189) 35 | expect_equal(length(valid), 1020) 36 | expect_equal(length(train), 1020) 37 | expect_equal(length(test), 6149) 38 | 39 | expect_tensor_shape(train[1][[1]], c(3, 500, 754)) 40 | expect_tensor_shape(valid[1][[1]], c(3, 500, 606)) 41 | expect_tensor_shape(all[1][[1]], c(3, 500, 754)) 42 | }) 43 | -------------------------------------------------------------------------------- /tests/testthat/test-oxford-pet-dataset.R: -------------------------------------------------------------------------------- 1 | test_that("oxford pet dataset", { 2 | 3 | root <- tempfile() 4 | 5 | train <- oxford_pet_dataset( 6 | root = root, 7 | download = TRUE, 8 | transform = torchvision::transform_to_tensor, 9 | target_transform = torchvision::transform_to_tensor 10 | ) 11 | 12 | valid <- oxford_pet_dataset( 13 | root = root, 14 | split = "valid", 15 | download = FALSE, 16 | transform = torchvision::transform_to_tensor, 17 | target_transform = torchvision::transform_to_tensor 18 | ) 19 | 20 | expect_tensor_shape(train[1][[1]], c(3, 500, 394)) 21 | expect_tensor_shape(train[1][[2]], c(1, 500, 394)) 22 | 23 | expect_tensor_shape(valid[1][[1]], c(3, 225, 300)) 24 | expect_tensor_shape(valid[1][[2]], c(1, 225, 300)) 25 | }) 26 | -------------------------------------------------------------------------------- /tests/testthat/test-spam-dataset.R: -------------------------------------------------------------------------------- 1 | test_that("spam_dataset works as expected", { 2 | 3 | dataset <- spam_dataset(download = TRUE) 4 | 5 | iter <- dataloader_make_iter(dataset) 6 | batch <- dataloader_next(iter) 7 | 8 | expect_equal(dim(batch$x), c(32, 57)) 9 | 10 | expect_equal(length(batch$y), 32) 11 | 12 | expect_true(all(as.array(batch$y) %in% c(0, 1))) 13 | }) 14 | -------------------------------------------------------------------------------- /torchdatasets.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | LineEndingConversion: Posix 18 | 19 | BuildType: Package 20 | PackageUseDevtools: Yes 21 | PackageInstallArgs: --no-multiarch --with-keep.source 22 | PackageRoxygenize: rd,collate,namespace 23 | --------------------------------------------------------------------------------