├── .formatter.exs ├── .gitignore ├── LICENSE ├── README.md ├── lib └── scidata │ ├── caltech101.ex │ ├── cifar10.ex │ ├── cifar100.ex │ ├── fashionmnist.ex │ ├── imdb_reviews.ex │ ├── iris.ex │ ├── kuzushiji_mnist.ex │ ├── mnist.ex │ ├── squad.ex │ ├── utils.ex │ ├── wine.ex │ ├── yelp_full_reviews.ex │ └── yelp_polarity_reviews.ex ├── mix.exs ├── mix.lock └── test ├── caltech101_test.exs ├── cifar100_test.exs ├── cifar10_test.exs ├── fashionmnist_test.exs ├── imdb_reviews_test.exs ├── iris_test.exs ├── kuzushiji_mnist_test.exs ├── mnist_test.exs ├── squad_test.exs ├── test_helper.exs ├── wine_test.exs ├── yelp_full_reviews_test.exs └── yelp_polarity_reviews_test.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | scidata-*.tar 24 | 25 | 26 | # Temporary files for e.g. tests 27 | /tmp 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scidata 2 | 3 | ## Usage 4 | 5 | Scidata currently supports the following training and test datasets: 6 | 7 | - Caltech101 8 | - CIFAR10 9 | - CIFAR100 10 | - FashionMNIST 11 | - IMDb Reviews 12 | - Kuzushiji-MNIST (KMNIST) 13 | - MNIST 14 | - SQuAD 15 | - Yelp Reviews (Full and Polarity) 16 | - Iris 17 | - Wine 18 | 19 | Download or fetch datasets locally: 20 | 21 | ```elixir 22 | {train_images, train_labels} = Scidata.MNIST.download() 23 | {test_images, test_labels} = Scidata.MNIST.download_test() 24 | 25 | # Unpack train_images like... 26 | {images_binary, tensor_type, shape} = train_images 27 | ``` 28 | 29 | Most often you will convert those results to `Nx` tensors: 30 | 31 | ```elixir 32 | {train_images, train_labels} = Scidata.MNIST.download() 33 | 34 | # Normalize and batch images 35 | {images_binary, images_type, images_shape} = train_images 36 | 37 | batched_images = 38 | images_binary 39 | |> Nx.from_binary(images_type) 40 | |> Nx.reshape(images_shape) 41 | |> Nx.divide(255) 42 | |> Nx.to_batched(32) 43 | 44 | # One-hot-encode and batch labels 45 | {labels_binary, labels_type, _shape} = train_labels 46 | 47 | batchd_labels = 48 | labels_binary 49 | |> Nx.from_binary(labels_type) 50 | |> Nx.new_axis(-1) 51 | |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) 52 | |> Nx.to_batched(32) 53 | ``` 54 | 55 | ## Installation 56 | 57 | ```elixir 58 | def deps do 59 | [ 60 | {:scidata, "~> 0.1.11"} 61 | ] 62 | end 63 | ``` 64 | 65 | ## Contributing 66 | 67 | PRs are encouraged! Consider using [utils](https://github.com/elixir-nx/scidata/blob/master/lib/scidata/utils.ex) to add your favorite dataset or one from [this list](https://github.com/elixir-nx/scidata/issues/16). 68 | 69 | ## License 70 | 71 | Copyright (c) 2022 Tom Rutten 72 | 73 | Licensed under the Apache License, Version 2.0 (the "License"); 74 | you may not use this file except in compliance with the License. 75 | You may obtain a copy of the License at [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) 76 | 77 | Unless required by applicable law or agreed to in writing, software 78 | distributed under the License is distributed on an "AS IS" BASIS, 79 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 80 | See the License for the specific language governing permissions and 81 | limitations under the License. 82 | -------------------------------------------------------------------------------- /lib/scidata/caltech101.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.Caltech101 do 2 | @moduledoc """ 3 | Module for downloading the [Caltech101 dataset](http://www.vision.caltech.edu/Image_Datasets/Caltech101). 4 | """ 5 | 6 | require Scidata.Utils 7 | alias Scidata.Utils 8 | 9 | @base_url "https://s3.amazonaws.com/fast-ai-imageclas/" 10 | @dataset_file "caltech_101.tgz" 11 | @labels_shape {9144, 1} 12 | @label_mapping %{ 13 | accordion: 0, 14 | airplanes: 1, 15 | anchor: 2, 16 | ant: 3, 17 | background_google: 4, 18 | barrel: 5, 19 | bass: 6, 20 | beaver: 7, 21 | binocular: 8, 22 | bonsai: 9, 23 | brain: 10, 24 | brontosaurus: 11, 25 | buddha: 12, 26 | butterfly: 13, 27 | camera: 14, 28 | cannon: 15, 29 | car_side: 16, 30 | ceiling_fan: 17, 31 | cellphone: 18, 32 | chair: 19, 33 | chandelier: 20, 34 | cougar_body: 21, 35 | cougar_face: 22, 36 | crab: 23, 37 | crayfish: 24, 38 | crocodile: 25, 39 | crocodile_head: 26, 40 | cup: 27, 41 | dalmatian: 28, 42 | dollar_bill: 29, 43 | dolphin: 30, 44 | dragonfly: 31, 45 | electric_guitar: 32, 46 | elephant: 33, 47 | emu: 34, 48 | euphonium: 35, 49 | ewer: 36, 50 | faces: 37, 51 | faces_easy: 38, 52 | ferry: 39, 53 | flamingo: 40, 54 | flamingo_head: 41, 55 | garfield: 42, 56 | gerenuk: 43, 57 | gramophone: 44, 58 | grand_piano: 45, 59 | hawksbill: 46, 60 | headphone: 47, 61 | hedgehog: 48, 62 | helicopter: 49, 63 | ibis: 50, 64 | inline_skate: 51, 65 | joshua_tree: 52, 66 | kangaroo: 53, 67 | ketch: 54, 68 | lamp: 55, 69 | laptop: 56, 70 | leopards: 57, 71 | llama: 58, 72 | lobster: 59, 73 | lotus: 60, 74 | mandolin: 61, 75 | mayfly: 62, 76 | menorah: 63, 77 | metronome: 64, 78 | minaret: 65, 79 | motorbikes: 66, 80 | nautilus: 67, 81 | octopus: 68, 82 | okapi: 69, 83 | pagoda: 70, 84 | panda: 71, 85 | pigeon: 72, 86 | pizza: 73, 87 | platypus: 74, 88 | pyramid: 75, 89 | revolver: 76, 90 | rhino: 77, 91 | rooster: 78, 92 | saxophone: 79, 93 | schooner: 80, 94 | scissors: 81, 95 | scorpion: 82, 96 | sea_horse: 83, 97 | snoopy: 84, 98 | soccer_ball: 85, 99 | stapler: 86, 100 | starfish: 87, 101 | stegosaurus: 88, 102 | stop_sign: 89, 103 | strawberry: 90, 104 | sunflower: 91, 105 | tick: 92, 106 | trilobite: 93, 107 | umbrella: 94, 108 | watch: 95, 109 | water_lilly: 96, 110 | wheelchair: 97, 111 | wild_cat: 98, 112 | windsor_chair: 99, 113 | wrench: 100, 114 | yin_yang: 101 115 | } 116 | 117 | @doc """ 118 | Downloads the Caltech101 training dataset or fetches it locally. 119 | 120 | Returns a tuple of format: 121 | 122 | {{images_binary, images_type, images_shape}, 123 | {labels_binary, labels_type, labels_shape}} 124 | 125 | If you want to one-hot encode the labels, you can: 126 | 127 | labels_binary 128 | |> Nx.from_binary(labels_type) 129 | |> Nx.new_axis(-1) 130 | |> Nx.equal(Nx.tensor(Enum.to_list(1..102))) 131 | 132 | ## Options. 133 | 134 | * `:base_url` - Dataset base URL. 135 | 136 | Defaults to `"https://s3.amazonaws.com/fast-ai-imageclas/"` 137 | 138 | * `:dataset_file` - Dataset filename. 139 | 140 | Defaults to `"caltech_101.tgz"` 141 | 142 | * `:cache_dir` - Cache directory. 143 | 144 | Defaults to `System.tmp_dir!()` 145 | 146 | """ 147 | def download(opts \\ []) do 148 | unless Code.ensure_loaded?(StbImage) do 149 | raise "StbImage is missing, please add `{:stb_image, \"~> 0.4\"}` as a dependency to your mix.exs" 150 | end 151 | 152 | download_dataset(:train, opts) 153 | end 154 | 155 | defp download_dataset(_dataset_type, opts) do 156 | base_url = opts[:base_url] || @base_url 157 | dataset_file = opts[:dataset_file] || @dataset_file 158 | 159 | # Skip first file since it's a temporary file. 160 | [_ | files] = Utils.get!(base_url <> dataset_file, opts).body 161 | 162 | {images, shapes, labels} = 163 | files 164 | |> Enum.reverse() 165 | |> Task.async_stream(&generate_records/1, 166 | max_concurrency: Keyword.get(opts, :max_concurrency, System.schedulers_online()) 167 | ) 168 | |> Enum.reduce( 169 | {[], [], []}, 170 | fn {:ok, record}, {image_acc, shape_acc, label_acc} -> 171 | {%{data: image_bin, shape: shape}, label} = record 172 | {[image_bin | image_acc], [shape | shape_acc], [label | label_acc]} 173 | end 174 | ) 175 | 176 | {{images, {:u, 8}, shapes}, {IO.iodata_to_binary(labels), {:u, 8}, @labels_shape}} 177 | end 178 | 179 | @compile {:no_warn_undefined, StbImage} 180 | 181 | defp generate_records({fname, image}) do 182 | class_name = 183 | fname 184 | |> List.to_string() 185 | |> String.downcase() 186 | |> String.split("/") 187 | |> Enum.at(1) 188 | |> String.to_atom() 189 | 190 | label = Map.fetch!(@label_mapping, class_name) 191 | {:ok, stb_image} = StbImage.read_binary(image) 192 | 193 | {stb_image, label} 194 | end 195 | end 196 | -------------------------------------------------------------------------------- /lib/scidata/cifar10.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.CIFAR10 do 2 | @moduledoc """ 3 | Module for downloading the [CIFAR10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). 4 | """ 5 | 6 | require Scidata.Utils 7 | alias Scidata.Utils 8 | 9 | @base_url "https://www.cs.toronto.edu/~kriz/" 10 | @dataset_file "cifar-10-binary.tar.gz" 11 | @train_images_shape {50000, 3, 32, 32} 12 | @train_labels_shape {50000} 13 | @test_images_shape {10000, 3, 32, 32} 14 | @test_labels_shape {10000} 15 | 16 | @doc """ 17 | Downloads the CIFAR10 training dataset or fetches it locally. 18 | 19 | Returns a tuple of format: 20 | 21 | {{images_binary, images_type, images_shape}, 22 | {labels_binary, labels_type, labels_shape}} 23 | 24 | If you want to one-hot encode the labels, you can: 25 | 26 | labels_binary 27 | |> Nx.from_binary(labels_type) 28 | |> Nx.new_axis(-1) 29 | |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) 30 | 31 | ## Options. 32 | 33 | * `:base_url` - Dataset base URL. 34 | 35 | Defaults to `"https://www.cs.toronto.edu/~kriz/"` 36 | 37 | * `:dataset_file` - Dataset filename. 38 | 39 | Defaults to `"cifar-10-binary.tar.gz"` 40 | 41 | * `:cache_dir` - Cache directory. 42 | 43 | Defaults to `System.tmp_dir!()` 44 | 45 | ## Examples 46 | 47 | iex> Scidata.CIFAR10.download() 48 | {{<<59, 43, 50, 68, 98, 119, 139, 145, 149, 149, 131, 125, 142, 144, 137, 129, 49 | 137, 134, 124, 139, 139, 133, 136, 139, 152, 163, 168, 159, 158, 158, 152, 50 | 148, 16, 0, 18, 51, 88, 120, 128, 127, 126, 116, 106, 101, 105, 113, 109, 51 | 112, ...>>, {:u, 8}, {50000, 3, 32, 32}}, 52 | {<<6, 9, 9, 4, 1, 1, 2, 7, 8, 3, 4, 7, 7, 2, 9, 9, 9, 3, 2, 6, 4, 3, 6, 6, 2, 53 | 6, 3, 5, 4, 0, 0, 9, 1, 3, 4, 0, 3, 7, 3, 3, 5, 2, 2, 7, 1, 1, 1, ...>>, 54 | {:u, 8}, {50000}}} 55 | 56 | """ 57 | def download(opts \\ []) do 58 | download_dataset(:train, opts) 59 | end 60 | 61 | @doc """ 62 | Downloads the CIFAR10 test dataset or fetches it locally. 63 | 64 | Accepts the same options as `download/1`. 65 | """ 66 | def download_test(opts \\ []) do 67 | download_dataset(:test, opts) 68 | end 69 | 70 | defp parse_images(content) do 71 | {images, labels} = 72 | for <>, reduce: {[], []} do 73 | {images, labels} -> 74 | <> = example 75 | {[image | images], [label | labels]} 76 | end 77 | 78 | {Enum.reverse(images), Enum.reverse(labels)} 79 | end 80 | 81 | defp download_dataset(dataset_type, opts) do 82 | base_url = opts[:base_url] || @base_url 83 | dataset_file = opts[:dataset_file] || @dataset_file 84 | 85 | files = Utils.get!(base_url <> dataset_file, opts).body 86 | 87 | {images, labels} = 88 | files 89 | |> Enum.filter(fn {fname, _} -> 90 | String.match?( 91 | List.to_string(fname), 92 | case dataset_type do 93 | :train -> ~r/data_batch/ 94 | :test -> ~r/test_batch/ 95 | end 96 | ) 97 | end) 98 | |> Enum.map(fn {_, content} -> Task.async(fn -> parse_images(content) end) end) 99 | |> Enum.map(&Task.await(&1, :infinity)) 100 | |> Enum.unzip() 101 | 102 | images = IO.iodata_to_binary(images) 103 | labels = IO.iodata_to_binary(labels) 104 | 105 | {{images, {:u, 8}, 106 | if(dataset_type == :test, do: @test_images_shape, else: @train_images_shape)}, 107 | {labels, {:u, 8}, 108 | if(dataset_type == :test, do: @test_labels_shape, else: @train_labels_shape)}} 109 | end 110 | end 111 | -------------------------------------------------------------------------------- /lib/scidata/cifar100.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.CIFAR100 do 2 | @moduledoc """ 3 | Module for downloading the [CIFAR100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). 4 | """ 5 | 6 | require Scidata.Utils 7 | alias Scidata.Utils 8 | 9 | @base_url "https://www.cs.toronto.edu/~kriz/" 10 | @dataset_file "cifar-100-binary.tar.gz" 11 | @train_images_shape {50000, 3, 32, 32} 12 | @train_labels_shape {50000, 2} 13 | @test_images_shape {10000, 3, 32, 32} 14 | @test_labels_shape {10000, 2} 15 | 16 | @doc """ 17 | Downloads the CIFAR100 training dataset or fetches it locally. 18 | 19 | Returns a tuple of format: 20 | 21 | {{images_binary, images_type, images_shape}, 22 | {labels_binary, labels_type, labels_shape}} 23 | 24 | ## Options. 25 | 26 | * `:base_url` - Dataset base URL. 27 | 28 | Defaults to `"https://www.cs.toronto.edu/~kriz/"` 29 | 30 | * `:dataset_file` - Dataset filename. 31 | 32 | Defaults to `"cifar-100-binary.tar.gz"` 33 | 34 | * `:cache_dir` - Cache directory. 35 | 36 | Defaults to `System.tmp_dir!()` 37 | 38 | ## Examples 39 | 40 | iex> Scidata.CIFAR100.download() 41 | {{<<59, 43, 50, 68, 98, 119, 139, 145, 149, 149, 131, 125, 142, 144, 137, 129, 42 | 137, 134, 124, 139, 139, 133, 136, 139, 152, 163, 168, 159, 158, 158, 152, 43 | 148, 16, 0, 18, 51, 88, 120, 128, 127, 126, 116, 106, 101, 105, 113, 109, 44 | 112, ...>>, {:u, 8}, {50000, 3, 32, 32}}, 45 | {<<6, 9, 9, 4, 1, 1, 2, 7, 8, 3, 4, 7, 7, 2, 9, 9, 9, 3, 2, 6, 4, 3, 6, 6, 2, 46 | 6, 3, 5, 4, 0, 0, 9, 1, 3, 4, 0, 3, 7, 3, 3, 5, 2, 2, 7, 1, 1, 1, ...>>, 47 | {:u, 8}, {50000, 2}}} 48 | 49 | """ 50 | def download(opts \\ []) do 51 | download_dataset(:train, opts) 52 | end 53 | 54 | @doc """ 55 | Downloads the CIFAR100 test dataset or fetches it locally. 56 | 57 | Accepts the same options as `download/1`. 58 | """ 59 | def download_test(opts \\ []) do 60 | download_dataset(:test, opts) 61 | end 62 | 63 | defp parse_images(content) do 64 | {images, labels} = 65 | for <>, reduce: {[], []} do 66 | {images, labels} -> 67 | <> = example 68 | {[image | images], [label | labels]} 69 | end 70 | 71 | {Enum.reverse(images), Enum.reverse(labels)} 72 | end 73 | 74 | defp download_dataset(dataset_type, opts) do 75 | base_url = opts[:base_url] || @base_url 76 | dataset_file = opts[:dataset_file] || @dataset_file 77 | 78 | files = Utils.get!(base_url <> dataset_file, opts).body 79 | 80 | {images, labels} = 81 | files 82 | |> Enum.filter(fn {fname, _} -> 83 | String.match?( 84 | List.to_string(fname), 85 | case dataset_type do 86 | :train -> ~r/train.bin/ 87 | :test -> ~r/test.bin/ 88 | end 89 | ) 90 | end) 91 | |> Enum.map(fn {_, content} -> Task.async(fn -> parse_images(content) end) end) 92 | |> Enum.map(&Task.await(&1, :infinity)) 93 | |> Enum.unzip() 94 | 95 | images = IO.iodata_to_binary(images) 96 | labels = IO.iodata_to_binary(labels) 97 | 98 | {{images, {:u, 8}, 99 | if(dataset_type == :test, do: @test_images_shape, else: @train_images_shape)}, 100 | {labels, {:u, 8}, 101 | if(dataset_type == :test, do: @test_labels_shape, else: @train_labels_shape)}} 102 | end 103 | end 104 | -------------------------------------------------------------------------------- /lib/scidata/fashionmnist.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.FashionMNIST do 2 | @moduledoc """ 3 | Module for downloading the [FashionMNIST dataset](https://github.com/zalandoresearch/fashion-mnist#readme). 4 | """ 5 | 6 | require Scidata.Utils 7 | alias Scidata.Utils 8 | 9 | @base_url "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" 10 | @train_image_file "train-images-idx3-ubyte.gz" 11 | @train_label_file "train-labels-idx1-ubyte.gz" 12 | @test_image_file "t10k-images-idx3-ubyte.gz" 13 | @test_label_file "t10k-labels-idx1-ubyte.gz" 14 | 15 | @doc """ 16 | Downloads the FashionMNIST training dataset or fetches it locally. 17 | 18 | Returns a tuple of format: 19 | 20 | {{images_binary, images_type, images_shape}, 21 | {labels_binary, labels_type, labels_shape}} 22 | 23 | If you want to one-hot encode the labels, you can: 24 | 25 | labels_binary 26 | |> Nx.from_binary(labels_type) 27 | |> Nx.new_axis(-1) 28 | |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) 29 | 30 | ## Options. 31 | 32 | * `:base_url` - Dataset base URL. 33 | 34 | Defaults to `"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"` 35 | 36 | * `:train_image_file` - Training set image filename. 37 | 38 | Defaults to `"train-images-idx3-ubyte.gz"` 39 | 40 | * `:train_label_file` - Training set label filename. 41 | 42 | Defaults to `"train-images-idx1-ubyte.gz"` 43 | 44 | * `:cache_dir` - Cache directory. 45 | 46 | Defaults to `System.tmp_dir!()` 47 | 48 | ## Examples 49 | 50 | iex> Scidata.FashionMNIST.download() 51 | {{<<105, 109, 97, 103, 101, 115, 45, 105, 100, 120, 51, 45, 117, 98, 121, 116, 52 | 101, 0, 236, 253, 7, 88, 84, 201, 215, 232, 11, 23, 152, 38, 57, 51, 166, 53 | 81, 71, 157, 209, 49, 135, 49, 141, 99, 206, 142, 57, 141, 89, 68, ...>>, 54 | {:u, 8}, {3739854681, 226418, 1634299437}}, 55 | {<<0, 3, 116, 114, 97, 105, 110, 45, 108, 97, 98, 101, 108, 115, 45, 105, 100, 56 | 120, 49, 45, 117, 98, 121, 116, 101, 0, 53, 221, 9, 130, 36, 73, 110, 100, 57 | 81, 219, 220, 150, 91, 214, 249, 251, 20, 141, 247, 53, 114, ...>>, {:u, 8}, 58 | {3739854681}}} 59 | 60 | """ 61 | def download(opts \\ []) do 62 | {download_images(:train, opts), download_labels(:train, opts)} 63 | end 64 | 65 | @doc """ 66 | Downloads the FashionMNIST test dataset or fetches it locally. 67 | 68 | ## Options. 69 | 70 | * `:base_url` - Dataset base URL. 71 | 72 | Defaults to `"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"` 73 | 74 | * `:test_image_file` - Test set image filename. 75 | 76 | Defaults to `"t10k-images-idx3-ubyte.gz"` 77 | 78 | * `:test_label_file` - Test set label filename. 79 | 80 | Defaults to `"t10k-labels-idx1-ubyte.gz"` 81 | 82 | * `:cache_dir` - Cache directory. 83 | 84 | Defaults to `System.tmp_dir!()` 85 | 86 | """ 87 | def download_test(opts \\ []) do 88 | {download_images(:test, opts), download_labels(:test, opts)} 89 | end 90 | 91 | defp download_images(:train, opts) do 92 | download_images(opts[:train_image_file] || @train_image_file, opts) 93 | end 94 | 95 | defp download_images(:test, opts) do 96 | download_images(opts[:test_image_file] || @test_image_file, opts) 97 | end 98 | 99 | defp download_images(filename, opts) do 100 | base_url = opts[:base_url] || @base_url 101 | 102 | data = Utils.get!(base_url <> filename, opts).body 103 | <<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = data 104 | {images, {:u, 8}, {n_images, 1, n_rows, n_cols}} 105 | end 106 | 107 | defp download_labels(:train, opts) do 108 | download_labels(opts[:train_label_file] || @train_label_file, opts) 109 | end 110 | 111 | defp download_labels(:test, opts) do 112 | download_labels(opts[:test_label_file] || @test_label_file, opts) 113 | end 114 | 115 | defp download_labels(filename, opts) do 116 | base_url = opts[:base_url] || @base_url 117 | 118 | data = Utils.get!(base_url <> filename, opts).body 119 | <<_::32, n_labels::32, labels::binary>> = data 120 | {labels, {:u, 8}, {n_labels}} 121 | end 122 | end 123 | -------------------------------------------------------------------------------- /lib/scidata/imdb_reviews.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.IMDBReviews do 2 | @moduledoc """ 3 | Module for downloading the [Large Movie Review Dataset](https://ai.stanford.edu/~amaas/data/sentiment/). 4 | """ 5 | 6 | @base_url "http://ai.stanford.edu/~amaas/data/sentiment/" 7 | @dataset_file "aclImdb_v1.tar.gz" 8 | 9 | alias Scidata.Utils 10 | 11 | @type train_sentiment :: :pos | :neg | :unsup 12 | @type test_sentiment :: :pos | :neg 13 | 14 | @doc """ 15 | Downloads the IMDB reviews training dataset or fetches it locally. 16 | 17 | `example_types` specifies which examples in the dataset should be returned 18 | according to each example's label: `:pos` for positive examples, `:neg` for 19 | negative examples, and `:unsup` for unlabeled examples. If no `example_types` 20 | are provided, `:pos` and `:neg` examples are fetched. 21 | 22 | ## Options. 23 | 24 | * `:base_url` - Dataset base URL. 25 | 26 | Defaults to `"http://ai.stanford.edu/~amaas/data/sentiment/"` 27 | 28 | * `:dataset_file` - Dataset filename. 29 | 30 | Defaults to `"aclImdb_v1.tar.gz"` 31 | 32 | * `:cache_dir` - Cache directory. 33 | 34 | Defaults to `System.tmp_dir!()` 35 | 36 | """ 37 | @spec download(example_types: [train_sentiment]) :: %{ 38 | review: [binary(), ...], 39 | sentiment: [1 | 0 | nil] 40 | } 41 | def download(opts \\ []), do: download_dataset(:train, opts) 42 | 43 | @doc """ 44 | Downloads the IMDB reviews test dataset or fetches it locally. 45 | 46 | `example_types` is the same as in `download/1`, but `:unsup` is 47 | unavailable because all unlabeled examples are in the training set. 48 | 49 | Accepts the same options as `download/1`. 50 | """ 51 | @spec download_test(example_types: [test_sentiment]) :: %{ 52 | review: [binary(), ...], 53 | sentiment: [1 | 0] 54 | } 55 | def download_test(opts \\ []), do: download_dataset(:test, opts) 56 | 57 | defp download_dataset(dataset_type, opts) do 58 | example_types = opts[:example_types] || [:pos, :neg] 59 | base_url = opts[:base_url] || @base_url 60 | dataset_file = opts[:dataset_file] || @dataset_file 61 | 62 | files = Utils.get!(base_url <> dataset_file, opts).body 63 | regex = ~r"#{dataset_type}/(#{Enum.join(example_types, "|")})/" 64 | 65 | {inputs, labels} = 66 | for {fname, contents} <- files, 67 | List.to_string(fname) =~ regex, 68 | reduce: {[], []} do 69 | {inputs, labels} -> 70 | {[contents | inputs], [get_label(fname) | labels]} 71 | end 72 | 73 | %{review: inputs, sentiment: labels} 74 | end 75 | 76 | defp get_label(fname) do 77 | fname = List.to_string(fname) 78 | 79 | cond do 80 | fname =~ "pos" -> 1 81 | fname =~ "neg" -> 0 82 | fname =~ "unsup" -> nil 83 | end 84 | end 85 | end 86 | -------------------------------------------------------------------------------- /lib/scidata/iris.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.Iris do 2 | @moduledoc """ 3 | Module for downloading the [Iris Data Set](https://archive.ics.uci.edu/dataset/53/iris). 4 | """ 5 | 6 | @base_url "https://archive.ics.uci.edu/static/public/53/iris.zip" 7 | @dataset_file "iris.data" 8 | @data_hash <<111, 96, 139, 113, 167, 49, 114, 22, 49, 155, 77, 39, 180, 217, 188, 132, 230, 9 | 171, 215, 52, 237, 167, 135, 43, 113, 164, 88, 86, 158, 38, 86, 192>> 10 | 11 | alias Scidata.Utils 12 | 13 | @doc """ 14 | Downloads the Iris dataset or fetches it locally. 15 | 16 | ## Information about the dataset are available in file `iris.names` inside the 17 | [zip file](https://archive.ics.uci.edu/static/public/53/iris.zip). 18 | 19 | ### Attribute 20 | 21 | 1. sepal length in cm 22 | 2. sepal width in cm 23 | 3. petal length in cm 24 | 4. petal width in cm 25 | 26 | ### Label 27 | 28 | * 0: Iris Setosa 29 | * 1: Iris Versicolour 30 | * 2: Iris Virginica 31 | 32 | ## Options. 33 | 34 | * `:base_url` - Dataset base URL. 35 | 36 | Defaults to `"https://archive.ics.uci.edu/static/public/53/iris.zip"` 37 | 38 | * `:dataset_file` - Dataset filename. 39 | 40 | Defaults to `"iris.data"` 41 | 42 | * `:cache_dir` - Cache directory. 43 | 44 | Defaults to `System.tmp_dir!()` 45 | 46 | """ 47 | def download(opts \\ []) do 48 | base_url = opts[:base_url] || @base_url 49 | dataset_file = opts[:dataset_file] || @dataset_file 50 | 51 | # Temporary fix to cope with bad cert on source site 52 | opts = Keyword.put(opts, :ssl_verify, :verify_none) 53 | 54 | [{_, data}] = 55 | Utils.get!(base_url, opts).body 56 | |> Enum.filter(fn {fname, _} -> 57 | String.match?( 58 | List.to_string(fname), 59 | ~r/#{dataset_file}/ 60 | ) 61 | end) 62 | 63 | if :crypto.hash(:sha256, data) != @data_hash do 64 | raise RuntimeError, "Dataset hashed to unexpected value" 65 | end 66 | 67 | data 68 | |> String.split() 69 | |> Enum.reverse() 70 | |> Enum.reduce({[], []}, fn row_str, {feature_acc, label_acc} -> 71 | row = String.split(row_str, ",") 72 | {features, [label]} = Enum.split(row, 4) 73 | 74 | features = 75 | Enum.map(features, fn val -> 76 | {val, ""} = Float.parse(val) 77 | val 78 | end) 79 | 80 | label = get_label(label) 81 | {[features | feature_acc], [label | label_acc]} 82 | end) 83 | end 84 | 85 | defp get_label(label) do 86 | cond do 87 | label =~ "Iris-setosa" -> 0 88 | label =~ "Iris-versicolor" -> 1 89 | label =~ "Iris-virginica" -> 2 90 | end 91 | end 92 | end 93 | -------------------------------------------------------------------------------- /lib/scidata/kuzushiji_mnist.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.KuzushijiMNIST do 2 | @moduledoc """ 3 | Module for downloading the [Kuzushiji-MNIST dataset](https://github.com/rois-codh/kmnist). 4 | """ 5 | 6 | alias Scidata.Utils 7 | 8 | @base_url "http://codh.rois.ac.jp/kmnist/dataset/kmnist/" 9 | @train_image_file "train-images-idx3-ubyte.gz" 10 | @train_label_file "train-labels-idx1-ubyte.gz" 11 | @test_image_file "t10k-images-idx3-ubyte.gz" 12 | @test_label_file "t10k-labels-idx1-ubyte.gz" 13 | 14 | @doc """ 15 | Downloads the Kuzushiji MNIST training dataset or fetches it locally. 16 | 17 | Returns a tuple of format: 18 | 19 | {{images_binary, images_type, images_shape}, 20 | {labels_binary, labels_type, labels_shape}} 21 | 22 | If you want to one-hot encode the labels, you can: 23 | 24 | labels_binary 25 | |> Nx.from_binary(labels_type) 26 | |> Nx.new_axis(-1) 27 | |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) 28 | 29 | ## Options. 30 | 31 | * `:base_url` - Dataset base URL. 32 | 33 | Defaults to `"http://codh.rois.ac.jp/kmnist/dataset/kmnist/"` 34 | 35 | * `:train_image_file` - Training set image filename. 36 | 37 | Defaults to `"train-images-idx3-ubyte.gz"` 38 | 39 | * `:train_label_file` - Training set label filename. 40 | 41 | Defaults to `"train-images-idx1-ubyte.gz"` 42 | 43 | * `:cache_dir` - Cache directory. 44 | 45 | Defaults to `System.tmp_dir!()` 46 | 47 | """ 48 | def download(opts \\ []) do 49 | {download_images(:train, opts), download_labels(:train, opts)} 50 | end 51 | 52 | @doc """ 53 | Downloads the Kuzushiji MNIST test dataset or fetches it locally. 54 | 55 | ## Options. 56 | 57 | * `:base_url` - Dataset base URL. 58 | 59 | Defaults to `"http://codh.rois.ac.jp/kmnist/dataset/kmnist/"` 60 | 61 | * `:test_image_file` - Test set image filename. 62 | 63 | Defaults to `"t10k-images-idx3-ubyte.gz"` 64 | 65 | * `:test_label_file` - Test set label filename. 66 | 67 | Defaults to `"t10k-labels-idx1-ubyte.gz"` 68 | 69 | * `:cache_dir` - Cache directory. 70 | 71 | Defaults to `System.tmp_dir!()` 72 | 73 | """ 74 | def download_test(opts \\ []) do 75 | {download_images(:test, opts), download_labels(:test, opts)} 76 | end 77 | 78 | defp download_images(:train, opts) do 79 | download_images(opts[:train_image_file] || @train_image_file, opts) 80 | end 81 | 82 | defp download_images(:test, opts) do 83 | download_images(opts[:test_image_file] || @test_image_file, opts) 84 | end 85 | 86 | defp download_images(filename, opts) do 87 | base_url = opts[:base_url] || @base_url 88 | 89 | data = Utils.get!(base_url <> filename, opts).body 90 | <<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = data 91 | {images, {:u, 8}, {n_images, 1, n_rows, n_cols}} 92 | end 93 | 94 | defp download_labels(:train, opts) do 95 | download_labels(opts[:train_label_file] || @train_label_file, opts) 96 | end 97 | 98 | defp download_labels(:test, opts) do 99 | download_labels(opts[:test_label_file] || @test_label_file, opts) 100 | end 101 | 102 | defp download_labels(filename, opts) do 103 | base_url = opts[:base_url] || @base_url 104 | 105 | data = Utils.get!(base_url <> filename, opts).body 106 | <<_::32, n_labels::32, labels::binary>> = data 107 | {labels, {:u, 8}, {n_labels}} 108 | end 109 | end 110 | -------------------------------------------------------------------------------- /lib/scidata/mnist.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.MNIST do 2 | @moduledoc """ 3 | Module for downloading the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). 4 | """ 5 | 6 | alias Scidata.Utils 7 | 8 | @base_url "https://storage.googleapis.com/cvdf-datasets/mnist/" 9 | @train_image_file "train-images-idx3-ubyte.gz" 10 | @train_label_file "train-labels-idx1-ubyte.gz" 11 | @test_image_file "t10k-images-idx3-ubyte.gz" 12 | @test_label_file "t10k-labels-idx1-ubyte.gz" 13 | 14 | @doc """ 15 | Downloads the MNIST training dataset or fetches it locally. 16 | 17 | Returns a tuple of format: 18 | 19 | {{images_binary, images_type, images_shape}, 20 | {labels_binary, labels_type, labels_shape}} 21 | 22 | If you want to one-hot encode the labels, you can: 23 | 24 | labels_binary 25 | |> Nx.from_binary(labels_type) 26 | |> Nx.new_axis(-1) 27 | |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) 28 | 29 | ## Options. 30 | 31 | * `:base_url` - Dataset base URL. 32 | 33 | Defaults to `"https://storage.googleapis.com/cvdf-datasets/mnist/"` 34 | 35 | * `:train_image_file` - Training set image filename. 36 | 37 | Defaults to `"train-images-idx3-ubyte.gz"` 38 | 39 | * `:train_label_file` - Training set label filename. 40 | 41 | Defaults to `"train-images-idx1-ubyte.gz"` 42 | 43 | * `:cache_dir` - Cache directory. 44 | 45 | Defaults to `System.tmp_dir!()` 46 | 47 | """ 48 | def download(opts \\ []) do 49 | {download_images(:train, opts), download_labels(:train, opts)} 50 | end 51 | 52 | @doc """ 53 | Downloads the MNIST test dataset or fetches it locally. 54 | 55 | ## Options. 56 | 57 | * `:base_url` - Dataset base URL. 58 | 59 | Defaults to `"https://storage.googleapis.com/cvdf-datasets/mnist/"` 60 | 61 | * `:train_image_file` - Training set image filename. 62 | 63 | Defaults to `"train-images-idx3-ubyte.gz"` 64 | 65 | * `:train_label_file` - Training set label filename. 66 | 67 | Defaults to `"train-images-idx1-ubyte.gz"` 68 | 69 | * `:test_image_file` - Test set image filename. 70 | 71 | Defaults to `"t10k-images-idx3-ubyte.gz"` 72 | 73 | * `:test_label_file` - Test set label filename. 74 | 75 | Defaults to `"t10k-labels-idx1-ubyte.gz"` 76 | 77 | * `:cache_dir` - Cache directory. 78 | 79 | Defaults to `System.tmp_dir!()` 80 | 81 | """ 82 | def download_test(opts \\ []) do 83 | {download_images(:test, opts), download_labels(:test, opts)} 84 | end 85 | 86 | defp download_images(:train, opts) do 87 | download_images(opts[:train_image_file] || @train_image_file, opts) 88 | end 89 | 90 | defp download_images(:test, opts) do 91 | download_images(opts[:test_image_file] || @test_image_file, opts) 92 | end 93 | 94 | defp download_images(filename, opts) do 95 | base_url = opts[:base_url] || @base_url 96 | 97 | data = Utils.get!(base_url <> filename, opts).body 98 | <<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = data 99 | {images, {:u, 8}, {n_images, 1, n_rows, n_cols}} 100 | end 101 | 102 | defp download_labels(:train, opts) do 103 | download_labels(opts[:train_label_file] || @train_label_file, opts) 104 | end 105 | 106 | defp download_labels(:test, opts) do 107 | download_labels(opts[:test_label_file] || @test_label_file, opts) 108 | end 109 | 110 | defp download_labels(filename, opts) do 111 | base_url = opts[:base_url] || @base_url 112 | 113 | data = Utils.get!(base_url <> filename, opts).body 114 | <<_::32, n_labels::32, labels::binary>> = data 115 | {labels, {:u, 8}, {n_labels}} 116 | end 117 | end 118 | -------------------------------------------------------------------------------- /lib/scidata/squad.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.Squad do 2 | @moduledoc """ 3 | Module for downloading the [SQuAD1.1 dataset](https://rajpurkar.github.io/SQuAD-explorer). 4 | """ 5 | 6 | require Scidata.Utils 7 | alias Scidata.Utils 8 | 9 | @base_url "https://rajpurkar.github.io/SQuAD-explorer/dataset/" 10 | @train_dataset_file "train-v1.1.json" 11 | @test_dataset_file "dev-v1.1.json" 12 | 13 | @doc """ 14 | Downloads the SQuAD training dataset 15 | 16 | ## Options. 17 | 18 | * `:base_url` - Dataset base URL. 19 | 20 | Defaults to `"https://rajpurkar.github.io/SQuAD-explorer/dataset/"` 21 | 22 | * `:train_dataset_file` - Training set filename. 23 | 24 | Defaults to `"train-v1.1.json"` 25 | 26 | * `:cache_dir` - Cache directory. 27 | 28 | Defaults to `System.tmp_dir!()` 29 | 30 | ## Examples 31 | 32 | iex> Scidata.Squad.download() 33 | [ 34 | %{ 35 | "paragraphs" => [ 36 | %{ 37 | "context" => "Architecturally, the school has a...", 38 | "qas" => [ 39 | %{ 40 | "answers" => [%{"answer_start" => 515, "text" => "Saint Bernadette Soubirous"}], 41 | "id" => "5733be284776f41900661182", 42 | "question" => "To whom did the..." 43 | }, ... 44 | ] 45 | } 46 | ], 47 | "title" => "University_of_Notre_Dame" 48 | }, ... 49 | ] 50 | """ 51 | 52 | def download(opts \\ []) do 53 | download_dataset(opts[:train_dataset_file] || @train_dataset_file, opts) 54 | end 55 | 56 | @doc """ 57 | Downloads the SQuAD test dataset 58 | 59 | ## Options. 60 | 61 | * `:base_url` - Dataset base URL. 62 | 63 | Defaults to `"https://rajpurkar.github.io/SQuAD-explorer/dataset/"` 64 | 65 | * `:test_dataset_file` - Test set filename. 66 | 67 | Defaults to `"dev-v1.1.json"` 68 | 69 | * `:cache_dir` - Cache directory. 70 | 71 | Defaults to `System.tmp_dir!()` 72 | 73 | ## Examples 74 | 75 | iex> Scidata.Squad.download_test() 76 | [ 77 | %{ 78 | "paragraphs" => [ 79 | %{ 80 | "context" => "Super Bowl 50 was an American football game t...", 81 | "qas" => [ 82 | %{ 83 | "answers" => [ 84 | %{"answer_start" => 177, "text" => "Denver Broncos"},... 85 | ], 86 | "id" => "56be4db0acb8001400a502ec", 87 | "question" => "Which NFL team represented the AFC at Super Bowl 50?" 88 | }, 89 | ] 90 | } 91 | ], 92 | "title" => "Super_Bowl_50" 93 | }, ... 94 | ] 95 | """ 96 | 97 | def download_test(opts \\ []) do 98 | download_dataset(opts[:test_dataset_file] || @test_dataset_file, opts) 99 | end 100 | 101 | defp download_dataset(dataset_name, opts) do 102 | base_url = opts[:base_url] || @base_url 103 | 104 | content = 105 | Utils.get!(base_url <> dataset_name, opts).body 106 | |> Jason.decode!() 107 | 108 | content["data"] 109 | end 110 | 111 | @doc """ 112 | Convert result of `download/0` or `download_test/0` to map for use with [Explorer.DataFrame](https://github.com/elixir-nx/explorer). 113 | 114 | ## Examples 115 | 116 | iex> columns_for_df = Scidata.Squad.download() |> Scidata.Squad.to_columns() 117 | %{ 118 | "answer_start" => [515, ...], 119 | "context" => ["Architecturally, the...", ...], 120 | "id" => ["5733be284776f41900661182", ...], 121 | "question" => ["To whom did the Vir...", ...], 122 | "answer_text" => ["Saint Bernadette Soubirous", ...], 123 | "title" => ["University_of_Notre_Dame", ...] 124 | } 125 | iex> Explorer.DataFrame.from_map(columns_for_df) 126 | #Explorer.DataFrame< 127 | [rows: 87599, columns: 6] 128 | ... 129 | > 130 | """ 131 | 132 | def to_columns(entries) do 133 | table = %{ 134 | "answer_start" => [], 135 | "context" => [], 136 | "id" => [], 137 | "question" => [], 138 | "answer_text" => [], 139 | "title" => [] 140 | } 141 | 142 | for %{"paragraphs" => paragraph, "title" => title} <- entries, 143 | %{"context" => context, "qas" => qas} <- paragraph, 144 | %{"id" => id, "question" => question, "answers" => answers} <- qas, 145 | %{"answer_start" => answer_start, "text" => answer_text} <- answers, 146 | reduce: table do 147 | %{ 148 | "answer_start" => answer_starts, 149 | "context" => contexts, 150 | "id" => ids, 151 | "question" => questions, 152 | "answer_text" => answer_texts, 153 | "title" => titles 154 | } -> 155 | %{ 156 | "answer_start" => [answer_start | answer_starts], 157 | "context" => [context | contexts], 158 | "id" => [id | ids], 159 | "question" => [question | questions], 160 | "answer_text" => [answer_text | answer_texts], 161 | "title" => [title | titles] 162 | } 163 | end 164 | |> Enum.map(fn {key, values} -> {key, :lists.reverse(values)} end) 165 | |> Enum.into(%{}) 166 | end 167 | end 168 | -------------------------------------------------------------------------------- /lib/scidata/utils.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.Utils do 2 | @moduledoc false 3 | 4 | def get!(url, opts \\ []) do 5 | request = %{ 6 | url: url, 7 | headers: [] 8 | } 9 | 10 | request 11 | |> if_modified_since(opts) 12 | |> run!(opts) 13 | |> raise_errors() 14 | |> handle_cache(opts) 15 | |> decode() 16 | |> elem(1) 17 | end 18 | 19 | defp if_modified_since(request, opts) do 20 | case File.stat(cache_path(request, opts)) do 21 | {:ok, stat} -> 22 | value = stat.mtime |> NaiveDateTime.from_erl!() |> format_http_datetime() 23 | update_in(request.headers, &[{'if-modified-since', String.to_charlist(value)} | &1]) 24 | 25 | _ -> 26 | request 27 | end 28 | end 29 | 30 | defp format_http_datetime(datetime) do 31 | Calendar.strftime(datetime, "%a, %d %b %Y %H:%m:%S GMT") 32 | end 33 | 34 | defp run!(request, opts) do 35 | verify = opts[:ssl_verify] || :verify_peer 36 | 37 | http_opts = [ 38 | ssl: [ 39 | verify: verify, 40 | cacertfile: CAStore.file_path(), 41 | customize_hostname_check: [ 42 | match_fun: :public_key.pkix_verify_hostname_match_fun(:https) 43 | ] 44 | ] 45 | ] 46 | 47 | request_opts = [body_format: :binary] 48 | arg = {request.url, request.headers} 49 | 50 | case :httpc.request(:get, arg, http_opts, request_opts) do 51 | {:ok, {{_, status, _}, headers, body}} -> 52 | response = %{status: status, headers: headers, body: body} 53 | {request, response} 54 | 55 | {:error, reason} -> 56 | raise inspect(reason) 57 | end 58 | end 59 | 60 | defp raise_errors({request, response}) do 61 | if response.status >= 400 do 62 | raise "HTTP #{response.status} #{inspect(response.body)}" 63 | else 64 | {request, response} 65 | end 66 | end 67 | 68 | defp decode({request, response}) do 69 | cond do 70 | String.ends_with?(request.url, ".tar.gz") or String.ends_with?(request.url, ".tgz") -> 71 | {:ok, files} = :erl_tar.extract({:binary, response.body}, [:memory, :compressed]) 72 | response = %{response | body: files} 73 | {request, response} 74 | 75 | String.ends_with?(request.url, ".zip") -> 76 | {:ok, files} = :zip.extract(response.body, [:memory]) 77 | response = %{response | body: files} 78 | {request, response} 79 | 80 | Path.extname(request.url) == ".gz" -> 81 | body = :zlib.gunzip(response.body) 82 | response = %{response | body: body} 83 | {request, response} 84 | 85 | true -> 86 | {request, response} 87 | end 88 | end 89 | 90 | defp handle_cache({request, response}, opts) do 91 | path = cache_path(request, opts) 92 | 93 | if response.status == 304 do 94 | # Logger.debug(["loading cached ", path]) 95 | response = %{response | body: File.read!(path)} 96 | {request, response} 97 | else 98 | # Logger.debug(["writing cache ", path]) 99 | File.write!(path, response.body) 100 | {request, response} 101 | end 102 | end 103 | 104 | defp cache_path(request, opts) do 105 | uri = URI.parse(request.url) 106 | path = Enum.join([uri.host, String.replace(uri.path, "/", "-")], "-") 107 | cache_dir = opts[:cache_dir] || System.tmp_dir!() 108 | File.mkdir_p!(cache_dir) 109 | Path.join(cache_dir, path) 110 | end 111 | end 112 | -------------------------------------------------------------------------------- /lib/scidata/wine.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.Wine do 2 | @moduledoc """ 3 | Module for downloading the [Wine Data Set](https://archive.ics.uci.edu/dataset/109/wine). 4 | """ 5 | 6 | @base_url "https://archive.ics.uci.edu/static/public/109/wine.zip" 7 | @dataset_file "wine.data" 8 | @data_hash <<107, 230, 177, 32, 63, 61, 81, 223, 11, 85, 58, 112, 229, 123, 138, 114, 60, 9 | 212, 5, 104, 57, 88, 32, 79, 150, 210, 61, 124, 214, 174, 166, 89>> 10 | 11 | alias Scidata.Utils 12 | 13 | @doc """ 14 | Downloads the Wine dataset or fetches it locally. 15 | 16 | ## Information about the dataset are available in file `iris.names` inside the 17 | [zip file](https://archive.ics.uci.edu/static/public/109/wine.zip). 18 | 19 | ### Attribute 20 | 21 | 1. Alcohol 22 | 2. Malic acid 23 | 3. Ash 24 | 4. Alcalinity of ash 25 | 5. Magnesium 26 | 6. Total phenols 27 | 7. Flavanoids 28 | 8. Nonflavanoid phenols 29 | 9. Proanthocyanins 30 | 10. Color intensity 31 | 11. Hue 32 | 12. OD280/OD315 of diluted wines 33 | 13. Proline 34 | 35 | ### Label 36 | 37 | * 0 38 | * 1 39 | * 2 40 | 41 | ## Options. 42 | 43 | * `:base_url` - Dataset base URL. 44 | 45 | Defaults to `"https://archive.ics.uci.edu/static/public/109/wine.zip"` 46 | 47 | * `:dataset_file` - Dataset filename. 48 | 49 | Defaults to `"wine.data"` 50 | 51 | * `:cache_dir` - Cache directory. 52 | 53 | Defaults to `System.tmp_dir!()` 54 | 55 | """ 56 | def download(opts \\ []) do 57 | base_url = opts[:base_url] || @base_url 58 | dataset_file = opts[:dataset_file] || @dataset_file 59 | 60 | # Temporary fix to cope with bad cert on source site 61 | opts = Keyword.put(opts, :ssl_verify, :verify_none) 62 | 63 | [{_, data}] = 64 | Utils.get!(base_url, opts).body 65 | |> Enum.filter(fn {fname, _} -> 66 | String.match?( 67 | List.to_string(fname), 68 | ~r/#{dataset_file}/ 69 | ) 70 | end) 71 | 72 | if :crypto.hash(:sha256, data) != @data_hash do 73 | raise RuntimeError, "Dataset hashed to unexpected value" 74 | end 75 | 76 | label_attr = 77 | data 78 | |> String.split() 79 | |> Enum.map(&String.split(&1, ",")) 80 | |> Enum.map(fn row -> 81 | [label | val_list] = row 82 | label = String.to_integer(label) 83 | 84 | val_list = 85 | Enum.map(val_list, fn val -> 86 | {val, ""} = Float.parse("0" <> val) 87 | val 88 | end) 89 | 90 | [label - 1 | val_list] 91 | end) 92 | 93 | labels = Enum.map(label_attr, &hd(&1)) 94 | attributes = Enum.map(label_attr, &tl(&1)) 95 | {attributes, labels} 96 | end 97 | end 98 | -------------------------------------------------------------------------------- /lib/scidata/yelp_full_reviews.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.YelpFullReviews do 2 | @moduledoc """ 3 | Module for downloading the [Yelp Reviews dataset](https://www.yelp.com/dataset). 4 | """ 5 | 6 | @base_url "https://s3.amazonaws.com/fast-ai-nlp/" 7 | 8 | @dataset_file "yelp_review_full_csv.tgz" 9 | 10 | alias Scidata.Utils 11 | alias NimbleCSV.RFC4180, as: CSV 12 | 13 | @doc """ 14 | Downloads the Yelp Reviews training dataset or fetches it locally. 15 | 16 | ## Options. 17 | 18 | * `:base_url` - Dataset base URL. 19 | 20 | Defaults to `"https://s3.amazonaws.com/fast-ai-nlp/"` 21 | 22 | * `:dataset_file` - Dataset filename. 23 | 24 | Defaults to `"yelp_review_full_csv.tgz"` 25 | 26 | * `:cache_dir` - Cache directory. 27 | 28 | Defaults to `System.tmp_dir!()` 29 | 30 | """ 31 | @spec download(Keyword.t()) :: %{review: [binary(), ...], rating: [5 | 4 | 3 | 2 | 1]} 32 | def download(opts \\ []), do: download_dataset(:train, opts) 33 | 34 | @doc """ 35 | Downloads the Yelp Reviews test dataset or fetches it locally. 36 | 37 | Accepts the same options as `download/1`. 38 | """ 39 | @spec download_test(Keyword.t()) :: %{ 40 | review: [binary(), ...], 41 | rating: [5 | 4 | 3 | 2 | 1] 42 | } 43 | def download_test(opts \\ []), do: download_dataset(:test, opts) 44 | 45 | defp download_dataset(dataset_type, opts) do 46 | base_url = opts[:base_url] || @base_url 47 | dataset_file = opts[:dataset_file] || @dataset_file 48 | 49 | files = Utils.get!(base_url <> dataset_file, opts).body 50 | regex = ~r"#{dataset_type}" 51 | 52 | records = 53 | for {fname, contents} <- files, 54 | List.to_string(fname) =~ regex, 55 | reduce: [[]] do 56 | _ -> CSV.parse_string(contents, skip_headers: false) 57 | end 58 | 59 | %{ 60 | review: records |> Enum.map(&List.last(&1)), 61 | rating: records |> Enum.map(fn x -> x |> List.first() |> String.to_integer() end) 62 | } 63 | end 64 | end 65 | -------------------------------------------------------------------------------- /lib/scidata/yelp_polarity_reviews.ex: -------------------------------------------------------------------------------- 1 | defmodule Scidata.YelpPolarityReviews do 2 | @moduledoc """ 3 | Module for downloading the [Yelp Polarity Reviews dataset](https://course.fast.ai/datasets#nlp). 4 | """ 5 | 6 | @base_url "https://s3.amazonaws.com/fast-ai-nlp/" 7 | 8 | @dataset_file "yelp_review_polarity_csv.tgz" 9 | 10 | alias Scidata.Utils 11 | alias NimbleCSV.RFC4180, as: CSV 12 | 13 | @doc """ 14 | Downloads the Yelp Polarity Reviews training dataset or fetches it locally. 15 | 16 | ## Options. 17 | 18 | * `:base_url` - Dataset base URL. 19 | 20 | Defaults to `"https://s3.amazonaws.com/fast-ai-nlp/"` 21 | 22 | * `:dataset_file` - Dataset filename. 23 | 24 | Defaults to `"yelp_review_polarity_csv.tgz"` 25 | 26 | * `:cache_dir` - Cache directory. 27 | 28 | Defaults to `System.tmp_dir!()` 29 | 30 | """ 31 | @spec download(Keyword.t()) :: %{review: [binary(), ...], sentiment: [1 | 0]} 32 | def download(opts \\ []), do: download_dataset(:train, opts) 33 | 34 | @doc """ 35 | Downloads the Yelp Polarity Reviews test dataset or fetches it locally. 36 | 37 | Accepts the same options as `download/1`. 38 | """ 39 | @spec download_test(Keyword.t()) :: %{ 40 | review: [binary(), ...], 41 | sentiment: [1 | 0] 42 | } 43 | def download_test(opts \\ []), do: download_dataset(:test, opts) 44 | 45 | defp download_dataset(dataset_type, opts) do 46 | base_url = opts[:base_url] || @base_url 47 | dataset_file = opts[:dataset_file] || @dataset_file 48 | 49 | files = Utils.get!(base_url <> dataset_file, opts).body 50 | regex = ~r"#{dataset_type}" 51 | 52 | records = 53 | for {fname, contents} <- files, 54 | List.to_string(fname) =~ regex, 55 | reduce: [[]] do 56 | _ -> CSV.parse_string(contents, skip_headers: false) 57 | end 58 | 59 | %{ 60 | review: records |> Enum.map(&List.last(&1)), 61 | sentiment: get_rating(records) 62 | } 63 | end 64 | 65 | defp get_rating(records) do 66 | Enum.map(records, fn 67 | ["1" | _] -> 0 68 | ["2" | _] -> 1 69 | end) 70 | end 71 | end 72 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule Scidata.MixProject do 2 | use Mix.Project 3 | 4 | @version "0.1.11" 5 | @repo_url "https://github.com/elixir-nx/scidata" 6 | 7 | def project do 8 | [ 9 | app: :scidata, 10 | version: @version, 11 | elixir: "~> 1.11", 12 | start_permanent: Mix.env() == :prod, 13 | deps: deps(), 14 | 15 | # Hex 16 | package: package(), 17 | description: "Datasets for science", 18 | 19 | # Docs 20 | name: "Scidata", 21 | docs: docs() 22 | ] 23 | end 24 | 25 | def application do 26 | [ 27 | extra_applications: [:logger, :ssl, :inets] 28 | ] 29 | end 30 | 31 | defp deps do 32 | [ 33 | {:ex_doc, ">= 0.24.0", only: :dev, runtime: false}, 34 | {:nimble_csv, "~> 1.1"}, 35 | {:jason, "~> 1.0"}, 36 | {:stb_image, "~> 0.4", optional: true}, 37 | {:castore, "~> 0.1"} 38 | ] 39 | end 40 | 41 | defp package do 42 | [ 43 | licenses: ["Apache-2.0"], 44 | links: %{"GitHub" => @repo_url} 45 | ] 46 | end 47 | 48 | defp docs do 49 | [ 50 | source_ref: "v#{@version}", 51 | source_url: @repo_url, 52 | groups_for_modules: [ 53 | Text: [ 54 | Scidata.IMDBReviews, 55 | Scidata.Squad, 56 | Scidata.YelpFullReviews, 57 | Scidata.YelpPolarityReviews 58 | ], 59 | Vision: [ 60 | Scidata.Caltech101, 61 | Scidata.CIFAR10, 62 | Scidata.CIFAR100, 63 | Scidata.FashionMNIST, 64 | Scidata.KuzushijiMNIST, 65 | Scidata.MNIST 66 | ], 67 | Misc: [ 68 | Scidata.Iris, 69 | Scidata.Wine 70 | ] 71 | ] 72 | ] 73 | end 74 | end 75 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "castore": {:hex, :castore, "0.1.17", "ba672681de4e51ed8ec1f74ed624d104c0db72742ea1a5e74edbc770c815182f", [:mix], [], "hexpm", "d9844227ed52d26e7519224525cb6868650c272d4a3d327ce3ca5570c12163f9"}, 3 | "earmark_parser": {:hex, :earmark_parser, "1.4.26", "f4291134583f373c7d8755566122908eb9662df4c4b63caa66a0eabe06569b0a", [:mix], [], "hexpm", "48d460899f8a0c52c5470676611c01f64f3337bad0b26ddab43648428d94aabc"}, 4 | "elixir_make": {:hex, :elixir_make, "0.6.3", "bc07d53221216838d79e03a8019d0839786703129599e9619f4ab74c8c096eac", [:mix], [], "hexpm", "f5cbd651c5678bcaabdbb7857658ee106b12509cd976c2c2fca99688e1daf716"}, 5 | "ex_doc": {:hex, :ex_doc, "0.28.4", "001a0ea6beac2f810f1abc3dbf4b123e9593eaa5f00dd13ded024eae7c523298", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "bf85d003dd34911d89c8ddb8bda1a958af3471a274a4c2150a9c01c78ac3f8ed"}, 6 | "jason": {:hex, :jason, "1.3.0", "fa6b82a934feb176263ad2df0dbd91bf633d4a46ebfdffea0c8ae82953714946", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "53fc1f51255390e0ec7e50f9cb41e751c260d065dcba2bf0d08dc51a4002c2ac"}, 7 | "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, 8 | "makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"}, 9 | "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, 10 | "nimble_csv": {:hex, :nimble_csv, "1.2.0", "4e26385d260c61eba9d4412c71cea34421f296d5353f914afe3f2e71cce97722", [:mix], [], "hexpm", "d0628117fcc2148178b034044c55359b26966c6eaa8e2ce15777be3bbc91b12a"}, 11 | "nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"}, 12 | "stb_image": {:hex, :stb_image, "0.5.2", "1751555da3a401c538a34de4bc9bc4d7c888be2013dd34e8ab03ff23b95a485a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "74c22e158dbc666cb50c64f53b99ae7205ae785ff4cf3f42cd5f7bc6fa80d1d6"}, 13 | } 14 | -------------------------------------------------------------------------------- /test/caltech101_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Caltech101Test do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | {{_images, {:u, 8}, shapes}, {labels, {:u, 8}, labels_shape}} = 9 | Scidata.Caltech101.download() 10 | 11 | assert length(shapes) == elem(labels_shape, 0) 12 | assert byte_size(labels) == elem(labels_shape, 0) 13 | end 14 | end 15 | end 16 | -------------------------------------------------------------------------------- /test/cifar100_test.exs: -------------------------------------------------------------------------------- 1 | defmodule CIFAR100 do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | {{_images, {:u, 8}, {n_images, n_channels, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels, n_classes}}} = 9 | Scidata.CIFAR100.download() 10 | 11 | assert n_images == 50000 12 | assert n_channels == 3 13 | assert n_rows == 32 14 | assert n_cols == 32 15 | assert n_labels == 50000 16 | assert n_classes == 2 17 | end 18 | 19 | test "retrieves test set" do 20 | {{_images, {:u, 8}, {n_images, n_channels, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels, n_classes}}} = 21 | Scidata.CIFAR100.download_test() 22 | 23 | assert n_images == 10000 24 | assert n_channels == 3 25 | assert n_rows == 32 26 | assert n_cols == 32 27 | assert n_labels == 10000 28 | assert n_classes == 2 29 | end 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /test/cifar10_test.exs: -------------------------------------------------------------------------------- 1 | defmodule CIFAR10 do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | {{_images, {:u, 8}, {n_images, n_channels, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} = 9 | Scidata.CIFAR10.download() 10 | 11 | assert n_images == 50000 12 | assert n_channels == 3 13 | assert n_rows == 32 14 | assert n_cols == 32 15 | assert n_labels == 50000 16 | end 17 | 18 | test "retrieves test set" do 19 | {{_images, {:u, 8}, {n_images, n_channels, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} = 20 | Scidata.CIFAR10.download_test() 21 | 22 | assert n_images == 10000 23 | assert n_channels == 3 24 | assert n_rows == 32 25 | assert n_cols == 32 26 | assert n_labels == 10000 27 | end 28 | end 29 | end 30 | -------------------------------------------------------------------------------- /test/fashionmnist_test.exs: -------------------------------------------------------------------------------- 1 | defmodule FashionMNISTTest do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | {{_images, {:u, 8}, {n_images, 1, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} = 9 | Scidata.FashionMNIST.download() 10 | 11 | assert n_images == 60000 12 | assert n_rows == 28 13 | assert n_cols == 28 14 | assert n_labels == 60000 15 | end 16 | 17 | test "retrieves test set" do 18 | {{_images, {:u, 8}, {n_images, 1, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} = 19 | Scidata.FashionMNIST.download_test() 20 | 21 | assert n_images == 10000 22 | assert n_rows == 28 23 | assert n_cols == 28 24 | assert n_labels == 10000 25 | end 26 | end 27 | end 28 | -------------------------------------------------------------------------------- /test/imdb_reviews_test.exs: -------------------------------------------------------------------------------- 1 | defmodule IMDBReviewsTest do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | %{review: train_inputs, sentiment: train_targets} = Scidata.IMDBReviews.download() 9 | 10 | assert length(train_inputs) == 25000 11 | assert length(train_targets) == 25000 12 | 13 | %{review: train_inputs, sentiment: train_targets} = 14 | Scidata.IMDBReviews.download(example_types: [:pos, :neg]) 15 | 16 | assert length(train_inputs) == 25000 17 | assert length(train_targets) == 25000 18 | 19 | %{review: train_inputs, sentiment: train_targets} = 20 | Scidata.IMDBReviews.download(example_types: [:pos, :neg, :unsup]) 21 | 22 | assert length(train_inputs) == 75000 23 | assert length(train_targets) == 75000 24 | end 25 | 26 | test "retrieves test set" do 27 | %{review: test_inputs, sentiment: test_targets} = 28 | Scidata.IMDBReviews.download_test(example_types: [:pos, :neg]) 29 | 30 | assert length(test_inputs) == 25000 31 | assert length(test_targets) == 25000 32 | assert [0, 0, 0, 0, 0] = Enum.take(test_targets, -5) 33 | end 34 | 35 | test "examples are expected" do 36 | clip = fn example -> String.slice(example, 0..20) end 37 | 38 | %{review: reviews, sentiment: targets} = 39 | Scidata.IMDBReviews.download(example_types: [:pos], transform_inputs: clip) 40 | 41 | clipped_reviews = reviews |> Enum.take(10) |> Enum.map(&clip.(&1)) 42 | 43 | assert clipped_reviews == [ 44 | "The story centers aro", 45 | "'The Adventures Of Ba", 46 | "This film and it's se", 47 | "I love this movie lik", 48 | "A hit at the time but", 49 | "Very smart, sometimes", 50 | "With the mixed review", 51 | "This movie really kic", 52 | "I'd always wanted Dav", 53 | "Like I said its a hid" 54 | ] 55 | 56 | assert Enum.take(targets, 10) == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 57 | end 58 | end 59 | end 60 | -------------------------------------------------------------------------------- /test/iris_test.exs: -------------------------------------------------------------------------------- 1 | defmodule IrisTest do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | {features, labels} = Scidata.Iris.download() 9 | 10 | assert length(labels) == 150 11 | assert length(features) == length(labels) 12 | 13 | assert labels |> Enum.uniq() |> Enum.sort() == [0, 1, 2] 14 | assert features |> Enum.map(&length(&1)) |> Enum.uniq() |> Enum.sort() == [4] 15 | end 16 | end 17 | end 18 | -------------------------------------------------------------------------------- /test/kuzushiji_mnist_test.exs: -------------------------------------------------------------------------------- 1 | defmodule KuzushijiMNISTTest do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | {{_images, {:u, 8}, {n_images, 1, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} = 9 | Scidata.KuzushijiMNIST.download() 10 | 11 | assert n_images == 60000 12 | assert n_rows == 28 13 | assert n_cols == 28 14 | assert n_labels == 60000 15 | end 16 | 17 | test "retrieves test set" do 18 | {{_images, {:u, 8}, {n_images, 1, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} = 19 | Scidata.KuzushijiMNIST.download_test() 20 | 21 | assert n_images == 10000 22 | assert n_rows == 28 23 | assert n_cols == 28 24 | assert n_labels == 10000 25 | end 26 | end 27 | end 28 | -------------------------------------------------------------------------------- /test/mnist_test.exs: -------------------------------------------------------------------------------- 1 | defmodule MNISTTest do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | {{_images, {:u, 8}, {n_images, 1, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} = 9 | Scidata.MNIST.download() 10 | 11 | assert n_images == 60000 12 | assert n_rows == 28 13 | assert n_cols == 28 14 | assert n_labels == 60000 15 | end 16 | 17 | test "retrieves test set" do 18 | {{_images, {:u, 8}, {n_images, 1, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} = 19 | Scidata.MNIST.download_test() 20 | 21 | assert n_images == 10000 22 | assert n_rows == 28 23 | assert n_cols == 28 24 | assert n_labels == 10000 25 | end 26 | end 27 | end 28 | -------------------------------------------------------------------------------- /test/squad_test.exs: -------------------------------------------------------------------------------- 1 | defmodule SquadTest do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download/0" do 7 | test "retrieves training set" do 8 | examples = Scidata.Squad.download() 9 | 10 | assert length(examples) == 442 11 | 12 | first_example = hd(examples) 13 | last_example = List.last(examples) 14 | 15 | assert first_example["title"] == "University_of_Notre_Dame" 16 | assert length(first_example["paragraphs"]) == 55 17 | 18 | assert last_example["title"] == "Kathmandu" 19 | assert length(last_example["paragraphs"]) == 58 20 | end 21 | end 22 | 23 | describe "download_test/0" do 24 | test "retrieves test set" do 25 | examples = Scidata.Squad.download_test() 26 | 27 | assert length(examples) == 48 28 | 29 | first_example = hd(examples) 30 | last_example = List.last(examples) 31 | 32 | assert first_example["title"] == "Super_Bowl_50" 33 | assert length(first_example["paragraphs"]) == 54 34 | 35 | assert last_example["title"] == "Force" 36 | assert length(last_example["paragraphs"]) == 44 37 | end 38 | end 39 | 40 | describe "to_columns/1" do 41 | test "returns full map for each dataset" do 42 | train_map = Scidata.Squad.download() |> Scidata.Squad.to_columns() 43 | 44 | assert train_map |> Map.keys() |> Enum.sort() == [ 45 | "answer_start", 46 | "answer_text", 47 | "context", 48 | "id", 49 | "question", 50 | "title" 51 | ] 52 | 53 | Enum.each(train_map, fn {_k, entries} -> 54 | assert length(entries) == 87599 55 | end) 56 | end 57 | end 58 | end 59 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | -------------------------------------------------------------------------------- /test/wine_test.exs: -------------------------------------------------------------------------------- 1 | defmodule WineTest do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | {features, labels} = Scidata.Wine.download() 9 | 10 | assert length(labels) == 178 11 | assert length(features) == length(labels) 12 | 13 | assert labels |> Enum.uniq() |> Enum.sort() == [0, 1, 2] 14 | assert features |> Enum.map(&length(&1)) |> Enum.uniq() |> Enum.sort() == [13] 15 | end 16 | end 17 | end 18 | -------------------------------------------------------------------------------- /test/yelp_full_reviews_test.exs: -------------------------------------------------------------------------------- 1 | defmodule YelpFullReviewsTest do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | %{review: train_inputs, rating: train_targets} = Scidata.YelpFullReviews.download() 9 | 10 | assert length(train_inputs) == 650_000 11 | assert length(train_targets) == 650_000 12 | assert train_targets |> Enum.uniq() |> Enum.sort() == [1, 2, 3, 4, 5] 13 | end 14 | 15 | test "retrieves test set" do 16 | %{review: test_inputs, rating: test_targets} = Scidata.YelpFullReviews.download_test() 17 | 18 | assert length(test_inputs) == 50000 19 | assert length(test_targets) == 50000 20 | assert test_targets |> Enum.uniq() |> Enum.sort() == [1, 2, 3, 4, 5] 21 | end 22 | end 23 | end 24 | -------------------------------------------------------------------------------- /test/yelp_polarity_reviews_test.exs: -------------------------------------------------------------------------------- 1 | defmodule YelpPolarityReviewsTest do 2 | use ExUnit.Case 3 | 4 | @moduletag timeout: 120_000 5 | 6 | describe "download" do 7 | test "retrieves training set" do 8 | %{review: train_inputs, sentiment: train_targets} = Scidata.YelpPolarityReviews.download() 9 | 10 | assert length(train_inputs) == 560_000 11 | assert length(train_targets) == 560_000 12 | assert train_targets |> Enum.uniq() |> Enum.sort() == [0, 1] 13 | end 14 | 15 | test "retrieves test set" do 16 | %{review: test_inputs, sentiment: test_targets} = 17 | Scidata.YelpPolarityReviews.download_test() 18 | 19 | assert length(test_inputs) == 38000 20 | assert length(test_targets) == 38000 21 | assert test_targets |> Enum.uniq() |> Enum.sort() == [0, 1] 22 | end 23 | end 24 | end 25 | --------------------------------------------------------------------------------