├── .formatter.exs ├── .gitignore ├── README.md ├── lib ├── cifar10.ex ├── fashionmnist.ex ├── mnist.ex └── utils.ex ├── mix.exs ├── mix.lock └── test ├── axon_datasets_test.exs └── test_helper.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | axon_datasets-*.tar 24 | 25 | 26 | # Temporary files for e.g. tests 27 | /tmp 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Axon Datasets 2 | 3 | Datasets have moved to [Scidata](https://github.com/elixir-nx/scidata) 4 | -------------------------------------------------------------------------------- /lib/cifar10.ex: -------------------------------------------------------------------------------- 1 | defmodule AxonDatasets.CIFAR10 do 2 | alias AxonDatasets.Utils 3 | 4 | @default_data_path "tmp/cifar10" 5 | @base_url 'https://www.cs.toronto.edu/~kriz/' 6 | @dataset_file 'cifar-10-binary.tar.gz' 7 | 8 | defp parse_images(content) do 9 | for <>, reduce: {<<>>, <<>>} do 10 | {images, labels} -> 11 | <> = example 12 | 13 | {images <> image, labels <> label} 14 | end 15 | end 16 | 17 | @doc """ 18 | Downloads the CIFAR10 dataset or fetches it locally. 19 | ## Options 20 | * `datapath` - path where the dataset .gz should be stored locally 21 | * `transform_images/1` - accepts accept a tuple like 22 | `{binary_data, tensor_type, data_shape}` which can be used for 23 | converting the `binary_data` to a tensor with a function like 24 | fn {labels_binary, type, _shape} -> 25 | labels_binary 26 | |> Nx.from_binary(type) 27 | |> Nx.new_axis(-1) 28 | |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) 29 | |> Nx.to_batched_list(32) 30 | end 31 | * `transform_labels/1` - similar to `transform_images/1` but applied to 32 | dataset labels 33 | 34 | Examples: 35 | iex> AxonDatasets.CIFAR10.download() 36 | Fetching cifar-10-binary.tar.gz from https://www.cs.toronto.edu/~kriz/ 37 | 38 | {{<<59, 43, 50, 68, 98, 119, 139, 145, 149, 149, 131, 125, 142, 144, 137, 129, 39 | 137, 134, 124, 139, 139, 133, 136, 139, 152, 163, 168, 159, 158, 158, 152, 40 | 148, 16, 0, 18, 51, 88, 120, 128, 127, 126, 116, 106, 101, 105, 113, 109, 41 | 112, ...>>, {:u, 8}, {50000, 3, 32, 32}}, 42 | {<<6, 9, 9, 4, 1, 1, 2, 7, 8, 3, 4, 7, 7, 2, 9, 9, 9, 3, 2, 6, 4, 3, 6, 6, 2, 43 | 6, 3, 5, 4, 0, 0, 9, 1, 3, 4, 0, 3, 7, 3, 3, 5, 2, 2, 7, 1, 1, 1, ...>>, 44 | {:u, 8}, {50000}}} 45 | """ 46 | def download(opts \\ []) do 47 | data_path = opts[:data_path] || @default_data_path 48 | transform_images = opts[:transform_images] || fn out -> out end 49 | transform_labels = opts[:transform_labels] || fn out -> out end 50 | 51 | gz = Utils.unzip_cache_or_download(@base_url, @dataset_file, data_path) 52 | 53 | with {:ok, files} <- :erl_tar.extract({:binary, gz}, [:memory, :compressed]) do 54 | {imgs, labels} = 55 | files 56 | |> Enum.filter(fn {fname, _} -> String.match?(List.to_string(fname), ~r/data_batch/) end) 57 | |> Enum.map(fn {_, content} -> Task.async(fn -> parse_images(content) end) end) 58 | |> Enum.map(&Task.await(&1, :infinity)) 59 | |> Enum.reduce({<<>>, <<>>}, fn {image, label}, {image_acc, label_acc} -> 60 | {image_acc <> image, label_acc <> label} 61 | end) 62 | 63 | {transform_images.({imgs, {:u, 8}, {50000, 3, 32, 32}}), 64 | transform_labels.({labels, {:u, 8}, {50000}})} 65 | end 66 | end 67 | end 68 | -------------------------------------------------------------------------------- /lib/fashionmnist.ex: -------------------------------------------------------------------------------- 1 | defmodule AxonDatasets.FashionMNIST do 2 | alias AxonDatasets.Utils 3 | 4 | @default_data_path "tmp/fashionmnist" 5 | @base_url 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' 6 | @image_file 'train-images-idx3-ubyte.gz' 7 | @label_file 'train-labels-idx1-ubyte.gz' 8 | 9 | defp download_images(opts) do 10 | data_path = opts[:data_path] || @default_data_path 11 | transform = opts[:transform_images] || fn out -> out end 12 | 13 | <<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = 14 | Utils.unzip_cache_or_download(@base_url, @image_file, data_path) 15 | 16 | transform.({images, {:u, 8}, {n_images, n_rows, n_cols}}) 17 | end 18 | 19 | defp download_labels(opts) do 20 | data_path = opts[:data_path] || @default_data_path 21 | transform = opts[:transform_labels] || fn out -> out end 22 | 23 | <<_::32, n_labels::32, labels::binary>> = 24 | Utils.unzip_cache_or_download(@base_url, @label_file, data_path) 25 | 26 | transform.({labels, {:u, 8}, {n_labels}}) 27 | end 28 | 29 | @doc """ 30 | Downloads the FashionMNIST dataset or fetches it locally. 31 | ## Options 32 | * `datapath` - path where the dataset .gz should be stored locally 33 | * `transform_images/1` - accepts accept a tuple like 34 | `{binary_data, tensor_type, data_shape}` which can be used for 35 | converting the `binary_data` to a tensor with a function like 36 | fn {labels_binary, type, _shape} -> 37 | labels_binary 38 | |> Nx.from_binary(type) 39 | |> Nx.new_axis(-1) 40 | |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) 41 | |> Nx.to_batched_list(32) 42 | end 43 | * `transform_labels/1` - similar to `transform_images/1` but applied to 44 | dataset labels 45 | 46 | Examples: 47 | iex> AxonDatasets.FashionMNIST.download() 48 | Fetching train-images-idx3-ubyte.gz from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/ 49 | 50 | Fetching train-labels-idx1-ubyte.gz from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/ 51 | 52 | {{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 53 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>, 54 | {:u, 8}, {60000, 28, 28}}, 55 | {<<9, 0, 0, 3, 0, 2, 7, 2, 5, 5, 0, 9, 5, 5, 7, 9, 1, 0, 6, 4, 3, 1, 4, 8, 4, 56 | 3, 0, 2, 4, 4, 5, 3, 6, 6, 0, 8, 5, 2, 1, 6, 6, 7, 9, 5, 9, 2, 7, ...>>, 57 | {:u, 8}, {60000}}} 58 | 59 | iex> transform_labels = fn {labels_binary, type, _shape} -> 60 | iex> labels_binary 61 | iex> |> Nx.from_binary(type) 62 | iex> |> Nx.new_axis(-1) 63 | iex> |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) 64 | iex> |> Nx.to_batched_list(32) 65 | iex> end 66 | #Function<7.126501267/1 in :erl_eval.expr/5> 67 | iex> AxonDatasets.FashionMNIST.download(transform_labels: transform_labels) 68 | Using train-images-idx3-ubyte.gz from tmp/fashionmnist 69 | 70 | Using train-labels-idx1-ubyte.gz from tmp/fashionmnist 71 | 72 | {{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 73 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>, 74 | {:u, 8}, {60000, 28, 28}}, #Nx.Tensor< 75 | u8[60000][10] 76 | [ 77 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 78 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 79 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 80 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], 81 | [1, 0, 0, 0, 0, 0, 0, 0, ...], 82 | ... 83 | ] 84 | >} 85 | """ 86 | def download(opts \\ []), 87 | do: {download_images(opts), download_labels(opts)} 88 | end 89 | -------------------------------------------------------------------------------- /lib/mnist.ex: -------------------------------------------------------------------------------- 1 | defmodule AxonDatasets.MNIST do 2 | alias AxonDatasets.Utils 3 | 4 | @default_data_path "tmp/mnist" 5 | @base_url 'https://storage.googleapis.com/cvdf-datasets/mnist/' 6 | @image_file 'train-images-idx3-ubyte.gz' 7 | @label_file 'train-labels-idx1-ubyte.gz' 8 | 9 | defp download_images(opts) do 10 | data_path = opts[:data_path] || @default_data_path 11 | transform = opts[:transform_images] || fn out -> out end 12 | 13 | <<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = 14 | Utils.unzip_cache_or_download(@base_url, @image_file, data_path) 15 | 16 | transform.({images, {:u, 8}, {n_images, n_rows, n_cols}}) 17 | end 18 | 19 | defp download_labels(opts) do 20 | data_path = opts[:data_path] || @default_data_path 21 | transform = opts[:transform_labels] || fn out -> out end 22 | 23 | <<_::32, n_labels::32, labels::binary>> = 24 | Utils.unzip_cache_or_download(@base_url, @label_file, data_path) 25 | 26 | transform.({labels, {:u, 8}, {n_labels}}) 27 | end 28 | 29 | @doc """ 30 | Downloads the MNIST dataset or fetches it locally. 31 | 32 | ## Options 33 | * `datapath` - path where the dataset .gz should be stored locally 34 | * `transform_images/1` - accepts accept a tuple like 35 | `{binary_data, tensor_type, data_shape}` which can be used for 36 | converting the `binary_data` to a tensor with a function like 37 | fn {labels_binary, type, _shape} -> 38 | labels_binary 39 | |> Nx.from_binary(type) 40 | |> Nx.new_axis(-1) 41 | |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) 42 | |> Nx.to_batched_list(32) 43 | end 44 | * `transform_labels/1` - similar to `transform_images/1` but applied to 45 | dataset labels 46 | """ 47 | def download(opts \\ []), 48 | do: {download_images(opts), download_labels(opts)} 49 | end 50 | -------------------------------------------------------------------------------- /lib/utils.ex: -------------------------------------------------------------------------------- 1 | defmodule AxonDatasets.Utils do 2 | @moduledoc false 3 | require Logger 4 | 5 | def unzip_cache_or_download(base_url, zip, data_path) do 6 | path = Path.join(data_path, zip) 7 | 8 | data = 9 | if File.exists?(path) do 10 | Logger.debug("Using #{zip} from #{data_path}\n") 11 | File.read!(path) 12 | else 13 | Logger.debug("Fetching #{zip} from #{base_url}\n") 14 | :inets.start() 15 | :ssl.start() 16 | 17 | {:ok, {_status, _response, data}} = :httpc.request(base_url ++ zip) 18 | File.mkdir_p!(data_path) 19 | File.write!(path, data) 20 | 21 | data 22 | end 23 | 24 | :zlib.gunzip(data) 25 | end 26 | end 27 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule AxonDatasets.MixProject do 2 | use Mix.Project 3 | 4 | def project do 5 | [ 6 | app: :axon_datasets, 7 | version: "0.1.0", 8 | elixir: "~> 1.11", 9 | start_permanent: Mix.env() == :prod, 10 | deps: deps() 11 | ] 12 | end 13 | 14 | def application do 15 | [ 16 | extra_applications: [:logger, :ssl, :inets] 17 | ] 18 | end 19 | 20 | defp deps do 21 | [] 22 | end 23 | end 24 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "nx": {:git, "https://github.com/elixir-nx/nx.git", "470f752dc07cc211cedbed97914c73a0d3a701d7", [branch: "main", sparse: "nx"]}, 3 | } 4 | -------------------------------------------------------------------------------- /test/axon_datasets_test.exs: -------------------------------------------------------------------------------- 1 | defmodule AxonDatasetsTest do 2 | use ExUnit.Case 3 | doctest AxonDatasets 4 | 5 | test "greets the world" do 6 | assert AxonDatasets.hello() == :world 7 | end 8 | end 9 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | --------------------------------------------------------------------------------