├── test ├── test_helper.exs ├── cuda │ └── memory_test.exs ├── template │ └── helpers_test.exs ├── graph │ ├── pin_test.exs │ └── factory_test.exs ├── template_test.exs ├── env_test.exs ├── compiler │ ├── gpu_node_test.exs │ ├── computation_graph_test.exs │ └── gpu_node_ptx_helpers_test.exs ├── support │ └── cuda_helpers.ex ├── float_16_test.exs ├── graph_test.exs ├── cuda_test.exs └── memory_test.exs ├── .tool-versions ├── lib ├── cuda │ ├── runner │ │ ├── protocols.ex │ │ ├── gpu_node.ex │ │ └── graph.ex │ ├── compiler │ │ ├── protocols.ex │ │ ├── gpu_node.ex │ │ ├── utils.ex │ │ └── context.ex │ ├── app.ex │ ├── graph │ │ ├── computation_graph.ex │ │ ├── gpu_node.ex │ │ ├── protocols.ex │ │ ├── pin.ex │ │ └── node.ex │ ├── template │ │ ├── helpers.ex │ │ └── ptx_helpers.ex │ ├── compiler.ex │ ├── env │ │ └── validation.ex │ ├── worker.ex │ ├── env.ex │ ├── template.ex │ ├── shared.ex │ ├── visualize │ │ └── dot.ex │ ├── float_16.ex │ ├── graph.ex │ └── memory.ex └── cuda.ex ├── c_src ├── .gcc-flags.json ├── runtime_port.h ├── cuda_driver_port.cpp ├── driver_port.h ├── utils.h ├── erlang_port.h ├── commands.h ├── command_info.cpp ├── utils.cpp ├── common.h ├── common.cpp ├── erlang_port.cpp ├── driver.h ├── commands.cpp ├── driver.cpp └── driver_port.cpp ├── .gitignore ├── README.md ├── mix.lock ├── config └── config.exs ├── Makefile ├── mix.exs └── .credo.exs /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | -------------------------------------------------------------------------------- /.tool-versions: -------------------------------------------------------------------------------- 1 | elixir 1.4.4 2 | erlang 19.3 3 | -------------------------------------------------------------------------------- /lib/cuda/runner/protocols.ex: -------------------------------------------------------------------------------- 1 | defprotocol Cuda.Runner do 2 | def load(node, opts \\ []) 3 | def run(node, inputs, opts \\ []) 4 | end 5 | -------------------------------------------------------------------------------- /c_src/.gcc-flags.json: -------------------------------------------------------------------------------- 1 | { 2 | "execPath": "/usr/bin/g++", 3 | "gccDefaultCFlags": "-Wall -fsyntax-only", 4 | "gccDefaultCppFlags": "-Wall -std=c++11 -fsyntax-only", 5 | "gccErrorLimit": 15, 6 | "gccIncludePaths": "/usr/local/cuda-8.0/include", 7 | "gccSuppressWarnings": false 8 | } 9 | -------------------------------------------------------------------------------- /lib/cuda/compiler/protocols.ex: -------------------------------------------------------------------------------- 1 | defprotocol Cuda.Compiler.GPUUnit do 2 | @type source :: {:ptx | :c, String.t} 3 | @spec sources(item :: struct, context :: Cuda.Template.Context.t) :: {:ok, [source]} | {:error, any} 4 | def sources(item, ctx) 5 | end 6 | 7 | defprotocol Cuda.Compiler.Unit do 8 | @spec compile(item :: struct, context :: Cuda.Template.Context.t) :: {:ok, struct} | {:error, any} 9 | def compile(item, ctx) 10 | end 11 | -------------------------------------------------------------------------------- /lib/cuda/app.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.App do 2 | @moduledoc false 3 | 4 | use Application 5 | 6 | def start(_type, _args) do 7 | import Supervisor.Spec 8 | 9 | children = [ 10 | worker(Cuda, [], restart: :temporary) 11 | ] 12 | Supervisor.start_link(children, strategy: :simple_one_for_one, name: __MODULE__) 13 | end 14 | 15 | def start_driver(opts \\ []) do 16 | Supervisor.start_child(__MODULE__, opts) 17 | end 18 | end 19 | -------------------------------------------------------------------------------- /c_src/runtime_port.h: -------------------------------------------------------------------------------- 1 | #ifndef __RUNTIME_PORT_H__ 2 | #define __RUNTIME_PORT_H__ 3 | 4 | #include "erlang_port.h" 5 | 6 | class RuntimePort: public ErlangPort { 7 | protected: 8 | virtual ETERM *HandleTermFunction(std::string name, ETERM *arg); 9 | virtual ETERM *HandleRawFunction(std::string name, RawData &data, size_t size); 10 | 11 | ETERM *Info(ETERM *arg); 12 | public: 13 | RuntimePort(int device); 14 | ~RuntimePort(); 15 | }; 16 | 17 | #endif // __RUNTIME_PORT_H__ 18 | -------------------------------------------------------------------------------- /test/cuda/memory_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.TestMemory do 2 | use ExUnit.Case 3 | 4 | describe "shared memory" do 5 | setup do 6 | {:ok, cuda1} = Cuda.start_link() 7 | {:ok, cuda2} = Cuda.start_link() 8 | [cuda1: cuda1, cuda2: cuda2] 9 | end 10 | 11 | test "shares memory", ctx do 12 | {:ok, a} = Cuda.memory_load(ctx[:cuda1], <<1, 2, 3, 4>>) 13 | {:ok, s} = Cuda.memory_share(ctx[:cuda1], a) 14 | {:ok, b} = Cuda.memory_load(ctx[:cuda2], s) 15 | {:ok, x} = Cuda.memory_read(ctx[:cuda2], b) 16 | assert x == <<1, 2, 3, 4>> 17 | end 18 | end 19 | end 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps 9 | 10 | # Where 3rd-party dependencies like ExDoc output generated docs. 11 | /doc 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | /c_src/*.gch 23 | /priv/* 24 | /tmp 25 | -------------------------------------------------------------------------------- /c_src/cuda_driver_port.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | #include "driver_port.h" 3 | 4 | void CleanupCuda() { 5 | DEBUG("Main function exitted. Resetting CUDA device"); 6 | cudaDeviceReset(); 7 | } 8 | 9 | int main(int argc, char *argv[]) { 10 | std::atexit(CleanupCuda); 11 | 12 | int device = -1; 13 | if (argc > 0) { 14 | std::string deviceStr(argv[0]); 15 | try { 16 | device = std::stoi(deviceStr); 17 | } catch (const std::invalid_argument &e) {} 18 | } 19 | 20 | try { 21 | DriverPort port(device); 22 | port.Loop(); 23 | } catch(...) {} // exit normally on any errors - just die 24 | return 0; 25 | } 26 | -------------------------------------------------------------------------------- /lib/cuda/graph/computation_graph.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Graph.ComputationGraph do 2 | @moduledoc """ 3 | Implements graph for internal nodes calculation 4 | """ 5 | 6 | use Cuda.Graph 7 | 8 | alias Cuda.Graph.Processing 9 | 10 | def __type__(_assigns), do: :computation_graph 11 | def __pins__(_assigns), do: [] 12 | def __graph__(graph), do: graph 13 | 14 | def __run__(%{id: gid} = graph) do 15 | with {:ok, nodes} <- Processing.topology_sort(graph) do 16 | nodes = nodes 17 | |> Enum.map(fn {node, _pin} -> node end) 18 | |> Enum.reject(& &1 == gid) 19 | graph = %{graph | nodes: nodes} 20 | graph 21 | else 22 | _ -> nil 23 | end 24 | end 25 | end 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cuda 2 | 3 | NVIDIA GPU CUDA library bindings for Erlang and Elixir. 4 | 5 | ## Installation 6 | 7 | ```elixir 8 | def deps do 9 | [{:cuda, "~> 0.1.0"}] 10 | end 11 | ``` 12 | 13 | ## Prerequisite 14 | 15 | At least one of video cards should be not in exclusive or prohibited compute 16 | mode. To check your video card mode run: 17 | 18 | ```sh 19 | nvidia-smi --format=csv --query-gpu="compute_mode" 20 | ``` 21 | 22 | To change comute mode to default run: 23 | 24 | ```sh 25 | sudo nvidia-smi -c 0 26 | ``` 27 | 28 | ## Debugging 29 | 30 | To have some debug messages from C++ cuda binding, compile library with 31 | `GPU_DEBUG=1` environment variable like this: 32 | 33 | ```sh 34 | mix clean 35 | GPU_DEBUG=1 mix compile 36 | ``` 37 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{"bunt": {:hex, :bunt, "0.2.0", "951c6e801e8b1d2cbe58ebbd3e616a869061ddadcc4863d0a2182541acae9a38", [:mix], [], "hexpm"}, 2 | "credo": {:hex, :credo, "0.7.2", "850463f126c09227994967fdcf8b8ad7684ab220f7727c00bcafc0ac37bd3660", [:mix], [{:bunt, "~> 0.2.0", [hex: :bunt, repo: "hexpm", optional: false]}], "hexpm"}, 3 | "earmark": {:hex, :earmark, "1.2.0", "bf1ce17aea43ab62f6943b97bd6e3dc032ce45d4f787504e3adf738e54b42f3a", [:mix], [], "hexpm"}, 4 | "ex_doc": {:hex, :ex_doc, "0.15.0", "e73333785eef3488cf9144a6e847d3d647e67d02bd6fdac500687854dd5c599f", [:mix], [{:earmark, "~> 1.1", [hex: :earmark, repo: "hexpm", optional: false]}], "hexpm"}, 5 | "uuid": {:hex, :uuid, "1.1.7", "007afd58273bc0bc7f849c3bdc763e2f8124e83b957e515368c498b641f7ab69", [:mix], [], "hexpm"}} 6 | -------------------------------------------------------------------------------- /c_src/driver_port.h: -------------------------------------------------------------------------------- 1 | #ifndef __DRIVER_PORT_H__ 2 | #define __DRIVER_PORT_H__ 3 | 4 | #include "erlang_port.h" 5 | 6 | class DriverPort: public ErlangPort { 7 | private: 8 | Driver *driver = NULL; 9 | 10 | template T Unpack(ETERM *term); 11 | protected: 12 | virtual ETERM *HandleTermFunction(std::string name, ETERM *arg); 13 | virtual ETERM *HandleRawFunction(std::string name, RawData &data, size_t size); 14 | 15 | std::shared_ptr UnpackRunArguments(ETERM *term); 16 | 17 | ETERM *Compile(ETERM *arg); 18 | ETERM *MemoryRead(ETERM *arg); 19 | ETERM *MemoryUnload(ETERM *arg); 20 | ETERM *ModuleLoad(ETERM *arg); 21 | ETERM *MemoryShare(ETERM *arg); 22 | ETERM *Run(ETERM *arg); 23 | ETERM *MemoryLoad(RawData &data, size_t size); 24 | ETERM *MemoryLoad(ETERM *arg); 25 | ETERM *Stream(ETERM *arg); 26 | ETERM *DeviceInfo(); 27 | public: 28 | DriverPort(int device); 29 | ~DriverPort(); 30 | }; 31 | 32 | #endif // __DRIVER_PORT_H__ 33 | -------------------------------------------------------------------------------- /lib/cuda/template/helpers.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Template.Helpers do 2 | @moduledoc """ 3 | Represents set of helper functions for EEx templates 4 | """ 5 | 6 | alias Cuda.Compiler.Context 7 | 8 | @doc """ 9 | Returns Cuda environment variable value 10 | """ 11 | @spec env(context :: Context.t, variable_name :: String.t | atom | number) :: any 12 | def env(ctx, var_name) do 13 | Map.get(ctx.env, var_name) 14 | end 15 | 16 | @doc """ 17 | Returns context variable value 18 | """ 19 | @spec var(context :: Context.t, variable_name :: String.t | atom | number) :: any 20 | def var(ctx, var_name) do 21 | with nil <- Context.find_assign(ctx, [:vars, var_name]) do 22 | get_in(ctx.assigns, [:vars, var_name]) 23 | end# |> IO.inspect(label: "VAR #{inspect var_name}") 24 | end 25 | 26 | defmacro var(var_name) do 27 | quote do 28 | var(var!(ctx), unquote(var_name)) 29 | end 30 | end 31 | 32 | defmacro env(var_name) do 33 | quote do 34 | env(var!(ctx), unquote(var_name)) 35 | end 36 | end 37 | end 38 | -------------------------------------------------------------------------------- /lib/cuda/compiler.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Compiler do 2 | def compile(sources, opts \\ []) do 3 | tmp = Path.join(System.tmp_dir!, "CudaCompiler-#{UUID.uuid1()}") 4 | File.mkdir_p!(tmp) 5 | nvcc = Keyword.get(opts, :nvcc, "/usr/local/cuda/bin/nvcc") 6 | cubin = Path.join(tmp, "#{UUID.uuid1}.cubin") 7 | files = sources 8 | |> Enum.reduce([], fn 9 | {:ptx, src}, acc -> 10 | id = UUID.uuid1() 11 | file = Path.join(tmp, "#{id}.ptx") 12 | :ok = File.write(file, src) 13 | [file | acc] 14 | {:c, src}, acc -> 15 | file = Path.join(tmp, "#{UUID.uuid1()}.cu") 16 | :ok = File.write(file, src) 17 | [file | acc] 18 | _, acc -> 19 | acc 20 | end) 21 | args = ~w(-dlink --cubin -gencode arch=compute_30,code=sm_30 -o #{cubin}) ++ files 22 | result = with {_, 0} <- System.cmd(nvcc, args, opts) do 23 | File.read(cubin) 24 | end 25 | File.rm_rf!(tmp) 26 | result 27 | end 28 | end 29 | -------------------------------------------------------------------------------- /test/template/helpers_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Template.HelpersTest do 2 | use ExUnit.Case 3 | alias Cuda.Compiler.Context 4 | #import Kernel, except: ["@": 2] 5 | import Cuda.Template.Helpers, except: ["@": 2] 6 | 7 | describe "env/2" do 8 | test "Get environment variable from context" do 9 | ctx = %Context{env: %Cuda.Env{int_size: 16}} 10 | assert 16 == env(ctx, :int_size) 11 | end 12 | end 13 | 14 | describe "var/2" do 15 | test "gets variable from context" do 16 | ctx = %Context{assigns: %{vars: %{var: 16}}} 17 | assert 16 == var(ctx, :var) 18 | end 19 | 20 | test "gets variable from context path" do 21 | root = %Cuda.Graph{id: :root, assigns: %{vars: %{c: 30}}, nodes: [ 22 | %Cuda.Graph{id: :b, assigns: %{vars: %{b: 20}}, nodes: [ 23 | %Cuda.Graph.Node{id: :a, assigns: %{vars: %{a: 10}}} 24 | ]} 25 | ]} 26 | ctx = %Context{root: root, path: [:a, :b]} 27 | assert 10 = var(ctx, :a) 28 | assert 20 = var(ctx, :b) 29 | assert 30 = var(ctx, :c) 30 | end 31 | end 32 | end 33 | -------------------------------------------------------------------------------- /test/graph/pin_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Graph.PinTest do 2 | use ExUnit.Case 3 | alias Cuda.Graph.Pin 4 | import Pin 5 | 6 | describe "data_type/1" do 7 | test "accepts simple types" do 8 | assert data_size(%Pin{data_type: :i8}) == 1 9 | assert data_size(%Pin{data_type: :i16}) == 2 10 | assert data_size(%Pin{data_type: :i32}) == 4 11 | assert data_size(%Pin{data_type: :i64}) == 8 12 | assert data_size(%Pin{data_type: :u8}) == 1 13 | assert data_size(%Pin{data_type: :u16}) == 2 14 | assert data_size(%Pin{data_type: :u32}) == 4 15 | assert data_size(%Pin{data_type: :u64}) == 8 16 | assert data_size(%Pin{data_type: :f16}) == 2 17 | assert data_size(%Pin{data_type: :f32}) == 4 18 | assert data_size(%Pin{data_type: :f64}) == 8 19 | end 20 | 21 | test "accepts unknown types with size qulifier" do 22 | assert data_size(%Pin{data_type: :test128}) == 16 23 | end 24 | 25 | test "accepts tuples" do 26 | assert data_size(%Pin{data_type: {:i16, :i16}}) == 4 27 | assert data_size(%Pin{data_type: {:f16, :i32, :i32}}) == 32 28 | assert data_size(%Pin{data_type: {:i16, {:i32, :i32}, :i16}}) == 64 29 | end 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /config/config.exs: -------------------------------------------------------------------------------- 1 | # This file is responsible for configuring your application 2 | # and its dependencies with the aid of the Mix.Config module. 3 | use Mix.Config 4 | 5 | # This configuration is loaded before any dependency and is restricted 6 | # to this project. If another project depends on this project, this 7 | # file won't be loaded nor affect the parent project. For this reason, 8 | # if you want to provide default values for your application for 9 | # 3rd-party users, it should be done in your "mix.exs" file. 10 | 11 | # You can configure for your application as: 12 | # 13 | # config :cuda, key: :value 14 | # 15 | # And access this configuration in your application as: 16 | # 17 | # Application.get_env(:cuda, :key) 18 | # 19 | # Or configure a 3rd-party app: 20 | # 21 | # config :logger, level: :info 22 | # 23 | 24 | config :cuda, :values_for_cuda_testing, 25 | float_size: 16, 26 | optimize: :memory 27 | 28 | 29 | # It is also possible to import configuration files, relative to this 30 | # directory. For example, you can emulate configuration per environment 31 | # by uncommenting the line below and defining dev.exs, test.exs and such. 32 | # Configuration from the imported file will override the ones defined 33 | # here (which is why it is important to import them last). 34 | # 35 | # import_config "#{Mix.env}.exs" 36 | -------------------------------------------------------------------------------- /test/template_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.TemplateTest do 2 | use ExUnit.Case 3 | import Cuda.Test.CudaHelpers 4 | 5 | alias Cuda.Compiler.Context 6 | import Cuda.Template 7 | 8 | describe "ptx_eval/2" do 9 | test "String.upcase/1 calling with parameter stored in context variable" do 10 | template = ~s[<%= upcase(var(:text)) %>] 11 | ctx = %Context{env: env(), assigns: %{vars: %{text: "Hello, EEx!"}}} 12 | assert "HELLO, EEX!" == ptx_eval(template, [context: ctx, ptx_helpers: [String]]) 13 | end 14 | 15 | test "Add 10 to environment variable int_size" do 16 | template = ~s[<%= env(:int_size) + var(:number) %>] 17 | ctx = %Context{env: env(int_size: 10), assigns: %{vars: %{number: 10}}} 18 | assert "20" == ptx_eval(template, [context: ctx]) 19 | end 20 | end 21 | 22 | describe "c_eval/2" do 23 | test "String.upcase/1 calling with parameter stored in context variable" do 24 | template = ~s[<%= upcase(var(ctx, :text)) %>] 25 | ctx = %Context{env: env(), assigns: %{vars: %{text: "Hello, EEx!"}}} 26 | assert "HELLO, EEX!" == c_eval(template, [context: ctx, c_helpers: [String]]) 27 | end 28 | 29 | test "Add 10 to environment variable int_size" do 30 | template = ~s[<%= env(ctx, :int_size) + var(ctx, :number) %>] 31 | ctx = %Context{env: env(int_size: 10), assigns: %{vars: %{number: 10}}} 32 | assert "20" == c_eval(template, [context: ctx]) 33 | end 34 | end 35 | end 36 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | MIX = mix 2 | 3 | SOURCES ?= 4 | LIBS ?= 5 | CXXFLAGS ?= 6 | GPU_DEBUG ?= "0" 7 | 8 | SOURCES += $(wildcard c_src/*.cpp) 9 | 10 | # common c++ compiler flags 11 | CXXFLAGS += -g -O3 -ansi -std=c++11 -pedantic -Wall -Wextra -Wno-long-long -DGPU_DEBUG=$(GPU_DEBUG) 12 | # LIBS = 13 | 14 | # os specific flags 15 | ifneq ($(OS), Windows_NT) 16 | CXXFLAGS += -fPIC 17 | 18 | ifeq ($(shell uname), Darwin) 19 | LDFLAGS += -dynamiclib -undefined dynamic_lookup 20 | endif 21 | endif 22 | 23 | # erl_interface library 24 | EI_INCL = $(shell erl -eval 'io:format("~s", [code:lib_dir(erl_interface, include)])' -s init stop -noshell) 25 | EI_LIBS = $(shell erl -eval 'io:format("~s", [code:lib_dir(erl_interface, lib)])' -s init stop -noshell) 26 | CXXFLAGS += -I$(EI_INCL) 27 | LIBS += -L$(EI_LIBS) -lerl_interface -lei -lpthread 28 | 29 | # CUDA library 30 | CUDA ?= "cuda" 31 | CXXFLAGS += $(shell pkg-config --cflags $(CUDA)) 32 | LIBS += $(shell pkg-config --libs-only-L $(CUDA)) -lcudart -lcuda 33 | 34 | .PHONY: all port clean 35 | 36 | all: port 37 | 38 | port: 39 | $(MIX) compile 40 | 41 | priv: 42 | mkdir -p priv 43 | 44 | priv/cuda_driver_port: priv $(SOURCES) 45 | $(CXX) $(CXXFLAGS) $(LDFLAGS) -o $@ $(SOURCES) $(LIBS) 46 | 47 | priv/cuda_driver_port.exe: priv $(SOURCES) 48 | $(CXX) $(CXXFLAGS) $(LDFLAGS) -o $@ $(SOURCES) $(LIBS) 49 | 50 | priv/cuda_driver_port_test: priv #$(SOURCES) 51 | @echo $(SOURCES) 52 | 53 | clean: 54 | $(RM) -f priv/cuda_driver_port 55 | -------------------------------------------------------------------------------- /c_src/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef __UTILS_H__ 2 | #define __UTILS_H__ 3 | 4 | inline int SMVer2Cores(int major, int minor) { 5 | // Defines for GPU Architecture types (using the SM version to determine the # of cores per SM 6 | typedef struct { 7 | int SM; // 0xMm (hexidecimal notation), M = SM Major version, and m = SM minor version 8 | int Cores; 9 | } SM2Cores; 10 | 11 | SM2Cores coresPerSM[] = { 12 | {0x20, 32 }, // Fermi Generation (SM 2.0) GF100 class 13 | {0x21, 48 }, // Fermi Generation (SM 2.1) GF10x class 14 | {0x30, 192}, // Kepler Generation (SM 3.0) GK10x class 15 | {0x32, 192}, // Kepler Generation (SM 3.2) GK10x class 16 | {0x35, 192}, // Kepler Generation (SM 3.5) GK11x class 17 | {0x37, 192}, // Kepler Generation (SM 3.7) GK21x class 18 | {0x50, 128}, // Maxwell Generation (SM 5.0) GM10x class 19 | {0x52, 128}, // Maxwell Generation (SM 5.2) GM20x class 20 | {0x53, 128}, // Maxwell Generation (SM 5.3) GM20x class 21 | {0x60, 64 }, // Pascal Generation (SM 6.0) GP100 class 22 | {0x61, 128}, // Pascal Generation (SM 6.1) GP10x class 23 | {0x62, 128}, // Pascal Generation (SM 6.2) GP10x class 24 | { -1, -1 } 25 | }; 26 | 27 | int idx = 0; 28 | for(idx = 0; coresPerSM[idx].SM != -1; idx++) { 29 | if (coresPerSM[idx].SM == ((major << 4) + minor)) { 30 | return coresPerSM[idx].Cores; 31 | } 32 | } 33 | return coresPerSM[idx - 1].Cores; 34 | } 35 | 36 | int BestDevice(); 37 | 38 | #endif // __UTILS_H__ 39 | -------------------------------------------------------------------------------- /lib/cuda/env/validation.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Env.Validation do 2 | @moduledoc """ 3 | Represents module for validate environment values 4 | """ 5 | 6 | @optimize_values ~w(memory speed adaptive none)a 7 | @int_size_values [8, 16, 32] 8 | @float_size_values [16, 32] 9 | 10 | @doc """ 11 | Validates environment variables 12 | """ 13 | @spec validate(atom, any) :: {:ok, any} | Cuda.error_tuple 14 | def validate(:optimize, value) do 15 | if Enum.member?(@optimize_values, value) do 16 | {:ok, value} 17 | else 18 | values = @optimize_values |> Enum.map(&Atom.to_string/1) |> Enum.join(", ") 19 | validate_error("wrong memory_optimization value, permitted values: #{values}") 20 | end 21 | end 22 | def validate(:int_size, value) do 23 | if Enum.member?(@int_size_values, value) do 24 | {:ok, value} 25 | else 26 | values = @int_size_values |> Enum.map(&Integer.to_string/1) |> Enum.join(", ") 27 | validate_error("wrong int_size value, permitted values: #{values}") 28 | end 29 | end 30 | def validate(:float_size, value) do 31 | if Enum.member?(@float_size_values, value) do 32 | {:ok, value} 33 | else 34 | values = @float_size_values |> Enum.map(&Integer.to_string/1) |> Enum.join(", ") 35 | validate_error("wrong float_size value, permitted values: #{values}") 36 | end 37 | end 38 | def validate(_, value), do: {:ok, value} 39 | 40 | defp validate_error(error_message), do: {:error, error_message} 41 | end 42 | -------------------------------------------------------------------------------- /c_src/erlang_port.h: -------------------------------------------------------------------------------- 1 | #ifndef __ERLANG_PORT_H__ 2 | #define __ERLANG_PORT_H__ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "common.h" 11 | #include "driver.h" 12 | 13 | #define PORTIN_FILENO 3 14 | #define PORTOUT_FILENO 4 15 | 16 | #define MEMORY_LOAD 1 17 | 18 | typedef std::shared_ptr RawData; 19 | 20 | class ErlangPort { 21 | private: 22 | std::istream input; 23 | std::ostream output; 24 | ETERM *tuple = NULL; 25 | ETERM *funcAtom = NULL; 26 | ETERM *arg = NULL; 27 | ETERM *result = NULL; 28 | 29 | uint32_t ReadPacketLength(); 30 | uint8_t ReadPacketType(); 31 | uint8_t ReadRawFunc(); 32 | ETERM *ReadTermPacket(uint32_t len); 33 | void WritePacketLength(uint32_t len); 34 | 35 | protected: 36 | virtual ETERM *HandleTermFunction(std::string name, ETERM *arg) = 0; 37 | virtual ETERM *HandleRawFunction(std::string name, RawData &data, size_t size) = 0; 38 | 39 | public: 40 | ErlangPort(); 41 | ~ErlangPort(); 42 | void WriteTermPacket(ETERM *packet); 43 | void WriteRawPacket(void *data, size_t size); 44 | void Loop(); 45 | }; 46 | 47 | // API functions 48 | ETERM *Info(ErlangPort *port, ETERM *arg); 49 | ETERM *Compile(ErlangPort *port, ETERM *arg); 50 | ETERM *ModuleLoad(ErlangPort *port, ETERM *arg); 51 | ETERM *MemoryRead(ErlangPort *port, ETERM *arg); 52 | ETERM *MemoryUnload(ErlangPort *port, ETERM *arg); 53 | ETERM *Run(ErlangPort *port, ETERM *arg); 54 | 55 | // raw API function 56 | ETERM *MemoryLoad(ErlangPort *port, std::shared_ptr &data, size_t size); 57 | 58 | #endif // __ERLANG_PORT_H__ 59 | -------------------------------------------------------------------------------- /test/graph/factory_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Compiler.FactoryTest do 2 | use ExUnit.Case 3 | 4 | alias Cuda.{Graph, Graph.Factory, Graph.Node, Graph.NodeProto} 5 | 6 | def pins() do 7 | [Node.input(:i, :i16, :inputs), Node.output(:o, :i16, :outputs)] 8 | end 9 | 10 | defmodule TestNode do 11 | use Node 12 | def __assigns__(opts, _env), do: %{options: opts} 13 | def __pins__(assigns), do: assigns.options[:pins] 14 | def __type__(_assigns), do: :gpu 15 | end 16 | 17 | defmodule TestGraph do 18 | use Graph 19 | def __pins__(_opts) do 20 | [%{input(:i, nil) | alias: {:group, :inputs}}, 21 | %{output(:o, nil) | alias: {:group, :outputs}}] 22 | end 23 | def __graph__(graph) do 24 | child_pins = [input(:i, :i16, :inputs), output(:o, :i16, :outputs)] 25 | graph 26 | |> chain(:n1, TestNode, [pins: child_pins]) 27 | |> chain(:n2, TestNode, [pins: child_pins]) 28 | |> close() 29 | end 30 | end 31 | 32 | describe "Factory.new/5" do 33 | test "creates node" do 34 | n = Factory.new(%Node{}, :n1, TestNode, pins: pins()) 35 | assert n.__struct__() == Node 36 | assert n.id == :n1 37 | assert n.module == TestNode 38 | end 39 | 40 | test "creates graph" do 41 | g = Factory.new(%Graph{}, :g, TestGraph) 42 | assert g.__struct__() == Graph 43 | assert g.id == :g 44 | assert g.module == TestGraph 45 | end 46 | 47 | test "substitudes aliases types in pins" do 48 | g = Factory.new(%Graph{}, :g, TestGraph) 49 | assert NodeProto.pin(g, :i).data_type == %{n1: %{i: :i16}, n2: %{i: :i16}} 50 | end 51 | end 52 | end 53 | -------------------------------------------------------------------------------- /lib/cuda/graph/gpu_node.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Graph.GPUNode do 2 | @moduledoc """ 3 | Represents a graph node that will be executed on GPU. 4 | 5 | ``` 6 | defmodule MyNode do 7 | use Cuda.Graph.GPUNode 8 | 9 | def __pins__(_assigns) do 10 | [input(:in), output(:out)] 11 | end 12 | 13 | def __ptx__(_assigns) do 14 | \"\"\" 15 | some ptx code 16 | \"\"\" 17 | end 18 | end 19 | ``` 20 | """ 21 | 22 | alias Cuda.Graph 23 | alias Cuda.Node 24 | alias Cuda.Graph.Pin 25 | alias Cuda.Graph.NodeProto 26 | 27 | require Cuda 28 | 29 | @type t :: %__MODULE__{ 30 | id: Graph.id, 31 | module: module, 32 | type: Node.type, 33 | pins: [Pin.t], 34 | assigns: map 35 | } 36 | @type source :: String.t | [String.t] | nil 37 | 38 | @callback __ptx__(node :: struct) :: source 39 | @callback __c__(node :: struct) :: source 40 | @callback __batch__(node :: struct) :: [any] 41 | 42 | @derive [NodeProto] 43 | defstruct [:id, :module, :type, pins: [], assigns: %{}] 44 | 45 | defmacro __using__(_opts) do 46 | quote do 47 | use Cuda.Graph.Node 48 | @behaviour unquote(__MODULE__) 49 | def __ptx__(_node), do: [] 50 | def __c__(_node), do: [] 51 | def __batch__(_node), do: [] 52 | def __proto__(), do: unquote(__MODULE__) 53 | def __type__(_assigns), do: :gpu 54 | defoverridable __batch__: 1, __c__: 1, __ptx__: 1 55 | end 56 | end 57 | end 58 | 59 | defimpl Cuda.Graph.Factory, for: Cuda.Graph.GPUNode do 60 | alias Cuda.Graph.Node 61 | alias Cuda.Graph.Factory 62 | 63 | @doc """ 64 | Creates a new gpu node 65 | """ 66 | def new(_, id, module, opts, env) do 67 | node = %Node{} 68 | |> Factory.new(id, module, opts, env) 69 | |> Map.from_struct 70 | module 71 | |> Node.proto() 72 | |> struct(node) 73 | end 74 | end 75 | -------------------------------------------------------------------------------- /test/env_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.EnvTest do 2 | use ExUnit.Case 3 | import Cuda.Env 4 | import Cuda.Env.Validation 5 | 6 | describe "create" do 7 | test "env map with default values" do 8 | default = Map.merge(%Cuda.Env{}, get_default()) 9 | 10 | assert {:ok, default} == create() 11 | end 12 | end 13 | 14 | describe "load" do 15 | setup do 16 | tmp_var = System.get_env("CUDA_ENV") 17 | System.put_env("CUDA_ENV", "values_for_cuda_testing") 18 | 19 | on_exit fn -> 20 | if tmp_var == nil do 21 | System.delete_env("CUDA_ENV") 22 | else 23 | System.put_env("CUDA_ENV", tmp_var) 24 | end 25 | end 26 | end 27 | 28 | test "env map with test config values" do 29 | answer = Map.merge(%Cuda.Env{}, get_default()) 30 | answer = :cuda 31 | |> Application.get_env(:values_for_cuda_testing) 32 | |> Enum.reduce(answer, fn ({key, val}, acc) -> Map.put(acc, key, val) end) 33 | 34 | {:ok, env} = load() 35 | 36 | assert answer == env 37 | end 38 | end 39 | 40 | describe "merge/2" do 41 | test "env map merges with test keywordlist" do 42 | kw = Application.get_env(:cuda, :values_for_cuda_testing) 43 | initial = Map.merge(%Cuda.Env{}, get_default()) 44 | answer = Enum.reduce(kw, initial, fn ({key, val}, acc) -> Map.put(acc, key, val) end) 45 | 46 | assert {:ok, answer} == merge(initial, kw) 47 | end 48 | 49 | test "get an error when try to load not permitted value" do 50 | assert {:error, _} = merge(%Cuda.Env{}, [float_size: 0]) 51 | end 52 | end 53 | 54 | describe "validate/2" do 55 | test "get an error when try to pass wrong value" do 56 | assert {:error, _} = validate(:float_size, 0) 57 | assert {:error, _} = validate(:int_size, 0) 58 | assert {:error, _} = validate(:optimize, 0) 59 | end 60 | end 61 | end 62 | -------------------------------------------------------------------------------- /c_src/commands.h: -------------------------------------------------------------------------------- 1 | #ifndef __COMMANDS_H__ 2 | #define __COMMANDS_H__ 3 | 4 | #include 5 | #include 6 | 7 | #include "common.h" 8 | #include "driver.h" 9 | 10 | namespace Commands { 11 | 12 | class Events { 13 | public: 14 | Events() { DEBUG("Events created"); } 15 | ~Events(); 16 | CUevent Get(std::string name); 17 | private: 18 | std::map events; 19 | }; 20 | 21 | struct Context { 22 | std::string id; 23 | CUmodule module; 24 | CUstream stream; 25 | Events *events = NULL; 26 | std::string finishEvent; 27 | }; 28 | 29 | class Command { 30 | public: 31 | static Command *Create(Driver *driver, ETERM *batch); 32 | Command(Driver *driver): driver(driver) {}; 33 | virtual void Run(Context &ctx) = 0; 34 | protected: 35 | Driver *driver; 36 | private: 37 | Command(); 38 | }; 39 | 40 | class Batch : public Command { 41 | public: 42 | Batch(Driver *driver, ETERM *args); 43 | void Run(Context &ctx); 44 | private: 45 | std::vector commands; 46 | }; 47 | 48 | class BatchList : public Command { 49 | public: 50 | BatchList(Driver *driver, ETERM *args); 51 | void Run(Context &ctx); 52 | private: 53 | std::vector batches; 54 | }; 55 | 56 | class RunCommand: public Command { 57 | public: 58 | RunCommand(Driver *driver, ETERM *args); 59 | void Run(Context &ctx); 60 | private: 61 | std::string kernel; 62 | unsigned int bx; 63 | unsigned int by; 64 | unsigned int bz; 65 | unsigned int gx; 66 | unsigned int gy; 67 | unsigned int gz; 68 | std::shared_ptr arguments; 69 | }; 70 | 71 | class EventCommand: public Command { 72 | public: 73 | EventCommand(Driver *driver, ETERM *args); 74 | virtual void Run(Context &ctx); 75 | private: 76 | std::string name; 77 | }; 78 | 79 | class WaitCommand: public Command { 80 | public: 81 | WaitCommand(Driver *driver, ETERM *args); 82 | virtual void Run(Context &ctx); 83 | private: 84 | std::string name; 85 | }; 86 | 87 | } // namespace Commands 88 | #endif // __COMMANDS_H__ 89 | -------------------------------------------------------------------------------- /test/compiler/gpu_node_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Compiler.GPUNodeTest do 2 | use ExUnit.Case 3 | 4 | import Cuda.Test.CudaHelpers 5 | 6 | alias Cuda.Compiler.GPUUnit 7 | alias Cuda.Graph.Factory 8 | 9 | defmodule PTXNode do 10 | use Cuda.Graph.GPUNode 11 | def __pins__(_), do: [input(:i, :i16), output(:o, :i32)] 12 | def __ptx__(_), do: "PTX-<%= var(ctx, :x) %>" 13 | end 14 | 15 | defmodule CNode do 16 | use Cuda.Graph.GPUNode 17 | def __pins__(_), do: [input(:i, :i16), output(:o, :i32)] 18 | def __c__(_), do: "C-<%= var(ctx, :x) %>" 19 | end 20 | 21 | defmodule PTXCNode do 22 | use Cuda.Graph.GPUNode 23 | def __pins__(_), do: [input(:i, :i16), output(:o, :i32)] 24 | def __ptx__(_), do: "PTX-<%= var(ctx, :x) %>" 25 | def __c__(_), do: "C-<%= var(ctx, :y) %>" 26 | end 27 | 28 | describe "sources/2" do 29 | test "returns ptx sources" do 30 | node = Factory.new(%Cuda.Graph.GPUNode{}, :node, PTXNode, [], env()) 31 | node = %{node | assigns: %{vars: %{x: 10}}} 32 | context = context(root: node, path: []) 33 | {:ok, %{assigns: %{sources: [{:ptx, ptx}]}}} = GPUUnit.sources(node, context) 34 | assert parse_ptx(ptx) == ["PTX-10"] 35 | end 36 | 37 | test "returns c sources" do 38 | node = Factory.new(%Cuda.Graph.GPUNode{}, :node, CNode, [], env()) 39 | node = %{node | assigns: %{vars: %{x: 20}}} 40 | context = context(root: node, path: []) 41 | {:ok, %{assigns: %{sources: [{:c, c}]}}} = GPUUnit.sources(node, context) 42 | assert parse_c(c) == ["C-20"] 43 | end 44 | 45 | test "returns both c and ptx sources" do 46 | node = Factory.new(%Cuda.Graph.GPUNode{}, :node, PTXCNode, [], env()) 47 | node = %{node | assigns: %{vars: %{x: 10, y: 20}}} 48 | context = context(root: node, path: []) 49 | {:ok, %{assigns: %{sources: [{:ptx, ptx}, {:c, c}]}}} = GPUUnit.sources(node, context) 50 | assert parse_ptx(ptx) == ["PTX-10"] 51 | assert parse_c(c) == ["C-20"] 52 | end 53 | end 54 | end 55 | -------------------------------------------------------------------------------- /c_src/command_info.cpp: -------------------------------------------------------------------------------- 1 | #include "erlang_port.h" 2 | #include 3 | 4 | ETERM *GetDeviceCount() { 5 | int devCount; 6 | cudaError_t result = cudaGetDeviceCount(&devCount); 7 | switch (result) { 8 | case cudaSuccess: return FORMAT("{~a,~i}", OK_STR, devCount); 9 | case cudaErrorNoDevice: return FORMAT("{~a,~i}", OK_STR, 0); 10 | default: throw RuntimeError(result); 11 | } 12 | } 13 | 14 | ETERM *GetMemory() { 15 | size_t freeMem, totalMem; 16 | cudaError_t result = cudaMemGetInfo(&freeMem, &totalMem); 17 | if (result != cudaSuccess) throw RuntimeError(result); 18 | return FORMAT("{~a,{~i,~i}}", OK_STR, freeMem, totalMem); 19 | } 20 | 21 | ETERM *GetDriverVersion() { 22 | int version; 23 | cudaError_t result = cudaDriverGetVersion(&version); 24 | if (result != cudaSuccess) throw RuntimeError(result); 25 | return FORMAT("{~a,~i}", OK_STR, version); 26 | } 27 | 28 | ETERM *GetRuntimeVersion() { 29 | int version; 30 | cudaError_t result = cudaRuntimeGetVersion(&version); 31 | if (result != cudaSuccess) throw RuntimeError(result); 32 | return FORMAT("{~a,~i}", OK_STR, version); 33 | } 34 | 35 | ETERM *Info(ErlangPort *, ETERM *arg) { 36 | if (IS_NIL(arg)) { 37 | auto deviceCount = GetDeviceCount(); 38 | auto memory = GetMemory(); 39 | auto driverVersion = GetDriverVersion(); 40 | auto runtimeVersion = GetRuntimeVersion(); 41 | return FORMAT("[{~a,~w},{~a,~w},{~a,~w},{~a,~w}]", 42 | C_STR("device_count"), erl_element(2, deviceCount), 43 | C_STR("driver_version"), erl_element(2, driverVersion), 44 | C_STR("memory"), erl_element(2, memory), 45 | C_STR("runtime_version"), erl_element(2, runtimeVersion)); 46 | } else if (ERL_IS_ATOM(arg)) { 47 | if (ATOM_EQ(arg, "device_count")) return GetDeviceCount(); 48 | else if (ATOM_EQ(arg, "driver_version")) return GetDriverVersion(); 49 | else if (ATOM_EQ(arg, "memory")) return GetMemory(); 50 | else if (ATOM_EQ(arg, "runtime_version")) return GetRuntimeVersion(); 51 | } 52 | throw StringError("bad argument"); 53 | } 54 | -------------------------------------------------------------------------------- /lib/cuda/worker.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Worker do 2 | use GenServer 3 | alias Cuda.{Runner, Shared} 4 | alias Cuda.Compiler.{Context, Unit} 5 | 6 | def start_link(opts \\ []) do 7 | {opts, args} = Keyword.split(opts, ~w(name)a) 8 | GenServer.start_link(__MODULE__, args, opts) 9 | end 10 | 11 | def init(opts) do 12 | cuda = case Keyword.get(opts, :cuda) do 13 | nil -> with {:ok, cuda} <- Cuda.start_link(), do: cuda 14 | cuda -> cuda 15 | end 16 | with {:ok, graph} <- Keyword.fetch(opts, :graph) do 17 | #{:ok, info} <- Cuda.device_info(cuda) do 18 | st = %{cuda: cuda, graph: graph} 19 | load_graph(st, opts) 20 | end 21 | end 22 | 23 | def run(pid, input, args \\ %{}) do 24 | GenServer.call(pid, {:run, input, args}) 25 | end 26 | 27 | #def gpu_info(pid) do 28 | # GenServer.call(pid, :info) 29 | #end 30 | 31 | def handle_call({:run, input, args}, _from, st) do 32 | opts = [cuda: st.cuda, args: args] 33 | result = Runner.run(st.graph, input, opts) 34 | {:reply, result, st} 35 | end 36 | 37 | #def handle_call(:info, _from, st) do 38 | # {:reply, {:ok, st.info}, st} 39 | #end 40 | 41 | defp load_graph(st, opts) do 42 | env = Keyword.get(opts, :env, %Cuda.Env{}) 43 | env = with {:ok, info} <- Cuda.device_info(st.cuda) do 44 | %{env | gpu_info: info} 45 | else 46 | _ -> env 47 | end 48 | vars = %{float_size: env.float_size, f: Cuda.Env.f(env)} 49 | ctx = %Context{env: env, assigns: %{vars: vars}} 50 | args = case Keyword.get(opts, :shared) do 51 | nil -> 52 | %{} 53 | shared -> 54 | shared 55 | |> Enum.map(fn {k, pid} -> 56 | with {:ok, ref} <- Shared.share(pid), 57 | {:ok, mem} <- Cuda.memory_load(st.cuda, ref) do 58 | {k, mem} 59 | end 60 | end) 61 | |> Enum.into(%{}) 62 | end 63 | with {:ok, graph} <- Unit.compile(st.graph, ctx), 64 | # load compiled cubins into GPU 65 | {:ok, graph} <- Runner.load(graph, args: args, cuda: st.cuda) do 66 | {:ok, %{st | graph: graph}} 67 | end 68 | end 69 | end 70 | -------------------------------------------------------------------------------- /c_src/utils.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | #include "common.h" 3 | #include "cuda_runtime.h" 4 | 5 | int BestDevice() { 6 | int maxPerfDevice = 0; 7 | int smPerMultiproc = 0; 8 | int currentDevice; 9 | int bestArch = 0; 10 | int prohibited = 0; 11 | int deviceCount = 0; 12 | unsigned long long maxComputePerf = 0; 13 | cudaError_t result; 14 | 15 | cudaDeviceProp deviceProp; 16 | result = cudaGetDeviceCount(&deviceCount); 17 | if (result != cudaSuccess) throw RuntimeError(result); 18 | if (deviceCount == 0) throw StringError("No devices supporting CUDA"); 19 | 20 | // Find the best major SM Architecture GPU device 21 | for (currentDevice = 0; currentDevice < deviceCount; currentDevice++) { 22 | result = cudaGetDeviceProperties(&deviceProp, currentDevice); 23 | if (result != cudaSuccess) throw RuntimeError(result); 24 | if (deviceProp.computeMode == cudaComputeModeProhibited) { 25 | prohibited++; 26 | continue; 27 | } 28 | // If this GPU is not running on Compute Mode prohibited, then we can add it to the list 29 | if (deviceProp.major > 0 && deviceProp.major < 9999) { 30 | bestArch = std::max(bestArch, deviceProp.major); 31 | } 32 | } 33 | 34 | if (prohibited == deviceCount) { 35 | throw StringError("All devices have compute mode prohibited"); 36 | } 37 | 38 | for (currentDevice = 0; currentDevice < deviceCount; currentDevice++) { 39 | cudaGetDeviceProperties(&deviceProp, currentDevice); 40 | if (deviceProp.computeMode == cudaComputeModeProhibited) continue; 41 | smPerMultiproc = deviceProp.major == 9999 && deviceProp.minor == 9999 ? 42 | 1 : 43 | SMVer2Cores(deviceProp.major, deviceProp.minor); 44 | unsigned long long computePerf = (unsigned long long) deviceProp.multiProcessorCount * smPerMultiproc * deviceProp.clockRate; 45 | if (computePerf > maxComputePerf) { 46 | // If we find GPU with SM major > 2, search only these 47 | // If our device == bestArch, choose this, or else pass 48 | if ((bestArch > 2 && deviceProp.major == bestArch) || bestArch <= 2) { 49 | maxComputePerf = computePerf; 50 | maxPerfDevice = currentDevice; 51 | } 52 | } 53 | } 54 | return maxPerfDevice; 55 | } 56 | -------------------------------------------------------------------------------- /c_src/common.h: -------------------------------------------------------------------------------- 1 | #ifndef __COMMON_H__ 2 | #define __COMMON_H__ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #if GPU_DEBUG > 0 11 | #include 12 | #endif 13 | 14 | extern "C" { 15 | #include "erl_interface.h" 16 | #include "ei.h" 17 | } 18 | 19 | #include "cuda.h" 20 | #include "cuda_runtime.h" 21 | 22 | #define C_STR(str) ((char*)std::string(str).c_str()) 23 | #define FORMAT(fmt, ...) erl_format(C_STR(fmt), ##__VA_ARGS__) 24 | #define OK_STR C_STR("ok") 25 | #define ERROR_STR C_STR("error") 26 | #define MAKE_BINARY(str) erl_mk_binary(C_STR(str), sizeof(str) - 1) 27 | #define ATOM_EQ(term, str) (strncmp(ERL_ATOM_PTR(term), str, sizeof(str) - 1) == 0) 28 | #define IS_NIL(term) (ERL_IS_ATOM(term) && strncmp(ERL_ATOM_PTR(term), "nil", 3) == 0) 29 | #define IS_OK_TUPLE(term) (strncmp(ERL_ATOM_PTR(erl_element(1, term)), "ok", 2) == 0) 30 | #if GPU_DEBUG > 0 31 | #define DEBUG(msg) std::cout << msg << "\n" 32 | #else 33 | #define DEBUG(msg) do {} while(0) 34 | #endif 35 | 36 | class Error { 37 | protected: 38 | std::string source; 39 | public: 40 | Error(const char *src = NULL) : source(src ? src : "") {} 41 | virtual ETERM *AsTerm() = 0; 42 | }; 43 | 44 | class TermError : public Error { 45 | public: 46 | ETERM *term; 47 | TermError(ETERM *error, const char *src = NULL): Error(src), term(error) {} 48 | virtual ETERM *AsTerm() { return term; } 49 | }; 50 | 51 | class StringError : public Error { 52 | public: 53 | std::string message; 54 | StringError(const char *errorMessage, const char *src = NULL): Error(src), message(errorMessage) {} 55 | virtual ETERM *AsTerm(); 56 | }; 57 | 58 | class RuntimeError : public Error { 59 | public: 60 | cudaError_t code; 61 | RuntimeError(cudaError_t errorNo, const char *src = NULL): Error(src), code(errorNo) {} 62 | virtual ETERM *AsTerm(); 63 | }; 64 | 65 | class DriverError : public Error { 66 | public: 67 | CUresult code; 68 | DriverError(CUresult errorNo, const char *src = NULL): Error(src), code(errorNo) {} 69 | virtual ETERM *AsTerm(); 70 | }; 71 | 72 | typedef std::map Keywords; 73 | Keywords GetKeywords(ETERM *list); 74 | 75 | template T Get(ETERM *); 76 | int GetModuleIndex(ETERM *); 77 | int GetMemoryIndex(ETERM *); 78 | 79 | #endif // __COMMON_H__ 80 | -------------------------------------------------------------------------------- /lib/cuda/runner/gpu_node.ex: -------------------------------------------------------------------------------- 1 | defimpl Cuda.Runner, for: Cuda.Graph.GPUNode do 2 | alias Cuda.Graph.NodeProto 3 | alias Cuda.Graph.Pin 4 | 5 | def load(%{assigns: assigns} = node, opts) do 6 | with cuda when is_pid(cuda) <- Keyword.get(opts, :cuda) do 7 | # load cubin into GPU 8 | {:ok, module} = Cuda.module_load(cuda, assigns.cubin) 9 | # load args into GPU 10 | args = opts 11 | |> Keyword.get(:args, %{}) 12 | |> Enum.reduce(%{}, fn 13 | {k, {m, _} = loaded}, args when m in ~w(memory shared_memory)a -> 14 | Map.put(args, k, loaded) 15 | {k, {type, value}}, args -> 16 | bin = Pin.pack(type, value) 17 | with {:ok, marg} <- Cuda.memory_load(cuda, bin) do 18 | Map.put(args, k, marg) 19 | else 20 | _ -> 21 | # TODO: warning here 22 | args 23 | end 24 | _, args -> 25 | args 26 | end) 27 | {:ok, NodeProto.assign(node, cuda_module: module, cuda_args: args)} 28 | end 29 | end 30 | 31 | def run(%{assigns: assigns}, inputs, opts) do 32 | with cuda when is_pid(cuda) <- Keyword.get(opts, :cuda) do 33 | pins = inputs 34 | |> Cuda.Compiler.Utils.wrap_pins 35 | |> Pin.pack(assigns.inputs_shape) 36 | 37 | {:ok, mpins} = Cuda.memory_load(cuda, pins) 38 | 39 | args = Map.merge(Map.get(assigns, :cuda_args, %{}), 40 | Keyword.get(opts, :args, %{})) 41 | 42 | batches = assigns.batches |> Enum.map(fn batch -> 43 | Enum.map(batch, fn 44 | {:run, {name, k, b, params}} -> 45 | params = Enum.map(params, & Map.get(args, &1)) 46 | {:run, {name, k, b, [mpins | params]}} 47 | {:run, {name, k, b}} -> 48 | {:run, {name, k, b, [mpins]}} 49 | x -> 50 | x 51 | end) 52 | end) 53 | 54 | :ok = Cuda.stream(cuda, assigns.cuda_module, batches) 55 | {:ok, pins} = Cuda.memory_read(cuda, mpins) 56 | 57 | output = pins 58 | |> Pin.unpack(assigns.outputs_shape) 59 | |> Cuda.Compiler.Utils.unwrap_pins 60 | 61 | {:ok, output} 62 | else 63 | _ -> {:error, :no_cuda_specified} 64 | end 65 | end 66 | end 67 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Mixfile do 2 | use Mix.Project 3 | 4 | def project do 5 | [app: :cuda, 6 | version: "0.1.0", 7 | elixir: "~> 1.4", 8 | build_embedded: Mix.env == :prod, 9 | start_permanent: Mix.env == :prod, 10 | compilers: [:port, :elixir, :app], 11 | elixirc_paths: paths(), 12 | deps: deps(), 13 | aliases: aliases(), 14 | docs: docs()] 15 | end 16 | 17 | def application do 18 | [mod: {Cuda.App, []}, 19 | extra_applications: [:logger]] 20 | end 21 | 22 | defp deps do 23 | [{:uuid, "~> 1.1"}, 24 | # {:cpp_port, path: "../cpp_port"}, 25 | {:credo, "~> 0.7", only: [:dev, :test]}, 26 | {:ex_doc, "~> 0.15", only: :dev, runtime: false}] 27 | end 28 | 29 | defp aliases do 30 | [clean: ["clean.port", "clean"]] 31 | end 32 | 33 | defp docs do 34 | [main: "Cuda", 35 | #logo: "path/to/logo.png", 36 | extras: ["README.md"]] 37 | end 38 | 39 | defp paths do 40 | ["lib", Path.join(~w(test support))] 41 | end 42 | 43 | # defp cpp_ports do 44 | # [[module: A, 45 | # src: "src", 46 | # target: "priv/cuda_driver_port_test", 47 | # env: %{"CUDA" => "cuda-8.0"}]] 48 | # end 49 | end 50 | 51 | defmodule Mix.Tasks.Compile.Port do 52 | @cuda_version_file "/usr/local/cuda/version.txt" 53 | @cuda_version_re ~r/\s+(\d+\.\d+)/ 54 | def run(_) do 55 | cuda = with true <- File.exists?(@cuda_version_file), 56 | {:ok, version} <- File.read(@cuda_version_file), 57 | [_, version] <- Regex.run(@cuda_version_re, version) do 58 | "cuda-#{version}" 59 | else 60 | _ -> "cuda" 61 | end 62 | opts = [stderr_to_stdout: true, 63 | env: [{"CUDA", cuda}]] 64 | if match? {:win32, _}, :os.type do 65 | {result, _error_code} = System.cmd("nmake", ["priv\\cuda_driver_port.exe"], opts) 66 | Mix.shell.info result 67 | else 68 | {result, _error_code} = System.cmd("make", ["priv/cuda_driver_port"], opts) 69 | Mix.shell.info result 70 | end 71 | end 72 | end 73 | 74 | defmodule Mix.Tasks.Clean.Port do 75 | def run(_) do 76 | opts = [stderr_to_stdout: true] 77 | if match? {:win32, _}, :os.type do 78 | {result, _error_code} = System.cmd("nmake", ["clean"], opts) 79 | Mix.shell.info result 80 | else 81 | {result, _error_code} = System.cmd("make", ["clean"], opts) 82 | Mix.shell.info result 83 | end 84 | end 85 | end 86 | -------------------------------------------------------------------------------- /lib/cuda/env.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Env do 2 | @moduledoc """ 3 | Represents environment variables manage module 4 | """ 5 | alias Cuda.Env.Validation 6 | 7 | @type optimize :: :memory | :speed | :adaptive | :none 8 | @type float_size :: 16 | 32 | 64 9 | @type int_size :: 8 | 16 | 32 | 64 10 | 11 | @type t :: %__MODULE__{ 12 | float_size: float_size, 13 | int_size: int_size, 14 | optimize: optimize, 15 | gpu_info: keyword 16 | } 17 | 18 | defstruct [float_size: 4, int_size: 1, optimize: :none, gpu_info: []] 19 | 20 | @env_var "CUDA_ENV" 21 | @default %{ 22 | float_size: 4, 23 | int_size: 1, 24 | optimize: :none} 25 | 26 | @doc """ 27 | Creates default filled env map 28 | """ 29 | @spec create() :: {:ok, t} 30 | def create() do 31 | {:ok, Map.merge(%__MODULE__{}, @default)} 32 | end 33 | 34 | @doc """ 35 | Returns env map filled from :cuda config (config.exs) 36 | with key loaded from system env CUDA_ENV 37 | """ 38 | @spec load() :: {:ok, t} | Cuda.error_tuple 39 | def load() do 40 | case get_env() do 41 | nil -> 42 | create() 43 | env -> 44 | keys = get_keys() |> MapSet.new() 45 | {:ok, init} = create() 46 | fill_in(env, keys, init) 47 | end 48 | end 49 | 50 | @doc """ 51 | Merge env map with keyword list opts 52 | """ 53 | @spec merge(t, [keyword]) :: {:ok, t} | Cuda.error_tuple 54 | def merge(env, opts) do 55 | keys = get_keys() |> MapSet.new() 56 | fill_in(opts, keys, env) 57 | end 58 | 59 | @doc """ 60 | Returns default env values map 61 | """ 62 | @spec get_default() :: map 63 | def get_default(), do: @default 64 | 65 | def f(%__MODULE__{float_size: size}) when is_integer(size) do 66 | "f#{size * 8}" 67 | end 68 | def f(_), do: "f32" 69 | 70 | defp get_keys() do 71 | %__MODULE__{} 72 | |> Map.from_struct() 73 | |> Map.keys() 74 | end 75 | 76 | defp get_env() do 77 | with val when not is_nil(val) <- System.get_env(@env_var) do 78 | val = String.to_atom(val) 79 | Application.get_env(:cuda, val) 80 | end 81 | end 82 | 83 | defp fill_in([], _, env), do: {:ok, env} 84 | defp fill_in([{key, value} | rest], keys, env) do 85 | with {:ok, value} <- Validation.validate(key, value) do 86 | if MapSet.member?(keys, key) do 87 | fill_in(rest, keys, Map.put(env, key, value)) 88 | else 89 | {:error, "unexpected value name in config"} 90 | end 91 | end 92 | end 93 | end 94 | -------------------------------------------------------------------------------- /test/support/cuda_helpers.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Test.CudaHelpers do 2 | alias Cuda.Compiler.Context 3 | 4 | def env(values \\ []) do 5 | {:ok, env} = Cuda.Env.create() 6 | env 7 | |> Map.merge(%{gpu_info: gpu_info()}) 8 | |> Map.merge(values |> Enum.into(%{})) 9 | end 10 | 11 | def context(values \\ []) do 12 | values = values |> Enum.into(%{}) 13 | %Context{env: env(Map.get(values, :env, [])), assigns: %{vars: %{}}} 14 | |> Map.merge(values, &context_merge/3) 15 | end 16 | 17 | @header_directives ~w(.version .target .address_size) 18 | def parse_ptx(ptx) do 19 | ptx 20 | |> String.split("\n") 21 | |> Enum.map(&String.trim/1) 22 | |> Enum.map(& String.replace(&1, ~r/\s+/, " ")) 23 | |> Enum.map(& String.split(&1, " ")) 24 | |> Enum.reject(& List.first(&1) in @header_directives) 25 | |> Enum.map(& Enum.join(&1, " ")) 26 | |> Enum.join() 27 | |> String.split(";") 28 | end 29 | 30 | def parse_c(c) do 31 | c 32 | |> String.split("\n") 33 | |> Enum.map(&String.trim/1) 34 | |> Enum.map(& String.replace(&1, ~r/\s+/, " ")) 35 | |> Enum.join() 36 | |> String.split(";") 37 | end 38 | 39 | def gpu_info() do 40 | [max_threads_per_block: 1024, max_block: {1024, 1024, 64}, 41 | max_grid: {2147483647, 65535, 65535}, max_shared_memory_per_block: 49152, 42 | total_constant_memory: 65536, warp_size: 32, max_pitch: 2147483647, 43 | max_registers_per_block: 65536, clock_rate: 1006000, gpu_overlap: true, 44 | miltiprocessor_count: 2, kernel_exec_timeout: true, integrated: false, 45 | can_map_host_memory: true, compute_mode: :default, concurrent_kernels: true, 46 | ecc_enabled: false, pci_bus_id: 1, pci_device_id: 0, tcc_driver: false, 47 | memory_clock_rate: 2505000, global_memory_bus_width: 64, l2_cache_size: 524288, 48 | max_threads_per_multiprocessor: 2048, unified_arressing: true, 49 | compute_capability: {3, 5}, global_l1_cache_supported: false, 50 | glocal_l1_cache_supported: true, max_shared_memory_per_multiprocessor: 49152, 51 | max_registers_per_multiprocessor: 65536, managed_memory: true, 52 | multi_gpu_board: false, multi_gpu_board_group_id: 0, 53 | host_native_atomic_supported: false, single_to_double_precision_perf_ratio: 24, 54 | pageable_memory_access: false, concurrent_managed_access: false, 55 | compute_preemption_supported: false, 56 | can_use_host_pointer_for_registered_mem: false] 57 | end 58 | 59 | defp context_merge(:env, v1, v2), do: Map.merge(v1, v2) 60 | defp context_merge(:assigns, v1, v2), do: Map.merge(v1, v2) 61 | defp context_merge(_, _v1, v2), do: v2 62 | end 63 | -------------------------------------------------------------------------------- /test/compiler/computation_graph_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Compiler.ComputationGraphTest do 2 | use ExUnit.Case 3 | 4 | import Cuda.Test.CudaHelpers 5 | 6 | alias Cuda.Compiler.GPUUnit 7 | alias Cuda.{Graph, Graph.ComputationGraph, Graph.Factory, Graph.GPUNode, Graph.Pin} 8 | 9 | defmodule Node1 do 10 | use GPUNode 11 | def __pins__(_), do: [input(:i, :i16), output(:o, :i32)] 12 | def __ptx__(_), do: "node1-<%= offset(ctx, :pins, :i) %>-<%= offset(ctx, :pins, :o) %>" 13 | end 14 | 15 | defmodule Node2 do 16 | use GPUNode 17 | def __pins__(_), do: [input(:i, :i32), output(:o, :i64)] 18 | def __ptx__(_), do: "node2-<%= offset(ctx, :pins, :i) %>-<%= offset(ctx, :pins, :o) %>" 19 | end 20 | 21 | defmodule Node3 do 22 | use GPUNode 23 | def __pins__(_), do: [input(:i, :i16, :inputs), output(:o, :i32, :inputs)] 24 | def __ptx__(_), do: "node3-<%= offset(ctx, :pins, :i) %>-<%= offset(ctx, :pins, :o) %>" 25 | end 26 | 27 | defmodule Node4 do 28 | use GPUNode 29 | def __pins__(_), do: [input(:i, :i32, :inputs), output(:o, :i64, :inputs)] 30 | def __ptx__(_), do: "node4-<%= offset(ctx, :pins, :i) %>-<%= offset(ctx, :pins, :o) %>" 31 | end 32 | 33 | describe "sources/2" do 34 | test "returns chained sources" do 35 | graph_pins = [%Pin{id: :i, type: :input, data_type: :i16}, 36 | %Pin{id: :o, type: :output, data_type: :i64}, 37 | %Pin{id: :gi, type: :input, data_type: :i16}, 38 | %Pin{id: :go, type: :output, data_type: :i64}, 39 | %Pin{id: :inputs, type: :output, alias: :inputs}] 40 | graph = Factory.new(%Cuda.Graph{}, :g, ComputationGraph, [], env()) 41 | |> Map.put(:pins, graph_pins) 42 | |> Graph.add(:node1, Node1) 43 | |> Graph.add(:node2, Node2) 44 | |> Graph.add(:node3, Node3) 45 | |> Graph.add(:node4, Node4) 46 | |> Graph.link(:i, {:node1, :i}) 47 | |> Graph.link({:node1, :o}, {:node2, :i}) 48 | |> Graph.link({:node2, :o}, :o) 49 | |> Graph.link(:gi, {:node3, :i}) 50 | |> Graph.link({:node3, :o}, {:node4, :i}) 51 | |> Graph.link({:node4, :o}, :go) 52 | ctx = context(root: graph, vars: %{x: 10}) 53 | {:ok, %{assigns: %{sources: sources}}} = GPUUnit.sources(graph, ctx) 54 | [{:ptx, n4}, {:ptx, n3}, {:ptx, n2}, {:ptx, n1}] = sources 55 | assert parse_ptx(n1) == ["node1-14-22"] 56 | assert parse_ptx(n2) == ["node2-22-14"] 57 | assert parse_ptx(n3) == ["node3-0-2"] 58 | assert parse_ptx(n4) == ["node4-2-6"] 59 | end 60 | end 61 | end 62 | -------------------------------------------------------------------------------- /lib/cuda/template.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Template do 2 | @moduledoc """ 3 | Represents eex templates processing module 4 | """ 5 | 6 | @type options :: [ 7 | context: Cuda.Compiler.Context.t, 8 | ptx_helpers: [], # ? - subject to remove - use helpers option 9 | c_helpers: [], # ? - subject to remove - use helpers option 10 | helpers: []] 11 | 12 | @doc """ 13 | Preprocesses PTX template 14 | """ 15 | @spec ptx_preprocess(template :: String.t, opts :: options) :: {:ptx, String.t} 16 | def ptx_preprocess(template, opts) do 17 | {:ptx, ptx_eval(template, opts)} 18 | end 19 | 20 | @doc """ 21 | Preprocesses C template 22 | """ 23 | @spec c_preprocess(template :: String.t, opts :: options) :: {:c, String.t} 24 | def c_preprocess(template, opts) do 25 | {:c, c_eval(template, opts)} 26 | end 27 | 28 | @doc """ 29 | Returns evaluated PTX template with included helper modules, and etc. 30 | """ 31 | @spec ptx_eval(template :: String.t, opts :: options) :: String.t 32 | def ptx_eval(template, opts) do 33 | hlprs = [Kernel, Cuda.Template.Helpers] ++ 34 | Keyword.get(opts, :ptx_helpers, []) ++ 35 | Keyword.get(opts, :helpers, []) 36 | ctx = Keyword.get(opts, :context) 37 | eval(template, ctx, hlprs) 38 | end 39 | 40 | @doc """ 41 | Returns evaluated C template with included helper modules, and etc. 42 | """ 43 | @spec c_eval(template :: String.t, opts :: options) :: String.t 44 | def c_eval(template, opts) do 45 | hlprs = [Kernel, Cuda.Template.Helpers] ++ 46 | Keyword.get(opts, :c_helpers, []) ++ 47 | Keyword.get(opts, :helpers, []) 48 | ctx = Keyword.get(opts, :context) 49 | eval(template, ctx, hlprs) 50 | end 51 | 52 | defp eval(template, context, helpers) do 53 | opts = [functions: get_funcs(helpers), 54 | macros: get_macros(helpers)] 55 | EEx.eval_string(template, [ctx: context, assigns: context.assigns], opts) 56 | end 57 | 58 | defp get_funcs([]), do: [] 59 | defp get_funcs([module | rest]) do 60 | {:ok, funcs} = get_funcs(module) 61 | [funcs | get_funcs(rest)] 62 | end 63 | defp get_funcs(module) do 64 | funcs = module.__info__(:functions) 65 | {:ok, {module, funcs}} 66 | end 67 | 68 | defp get_macros([]), do: [] 69 | defp get_macros([module | rest]) do 70 | {:ok, macros} = get_macros(module) 71 | [macros | get_macros(rest)] 72 | end 73 | defp get_macros(module) do 74 | macros = module.__info__(:macros) 75 | #macros = case module do 76 | # Kernel -> macros |> Enum.reject(fn {n, _} -> n == :@ end) 77 | # _ -> macros 78 | #end 79 | {:ok, {module, macros}} 80 | end 81 | end 82 | -------------------------------------------------------------------------------- /test/float_16_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Float16Test do 2 | use ExUnit.Case 3 | import Cuda.Float16 4 | 5 | describe "pack/1" do 6 | test "Positive value without fraction" do 7 | assert pack(2) == <<0::size(1),16::size(5),0::size(10)>> 8 | end 9 | 10 | test "Negative value without fraction" do 11 | assert pack(-2) == <<1::size(1),16::size(5),0::size(10)>> 12 | end 13 | 14 | test "Positive float value" do 15 | assert pack(155.625) == <<0::size(1),22::size(5),221::size(10)>> 16 | end 17 | 18 | test "Negative float value" do 19 | assert pack(-155.625) == <<1::size(1),22::size(5),221::size(10)>> 20 | end 21 | 22 | test "Maximal normal number" do 23 | assert pack(65504) == <<0::size(1),30::size(5),1023::size(10)>> 24 | end 25 | 26 | test "Minimal normal number" do 27 | assert pack(0.000061) == <<0::size(1),0::size(5),1022::size(10)>> 28 | end 29 | 30 | test "Minimal subnormal number" do 31 | assert pack(0.00000006) == <<0::size(1),0::size(5),1::size(10)>> 32 | end 33 | 34 | test "Zero" do 35 | assert pack(0) == <<0::size(1),0::size(5),0::size(10)>> 36 | end 37 | end 38 | 39 | describe "unpack/1" do 40 | test "Positive value without fraction" do 41 | assert unpack(<<0::size(1),16::size(5),0::size(10)>>) == 2.0 42 | end 43 | 44 | test "Negative value without fraction" do 45 | assert unpack(<<1::size(1),16::size(5),0::size(10)>>) == -2.0 46 | end 47 | 48 | test "Positive float value" do 49 | assert unpack(<<0::size(1),22::size(5),221::size(10)>>) == 155.625 50 | end 51 | 52 | test "Negative float value" do 53 | assert unpack(<<1::size(1),22::size(5),221::size(10)>>) == -155.625 54 | end 55 | 56 | test "Maximal normal number" do 57 | assert unpack(<<0::size(1),30::size(5),1023::size(10)>>) == 65504 58 | end 59 | 60 | test "Minimal normal number" do 61 | assert <<0::size(1),0::size(5),1022::size(10)>> 62 | |> unpack() 63 | |> Float.round(6) == 0.000061 64 | end 65 | 66 | test "Minimal subnormal number" do 67 | assert <<0::size(1),0::size(5),1::size(10)>> 68 | |> unpack() 69 | |> Float.round(9) == 0.00000006 70 | end 71 | 72 | test "Zero" do 73 | assert unpack(<<0::size(1),0::size(5),0::size(10)>>) == 0 74 | end 75 | 76 | test "Not a number" do 77 | assert unpack(<<0::size(1),31::size(5),1::size(10)>>) == :not_a_number 78 | end 79 | 80 | test "Positive infinity" do 81 | assert unpack(<<0::size(1),31::size(5),0::size(10)>>) == :positive_infinity 82 | end 83 | 84 | test "Negative infinity" do 85 | assert unpack(<<1::size(1),31::size(5),0::size(10)>>) == :negative_infinity 86 | end 87 | end 88 | end 89 | -------------------------------------------------------------------------------- /test/graph_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.GraphTest do 2 | use ExUnit.Case 3 | alias Cuda.Graph 4 | alias Cuda.Graph.Node 5 | alias Cuda.Graph.Pin 6 | 7 | import Cuda.Test.GraphHelpers 8 | import Graph, except: [graph: 1, graph: 2] 9 | 10 | alias Cuda.Test.GraphHelpers.Single 11 | alias Cuda.Test.GraphHelpers.Double 12 | 13 | describe "add/4" do 14 | test "adds nodes to graph" do 15 | graph = graph() |> add(:a, Single) 16 | assert [%Node{id: :a}] = graph.nodes 17 | end 18 | 19 | test "rejects nodes with id that already in the graph" do 20 | graph = graph() |> add(:a, Single) 21 | assert_raise(CompileError, fn -> graph |> add(:a, Double) end) 22 | end 23 | end 24 | 25 | describe "link/2" do 26 | test "links graph input to node input" do 27 | graph = graph(pins: [%Pin{id: :i, type: :input, data_type: :i8}]) 28 | |> add(:a, Single) 29 | |> link(:i, {:a, :input}) 30 | assert [{{:__self__, :i}, {:a, :input}}] = graph.links 31 | end 32 | 33 | test "links node output to graph output" do 34 | graph = graph(pins: [%Pin{id: :o, type: :output, data_type: :i8}]) 35 | |> add(:a, Single) 36 | |> link({:a, :output}, :o) 37 | assert [{{:a, :output}, {:__self__, :o}}] = graph.links 38 | end 39 | 40 | test "links graph input to graph output" do 41 | graph = graph(pins: [%Pin{id: :i, type: :input, data_type: :i8}, 42 | %Pin{id: :o, type: :output, data_type: :i8}]) 43 | |> link(:i, :o) 44 | assert [{{:__self__, :i}, {:__self__, :o}}] = graph.links 45 | end 46 | 47 | test "links node output to node input" do 48 | graph = graph() 49 | |> add(:a, Single) 50 | |> add(:b, Single) 51 | |> link({:a, :output}, {:b, :input}) 52 | assert [{{:a, :output}, {:b, :input}}] = graph.links 53 | end 54 | 55 | test "rejects wrong pin type connection" do 56 | graph = graph(pins: [%Pin{id: :i, type: :input, data_type: :i8}, 57 | %Pin{id: :o, type: :output, data_type: :i8}]) 58 | |> add(:a, Single) 59 | |> add(:b, Single) 60 | assert_raise(CompileError, fn -> graph |> link(:o, :i) end) 61 | assert_raise(CompileError, fn -> graph |> link(:i, {:a, :output}) end) 62 | assert_raise(CompileError, fn -> graph |> link({:a, :input}, {:b, :input}) end) 63 | assert_raise(CompileError, fn -> graph |> link({:a, :output}, {:b, :output}) end) 64 | assert_raise(CompileError, fn -> graph |> link({:a, :output}, :i) end) 65 | end 66 | 67 | test "rejects wrong pin data_type connection" do 68 | graph = graph(pins: [%Pin{id: :i, type: :input, data_type: :i16}, 69 | %Pin{id: :o, type: :output, data_type: :i8}]) 70 | |> add(:a, Single) 71 | |> add(:b, Single) 72 | assert_raise(CompileError, fn -> graph |> link(:i, {:a, :input}) end) 73 | end 74 | end 75 | end 76 | -------------------------------------------------------------------------------- /c_src/common.cpp: -------------------------------------------------------------------------------- 1 | #include "common.h" 2 | 3 | ETERM *StringError::AsTerm() { 4 | return FORMAT("{error,~w,~w}", 5 | erl_mk_binary(source.c_str(), source.size()), 6 | erl_mk_binary(message.c_str(), message.size())); 7 | } 8 | 9 | ETERM *RuntimeError::AsTerm() { 10 | const char *name = cudaGetErrorName(code); 11 | const char *str = cudaGetErrorString(code); 12 | return FORMAT("{error,~w,~w,~w}", 13 | erl_mk_binary(source.c_str(), source.size()), 14 | erl_mk_binary(name, strlen(name)), 15 | erl_mk_binary(str, strlen(str))); 16 | } 17 | 18 | ETERM *DriverError::AsTerm() { 19 | const char *name, *str; 20 | if (cuGetErrorName(code, &name) == CUDA_SUCCESS && 21 | cuGetErrorString(code, &str) == CUDA_SUCCESS) { 22 | // DEBUG("DeviceError: " << name << ", " << str); 23 | return FORMAT("{error,~w,~w,~w}", 24 | erl_mk_binary(source.c_str(), source.size()), 25 | erl_mk_binary(name, strlen(name)), 26 | erl_mk_binary(str, strlen(str))); 27 | } 28 | return FORMAT("{error,~w,~w}", 29 | erl_mk_binary(source.c_str(), source.size()), 30 | MAKE_BINARY("Unknown error")); 31 | } 32 | 33 | Keywords GetKeywords(ETERM *list) { 34 | if (!ERL_IS_LIST(list)) throw StringError("Bad argument"); 35 | Keywords map; 36 | auto size = erl_length(list); 37 | for (int i = 0; i < size; i++) { 38 | auto tuple = erl_hd(list); 39 | list = erl_tl(list); 40 | if (!ERL_IS_TUPLE(tuple) || erl_size(tuple) != 2) throw StringError("Bad argument"); 41 | auto keyAtom = erl_element(1, tuple); 42 | if (!ERL_IS_ATOM(keyAtom)) throw StringError("Bad argument"); 43 | std::string key(ERL_ATOM_PTR(keyAtom)); 44 | map.insert(std::pair(key, erl_element(2, tuple))); 45 | } 46 | return map; 47 | } 48 | 49 | template <> int Get(ETERM *value) { 50 | if (!ERL_IS_INTEGER(value)) throw StringError("Bad argument"); 51 | return ERL_INT_VALUE(value); 52 | } 53 | 54 | template <> unsigned int Get(ETERM *value) { 55 | if (!ERL_IS_INTEGER(value)) throw StringError("Bad argument"); 56 | return ERL_INT_UVALUE(value); 57 | } 58 | 59 | template <> bool Get(ETERM *value) { 60 | if (!ERL_IS_ATOM(value)) throw StringError("Bad argument"); 61 | if (ATOM_EQ(value, "true")) return true; 62 | if (ATOM_EQ(value, "false")) return false; 63 | throw StringError("Bad argument"); 64 | } 65 | 66 | int GetModuleIndex(ETERM *value) { 67 | if (!ERL_IS_TUPLE(value) || erl_size(value) != 2) { 68 | throw StringError("Invalid module handle"); 69 | } 70 | auto a = erl_element(1, value); 71 | auto v = erl_element(2, value); 72 | if (!ERL_IS_ATOM(a) || !ATOM_EQ(a, "module")) { 73 | throw StringError("Invalid module handle"); 74 | } 75 | return Get(v); 76 | } 77 | 78 | int GetMemoryIndex(ETERM *value) { 79 | if (!ERL_IS_TUPLE(value) || erl_size(value) != 2) { 80 | throw StringError("Invalid memory handle"); 81 | } 82 | auto a = erl_element(1, value); 83 | auto v = erl_element(2, value); 84 | if (!ERL_IS_ATOM(a) || !ATOM_EQ(a, "memory")) { 85 | throw StringError("Invalid memory handle"); 86 | } 87 | return Get(v); 88 | } 89 | -------------------------------------------------------------------------------- /test/compiler/gpu_node_ptx_helpers_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Compiler.GPUNodePTXHelpersTest do 2 | use ExUnit.Case 3 | 4 | import Cuda.Test.CudaHelpers 5 | 6 | alias Cuda.Compiler.GPUUnit 7 | alias Cuda.Graph.{NodeProto, GPUNode} 8 | alias Cuda.Graph.Factory 9 | alias Cuda.Memory 10 | 11 | defmodule PTXNode do 12 | use GPUNode 13 | def __pins__(_), do: [input(:i, :i16), output(:o, :i32)] 14 | def __ptx__(node), do: Keyword.get(node.assigns.options, :ptx) 15 | end 16 | 17 | defp new_node(ptx) do 18 | Factory.new(%GPUNode{}, :node, PTXNode, [ptx: ptx], env()) 19 | end 20 | 21 | defp gen_ptx(text, opts \\ []) do 22 | node = new_node(text) |> NodeProto.assign(Keyword.get(opts, :node_assigns, %{})) 23 | ctx = context(root: node, path: [], assigns: Keyword.get(opts, :ctx_assigns, %{})) 24 | {:ok, %{assigns: %{sources: [{:ptx, ptx}]}}} = GPUUnit.sources(node, ctx) 25 | parse_ptx(ptx) 26 | end 27 | 28 | describe "offset/2" do 29 | test "returns memory offset" do 30 | assert gen_ptx(~s{<%= offset(ctx, :pins, :i) %>}) == ["0"] 31 | assert gen_ptx(~s{<%= offset(ctx, :pins, :o) %>}) == ["2"] 32 | end 33 | end 34 | 35 | describe "shared_offset/2" do 36 | test "returns shared offset" do 37 | memory = %Memory{vars: [ 38 | {:a, {10, %{node1: :i16, node2: :i32}}}, 39 | {:b, {30, %{node1: :i16, node2: :i32}}} 40 | ]} 41 | assigns = %{vars: %{layer: :node1}, memory: %{shared: memory}} 42 | assert gen_ptx(~s{<%= shared_offset(ctx, :a) %>}, ctx_assigns: assigns) == ["10"] 43 | assert gen_ptx(~s{<%= shared_offset(ctx, :b) %>}, ctx_assigns: assigns) == ["30"] 44 | assigns = %{vars: %{layer: :node2}, memory: %{shared: memory}} 45 | assert gen_ptx(~s{<%= shared_offset(ctx, :a) %>}, ctx_assigns: assigns) == ["12"] 46 | assert gen_ptx(~s{<%= shared_offset(ctx, :b) %>}, ctx_assigns: assigns) == ["32"] 47 | end 48 | end 49 | 50 | describe "defkernel/2" do 51 | test "expands to kernel function declaration" do 52 | ptx = gen_ptx(~s{<%= defkernel(ctx, "x") do %>\n<% end %>}) 53 | assert ptx == [".visible .entry node__x (.param .u64 .ptr pins) {}"] 54 | end 55 | 56 | test "accepts additional parameters" do 57 | ptx = gen_ptx(~s{<%= defkernel(ctx, "x", a: :u8) do %>\n<% end %>}) 58 | assert ptx == [".visible .entry node__x (.param .u64 .ptr pins, .param .u8 a) {}"] 59 | ptx = gen_ptx(~s{<%= defkernel(ctx, "x", a: u8) do %>\n<% end %>}) 60 | assert ptx == [".visible .entry node__x (.param .u64 .ptr pins, .param .u8 a) {}"] 61 | ptx = gen_ptx(~s{<%= defkernel(ctx, "x", a: u8.ptr.local) do %>\n<% end %>}) 62 | assert ptx == [".visible .entry node__x (.param .u64 .ptr pins, .param .u8 .ptr .local a) {}"] 63 | ptx = gen_ptx(~s{<%= defkernel(ctx, "x", a: u8.ptr.align-16) do %>\n<% end %>}) 64 | assert ptx == [".visible .entry node__x (.param .u64 .ptr pins, .param .u8 .ptr .align 16 a) {}"] 65 | ptx = gen_ptx(~s{<%= defkernel(ctx, "x", a: u8.ptr.local.align-8) do %>\n<% end %>}) 66 | assert ptx == [".visible .entry node__x (.param .u64 .ptr pins, .param .u8 .ptr .local .align 8 a) {}"] 67 | end 68 | end 69 | end 70 | -------------------------------------------------------------------------------- /lib/cuda/compiler/gpu_node.ex: -------------------------------------------------------------------------------- 1 | defimpl Cuda.Compiler.GPUUnit, for: Cuda.Graph.GPUNode do 2 | alias Cuda.Compiler.Context 3 | alias Cuda.{Template, Template.PtxHelpers} 4 | alias Cuda.Graph.{Node, NodeProto} 5 | 6 | import Cuda.Compiler.Utils 7 | 8 | def sources(node, ctx) do 9 | node = put_pins_shapes(node) 10 | ctx = Context.replace_current(ctx, node) 11 | helpers = [PtxHelpers] ++ Map.get(node.assigns, :helpers, []) 12 | opts = [context: ctx, helpers: helpers] 13 | ptx = case node.module.__ptx__(node) do 14 | src when is_bitstring(src) -> [src] 15 | src when is_list(src) -> src 16 | _ -> [] 17 | end 18 | ptx = ptx 19 | |> Enum.map(& Template.ptx_preprocess(&1, opts)) 20 | |> Enum.map(&include_header(ctx, &1)) 21 | c = case node.module.__c__(node) do 22 | src when is_bitstring(src) -> [src] 23 | src when is_list(src) -> src 24 | _ -> [] 25 | end 26 | c = c |> Enum.map(& Template.c_preprocess(&1, opts)) 27 | {:ok, NodeProto.assign(node, :sources, ptx ++ c)} 28 | end 29 | 30 | @line_re ["\n", "\r\n", "\n\r"] 31 | @space_re ~r/\s+/ 32 | @header_directives ~w(.version .target .address_size) 33 | defp include_header(ctx, {:ptx, src}) do 34 | directives = src 35 | |> String.split(@line_re) 36 | |> Enum.map(&String.trim/1) 37 | |> Enum.map(&String.split(&1, @space_re)) 38 | |> Enum.map(&List.first/1) 39 | |> Enum.filter(& &1 in @header_directives) 40 | src = if ".address_size" in directives do 41 | src 42 | else 43 | PtxHelpers.address_size(ctx) <> src 44 | end 45 | src = if ".target" in directives do 46 | src 47 | else 48 | PtxHelpers.target(ctx) <> src 49 | end 50 | src = if ".version" in directives do 51 | src 52 | else 53 | PtxHelpers.version() <> src 54 | end 55 | {:ptx, src} 56 | end 57 | end 58 | 59 | defimpl Cuda.Compiler.Unit, for: Cuda.Graph.GPUNode do 60 | alias Cuda.Compiler 61 | alias Cuda.Compiler.{Context, GPUUnit} 62 | alias Cuda.Graph.{Node, NodeProto} 63 | require Logger 64 | 65 | def compile(node, ctx) do 66 | Logger.info("CUDA: Compiling GPU code for node #{node.module} (#{node.id})") 67 | with {:ok, node} <- node.module.__compile__(node), 68 | ctx = Context.for_node(ctx, node), 69 | {:ok, node} <- GPUUnit.sources(node, ctx), 70 | {:ok, cubin} <- Compiler.compile(node.assigns.sources) do 71 | batch = node.module.__batch__(node) 72 | |> Enum.map(fn 73 | {:run, {name, g, b, args}} -> 74 | {:run, {"#{Node.string_id(node.id)}__#{name}", g, b, args}} 75 | {:run, {name, g, b}} -> 76 | {:run, {"#{Node.string_id(node.id)}__#{name}", g, b, []}} 77 | end) 78 | node = node 79 | |> NodeProto.assign(:cubin, cubin) 80 | |> NodeProto.assign(:batch, batch) 81 | {:ok, node} 82 | else 83 | _ -> 84 | Logger.warn("CUDA: Error occured while compiling GPU code for node " <> 85 | "#{node.module} (#{node.id})") 86 | {:error, :compile_error} 87 | end 88 | end 89 | end 90 | -------------------------------------------------------------------------------- /lib/cuda/shared.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Shared do 2 | use GenServer 3 | alias Cuda.Memory 4 | 5 | def start_link(opts \\ []) do 6 | {opts, args} = Keyword.split(opts, ~w(name)a) 7 | GenServer.start_link(__MODULE__, args, opts) 8 | end 9 | 10 | def init(opts) do 11 | with {:ok, cuda} <- Cuda.start_link() do 12 | st = %{cuda: cuda, memory: Keyword.get(opts, :memory), ref: nil} 13 | opts |> Keyword.get(:vars, %{}) |> load_vars(st) 14 | end 15 | end 16 | 17 | def load(pid, vars) do 18 | GenServer.call(pid, {:load, vars}) 19 | end 20 | def load(pid, memory, vars) do 21 | GenServer.call(pid, {:load, memory, vars}) 22 | end 23 | 24 | def unload(pid) do 25 | GenServer.call(pid, :unload) 26 | end 27 | 28 | def handle(pid) do 29 | GenServer.call(pid, :handle) 30 | end 31 | 32 | def memory(pid) do 33 | GenServer.call(pid, :memory) 34 | end 35 | 36 | def data(pid) do 37 | GenServer.call(pid, :data) 38 | end 39 | 40 | def share(pid) do 41 | GenServer.call(pid, :share) 42 | end 43 | 44 | def vars(pid) do 45 | GenServer.call(pid, :vars) 46 | end 47 | 48 | def handle_call({:load, vars}, _from, st) do 49 | with {:ok, st} <- load_vars(vars, st) do 50 | {:reply, {:ok, st.ref}, st} 51 | end 52 | end 53 | 54 | def handle_call({:load, memory, vars}, _from, st) do 55 | st = %{st | memory: memory} 56 | with {:ok, st} <- load_vars(vars, st) do 57 | {:reply, {:ok, st.ref}, st} 58 | end 59 | end 60 | 61 | def handle_call(:unload, _from, %{ref: ref} = st) when not is_nil(ref) do 62 | with :ok <- Cuda.memory_unload(st.cuda, st.ref) do 63 | {:reply, :ok, %{st | extracts: %{}, ref: nil}} 64 | else 65 | result -> {:reply, result, st} 66 | end 67 | end 68 | def handle_call(:unload, _from, st) do 69 | {:reply, :ok, st} 70 | end 71 | 72 | def handle_call(:handle, _from, st) do 73 | {:reply, {:ok, st.ref}, st} 74 | end 75 | 76 | def handle_call(:memory, _from, st) do 77 | {:reply, {:ok, st.memory}, st} 78 | end 79 | 80 | def handle_call(:data, _from, %{ref: ref} = st) when not is_nil(ref) do 81 | result = Cuda.memory_read(st.cuda, ref) 82 | {:reply, result, st} 83 | end 84 | def handle_call(:data, _from, st) do 85 | {:reply, {:ok, nil}, st} 86 | end 87 | 88 | def handle_call(:share, _from, st) do 89 | result = Cuda.memory_share(st.cuda, st.ref) 90 | {:reply, result, st} 91 | end 92 | 93 | def handle_call(:vars, _from, %{ref: ref} = st) when not is_nil(ref) do 94 | result = with {:ok, data} <- Cuda.memory_read(st.cuda, ref) do 95 | {:ok, Memory.unpack(data, st.memory)} 96 | end 97 | {:reply, result, st} 98 | end 99 | 100 | defp load_vars(_, %{memory: nil} = st), do: {:ok, st} 101 | defp load_vars(vars, %{memory: memory} = st) do 102 | #IO.inspect({vars, memory}) 103 | bin = Memory.pack(vars, memory) 104 | #IO.inspect({memory, vars, (for <>, do: x)}) 105 | if byte_size(bin) > 0 do 106 | unload = case st.ref do 107 | nil -> :ok 108 | ref -> Cuda.memory_unload(st.cuda, ref) 109 | end 110 | with :ok <- unload, 111 | {:ok, ref} <- Cuda.memory_load(st.cuda, bin) do 112 | {:ok, %{st | ref: ref}} 113 | end 114 | else 115 | {:ok, %{st | ref: nil}} 116 | end 117 | end 118 | end 119 | -------------------------------------------------------------------------------- /test/cuda_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.CudaTest do 2 | use ExUnit.Case 3 | doctest Cuda 4 | 5 | describe "stream/3" do 6 | # TODO: Sometimes this test fails with following message: 7 | # 1) test stream/3 process batch calculations (Cuda.CudaTest) 8 | # test/cuda_test.exs:6 9 | # Assertion with == failed 10 | # code: result == <<3::little-32, 4::little-32, 5::little-32, 6::little-32, 7::little-32, 8::little-32, 9::little-32, 10::little-32>> 11 | # left: <<1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 6, 0, 12 | # 0, 0, 7, 0, 0, 0, 8, 0, 0, 0>> 13 | # right: <<3, 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, 8, 0, 14 | # 0, 0, 9, 0, 0, 0, 10, 0, 0, 0>> 15 | # stacktrace: 16 | # test/cuda_test.exs:67: (test) 17 | # 18 | # We should reproduce this error and rewrite test 19 | 20 | test "process batch calculations" do 21 | # x[i] = x[i] + 2 22 | ptx1 = """ 23 | .version 4.3 24 | .target sm_20 25 | .address_size 64 26 | .visible .entry ptx1( 27 | .param .u64 .ptr .global input, 28 | .param .u64 .ptr .global output 29 | ) { 30 | .reg .u64 %p<3>; 31 | .reg .u32 %r<2>; 32 | ld.param.u64 %p0, [input]; 33 | ld.param.u64 %p1, [output]; 34 | mov.u32 %r1, %tid.x; 35 | mul.wide.u32 %p2, %r1, 4; 36 | add.u64 %p0, %p0, %p2; 37 | add.u64 %p1, %p1, %p2; 38 | ld.global.u32 %r0, [%p0]; 39 | add.u32 %r0, %r0, 2; 40 | st.global.u32 [%p1], %r0; 41 | } 42 | """ 43 | 44 | # x[i] = x[i] - 1 45 | ptx2 = """ 46 | .version 4.3 47 | .target sm_20 48 | .address_size 64 49 | .visible .entry ptx2( 50 | .param .u64 .ptr .global input, 51 | .param .u64 .ptr .global output 52 | ) { 53 | .reg .u64 %p<3>; 54 | .reg .u32 %r<2>; 55 | ld.param.u64 %p0, [input]; 56 | ld.param.u64 %p1, [output]; 57 | mov.u32 %r1, %tid.x; 58 | mul.wide.u32 %p2, %r1, 4; 59 | add.u64 %p0, %p0, %p2; 60 | add.u64 %p1, %p1, %p2; 61 | ld.global.u32 %r0, [%p0]; 62 | sub.u32 %r0, %r0, 1; 63 | st.global.u32 [%p1], %r0; 64 | } 65 | """ 66 | 67 | data = <<1::little-32, 2::little-32, 3::little-32, 4::little-32, 68 | 5::little-32, 6::little-32, 7::little-32, 8::little-32>> 69 | {:ok, cuda} = Cuda.start_link() 70 | {:ok, input} = Cuda.memory_load(cuda, data) 71 | {:ok, output} = Cuda.memory_load(cuda, <<0::size(8)-unit(32)>>) 72 | {:ok, module} = Cuda.compile(cuda, [ptx1, ptx2]) 73 | 74 | # x[i] = x[i] + 2 - 1 + 2 - 1 75 | batch = [[{:run, {"ptx1", {8, 1, 1}, {1, 1, 1}, [input, output]}}, 76 | {:run, {"ptx2", {8, 1, 1}, {1, 1, 1}, [output, input]}}, 77 | {:run, {"ptx1", {8, 1, 1}, {1, 1, 1}, [input, output]}}, 78 | {:run, {"ptx2", {8, 1, 1}, {1, 1, 1}, [output, input]}}]] 79 | :ok = Cuda.stream(cuda, module, batch) 80 | {:ok, result} = Cuda.memory_read(cuda, input) 81 | assert result == <<3::little-32, 4::little-32, 5::little-32, 6::little-32, 82 | 7::little-32, 8::little-32, 9::little-32, 10::little-32>> 83 | end 84 | end 85 | end 86 | -------------------------------------------------------------------------------- /lib/cuda/visualize/dot.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Graph.Visualize.Dot do 2 | alias Cuda.Graph.Node 3 | alias Cuda.Graph.NodeProto 4 | alias Cuda.Graph.Processing 5 | import Node, only: [input_pin_types: 0, output_pin_types: 0] 6 | 7 | def render(graph, opts \\ []) do 8 | gv = render_node(graph) 9 | n = UUID.uuid1() 10 | file = Path.join(System.tmp_dir!, "#{n}.gv") 11 | File.write(file, gv) 12 | #IO.puts(gv) 13 | out = Keyword.get(opts, :output, "#{n}.svg") 14 | System.cmd("dot", ["-Tsvg", file, "-o", out]) 15 | File.rm_rf!(file) 16 | end 17 | 18 | defp render_pin(pin, node_id) do 19 | pin_id = node_id(pin.id, node_id) 20 | label = if is_nil(pin.group), do: "#{pin.id}", else: "#{pin.id} (#{pin.group})" 21 | ~s(#{pin_id}[label="#{label}"]) 22 | end 23 | 24 | defp render_node(node, parent_id \\ nil) do 25 | id = node_id(node.id, parent_id) 26 | 27 | color = case node.type do 28 | :gpu -> "blue" 29 | :computation_graph -> "green" 30 | _ -> "lightgrey" 31 | end 32 | 33 | g = if parent_id == nil, do: "digraph", else: "subgraph" 34 | 35 | i = node 36 | |> NodeProto.pins(input_pin_types()) 37 | |> Enum.map(& render_pin(&1, id)) 38 | |> Enum.join("; ") 39 | i = "subgraph #{id}_inputs_cluster {rankdir=TB;#{i}}" 40 | 41 | o = node 42 | |> NodeProto.pins(output_pin_types()) 43 | |> Enum.map(& render_pin(&1, id)) 44 | |> Enum.join("; ") 45 | o = case o do 46 | "" -> "" 47 | o -> o <> ";" 48 | end 49 | 50 | links = if parent_id == nil, do: render_links(node), else: [] 51 | links = links |> Enum.join("; ") 52 | 53 | children = case Map.get(node, :nodes) do 54 | nil -> "" 55 | nodes -> nodes |> Enum.map(& render_node(&1, id)) |> Enum.join("\n") 56 | end 57 | 58 | layout = if parent_id == nil do 59 | "rankdir=LR;rank=source;" 60 | else 61 | "rankdir=LR;rank=source;" 62 | end 63 | 64 | """ 65 | #{g} "cluster_#{id}" { 66 | #{layout} 67 | label="#{Node.string_id(node.id)}"; 68 | color=#{color}; 69 | shape=box; 70 | #{i}; 71 | #{children} 72 | #{links} 73 | #{o} 74 | } 75 | """ 76 | end 77 | 78 | defp render_links(%{id: gid} = node, parent_id \\ nil) do 79 | id = node_id(node.id, parent_id) 80 | result = Processing.dfs(node, fn 81 | :enter, {%{nodes: _} = src, _}, st -> 82 | if node_id(src.id, parent_id) in st.nodes do 83 | {:ok, st} 84 | else 85 | links = render_links(src, id) 86 | {:ok, %{st | nodes: [src.id | st.nodes], links: st.links ++ links}} 87 | end 88 | :enter, {src, _}, st -> 89 | if node_id(src.id, parent_id) in st.nodes do 90 | {:ok, st} 91 | else 92 | {:ok, %{st | nodes: [src.id | st.nodes]}} 93 | end 94 | :move, {{src, src_pin}, {dst, dst_pin}}, st -> 95 | sp = if src.id == gid, do: parent_id, else: node_id(gid, parent_id) 96 | dp = if dst.id == gid, do: parent_id, else: node_id(gid, parent_id) 97 | l = "#{node_id({src.id, src_pin.id}, sp)} -> #{node_id({dst.id, dst_pin.id}, dp)}" 98 | {:ok, %{st | links: [l | st.links]}} 99 | _, _, st -> 100 | {:ok, st} 101 | end, %{nodes: [id], links: [], path: nil}) 102 | with {:ok, st} <- result, do: st.links |> Enum.uniq 103 | end 104 | 105 | defp node_id(id, nil), do: Node.string_id(id) |> String.replace("-", "") 106 | defp node_id(id, parent_id), do: Node.string_id({parent_id, id}) |> String.replace("-", "") 107 | end 108 | -------------------------------------------------------------------------------- /lib/cuda/compiler/utils.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Compiler.Utils do 2 | alias Cuda.Graph.{Node, NodeProto, Pin} 3 | alias Cuda.Memory 4 | import Node, only: [input_pin_types: 0, output_pin_types: 0] 5 | 6 | # TODO: Fix negative skips. 7 | # reproduced by `mix test test/network_test.exs:85`, 8 | # negative skip in input of {:back_propagation, :fc, :fc_node} 9 | def put_pins_shapes(%{assigns: %{pin_offsets: offsets}} = node) do 10 | #IO.inspect({node.id, offsets}) 11 | size = case Map.get(node.assigns, :pin_size) do 12 | nil -> node.pins |> Enum.map(&Pin.data_size/1) |> Enum.reduce(0, &+/2) 13 | size -> size 14 | end 15 | o = offsets |> Map.values() 16 | pins = if o == Enum.uniq(o) do 17 | node.pins |> pins_shape(offsets, size) 18 | else 19 | inputs = node 20 | |> NodeProto.pins(input_pin_types()) 21 | |> pins_shape(offsets, size) 22 | outputs = node 23 | |> NodeProto.pins(output_pin_types()) 24 | |> pins_shape(offsets, size) 25 | %Memory{vars: inputs.vars ++ outputs.vars} 26 | end# |> Memory.inspect_structure(label: node.id) 27 | memory = node.assigns 28 | |> Map.get(:memory, %{}) 29 | |> Map.put(:pins, pins) 30 | NodeProto.assign(node, memory: memory) 31 | end 32 | def put_pins_shapes(%{pins: _} = node) do 33 | pins = node.pins |> pins_shape()# |> Memory.inspect_structure(label: node.id) 34 | memory = node.assigns 35 | |> Map.get(:memory, %{}) 36 | |> Map.put(:pins, pins) 37 | NodeProto.assign(node, :memory, memory) 38 | end 39 | def put_pins_shapes(node), do: node 40 | 41 | def pins_shape(pins) do 42 | {vars, _} = pins |> Enum.map_reduce(0, fn pin, offset -> 43 | size = Pin.data_size(pin) 44 | {{pin.id, {offset, pin.data_type}}, offset + size} 45 | end) 46 | %Memory{vars: vars} 47 | end 48 | def pins_shape(pins, offsets, _pin_size) do 49 | pin_ids = pins |> Enum.map(& &1.id) 50 | offsets = offsets |> Enum.sort_by(fn {_, {o, _}} -> o end) 51 | pin_size = offsets 52 | |> Enum.map(& elem(elem(&1, 1), 1)) 53 | |> Enum.reduce(0, &+/2) 54 | [{_, {first_offset, _}} | _] = offsets 55 | pin_size = pin_size + first_offset 56 | {k, o} = offsets 57 | |> Enum.filter(fn {k, _} -> k in pin_ids end) 58 | |> Enum.unzip() 59 | s = o |> Enum.map(& elem(&1, 0)) 60 | s = Enum.chunk(s ++ [pin_size], 2, 1) |> Enum.map(fn [a, b] -> b - a end) 61 | vars = [k, o, s] 62 | |> Enum.zip() 63 | |> Enum.with_index() 64 | |> Enum.map(fn {{pin_id, {offset, data_size}, size}, idx} -> 65 | pin = Enum.find(pins, & &1.id == pin_id) 66 | skip_before = if idx == 0 and offset > 0, do: offset, else: 0 67 | skip_after = case size - data_size do 68 | skip when skip < 0 -> pin_size + skip 69 | skip -> skip 70 | end 71 | shape = case {skip_before, skip_after} do 72 | {0, 0} -> pin.data_type 73 | {0, a} -> %Memory.Shape{skip: a, type: pin.data_type} 74 | {b, a} -> %Memory.Shape{skip: {b, a}, type: pin.data_type} 75 | end 76 | {pin_id, {offset, shape}} 77 | end) 78 | %Memory{vars: vars} 79 | end 80 | 81 | def wrap_pins(pins) when is_map(pins) do 82 | pins |> Enum.map(fn {k, v} -> {k, [v]} end) |> Enum.into(%{}) 83 | end 84 | def wrap_pins(pins), do: pins 85 | 86 | def unwrap_pins(pins) when is_map(pins) do 87 | pins |> Enum.map(fn {k, [v | _]} -> {k, v} end) |> Enum.into(%{}) 88 | end 89 | def unwrap_pins(pins), do: pins 90 | end 91 | -------------------------------------------------------------------------------- /lib/cuda/compiler/context.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Compiler.Context do 2 | @moduledoc """ 3 | Compilation context 4 | """ 5 | 6 | alias Cuda.Graph.GraphProto 7 | 8 | @type t :: %__MODULE__{ 9 | env: Cuda.Env.t, 10 | assigns: map, 11 | root: Cuda.Graph.t | Cuda.Node.t, 12 | path: [Cuda.Graph.id] 13 | } 14 | 15 | defstruct [:env, :root, assigns: %{}, path: []] 16 | 17 | def new(opts) do 18 | struct(__MODULE__, Enum.into(opts, [])) 19 | end 20 | 21 | def assign(%__MODULE__{assigns: assigns} = ctx, values) do 22 | %{ctx | assigns: Map.merge(assigns, Enum.into(values, %{}))} 23 | end 24 | def assign(%__MODULE__{assigns: assigns} = ctx, key, value) do 25 | %{ctx | assigns: Map.put(assigns, key, value)} 26 | end 27 | 28 | defp expanded_path(path) do 29 | path 30 | |> Enum.flat_map(fn 31 | id when is_tuple(id) -> id |> Tuple.to_list |> Enum.reverse 32 | id -> [id] 33 | end) 34 | end 35 | 36 | defp find_in_expanded(_, _, []), do: nil 37 | defp find_in_expanded(assigns, path, [_ | rest] = ctx_path) do 38 | expanded = ctx_path |> Enum.reverse |> Enum.intersperse(:expanded) 39 | expanded = [:expanded | expanded] ++ path 40 | with nil <- get_in(assigns, expanded) do 41 | find_in_expanded(assigns, path, rest) 42 | end 43 | end 44 | 45 | def find_assign(ctx, path, ctx_path \\ nil, callback \\ fn _ -> true end) 46 | def find_assign(ctx, path, nil, callback) do 47 | find_assign(ctx, path, ctx.path, callback) 48 | end 49 | def find_assign(%{root: nil}, _, _, _), do: nil 50 | def find_assign(ctx, path, [], callback) do 51 | value = with nil <- get_in(ctx.root.assigns, [:expanded | path]) do 52 | get_in(ctx.root.assigns, path) 53 | end 54 | if not is_nil(value) and callback.(value), do: value, else: nil 55 | end 56 | def find_assign(ctx, path, [_ | rest] = ctx_path, callback) do 57 | with nil <- find_in_expanded(ctx.root.assigns, path, expanded_path(ctx_path)) do 58 | with %{} = child <- node(ctx, ctx_path) do 59 | with nil <- find_in_expanded(child.assigns, path, expanded_path(rest)), 60 | nil <- get_in(child.assigns, path) do 61 | find_assign(ctx, path, rest, callback) 62 | else 63 | value -> 64 | if callback.(value) do 65 | value 66 | else 67 | find_assign(ctx, path, rest, callback) 68 | end 69 | end 70 | else 71 | _ -> find_assign(ctx, path, rest, callback) 72 | end 73 | end 74 | end 75 | def find_assign(_, _, _, _), do: nil 76 | 77 | def for_node(%__MODULE__{root: nil} = ctx, child) do 78 | %{ctx | root: child, path: []} 79 | end 80 | def for_node(%__MODULE__{path: path} = ctx, %{id: id}) do 81 | {_, with_id} = Enum.split_while(path, & &1 != id) 82 | path = if with_id == [], do: [id | path], else: with_id 83 | %{ctx | path: path} 84 | end 85 | 86 | def node(ctx, node \\ nil) 87 | def node(%__MODULE__{root: root, path: []}, nil), do: root 88 | def node(%__MODULE__{path: path} = ctx, nil), do: node(ctx, path) 89 | def node(%__MODULE__{root: nil}, _), do: nil 90 | def node(%__MODULE__{root: root}, []) do 91 | root 92 | end 93 | def node(%__MODULE__{root: root}, path) do 94 | GraphProto.node(root, Enum.reverse(path)) 95 | end 96 | def node(_, _) do 97 | nil 98 | end 99 | 100 | def replace_current(%__MODULE__{path: []} = ctx, node) do 101 | %{ctx | root: node} 102 | end 103 | def replace_current(%__MODULE__{root: root, path: [_ | rest] = id} = ctx, node) do 104 | %{ctx | path: [node.id | rest], root: GraphProto.replace(root, id |> Enum.reverse, node)} 105 | end 106 | end 107 | -------------------------------------------------------------------------------- /c_src/erlang_port.cpp: -------------------------------------------------------------------------------- 1 | #include "erlang_port.h" 2 | 3 | using __gnu_cxx::stdio_filebuf; 4 | 5 | // Swaps big-endian to little-endian or opposite 6 | template void EndianSwap(T *buffer) { 7 | unsigned char *mem = reinterpret_cast(buffer); 8 | std::reverse(mem, mem + sizeof(T)); 9 | } 10 | 11 | ErlangPort::ErlangPort() : 12 | input(new stdio_filebuf(PORTIN_FILENO, std::ios::in)), 13 | output(new stdio_filebuf(PORTOUT_FILENO, std::ios::out)) { 14 | input.exceptions(std::ifstream::failbit | std::ifstream::badbit | 15 | std::ifstream::eofbit); 16 | output.exceptions(std::ofstream::failbit | std::ofstream::badbit | 17 | std::ofstream::eofbit); 18 | erl_init(NULL, 0); 19 | DEBUG("Port initialized"); 20 | } 21 | 22 | ErlangPort::~ErlangPort() { 23 | DEBUG("Port destroyed"); 24 | if (tuple) erl_free_compound(tuple); 25 | if (funcAtom) erl_free_term(funcAtom); 26 | if (arg) erl_free_term(arg); 27 | if (result) erl_free_term(result); 28 | } 29 | 30 | uint32_t ErlangPort::ReadPacketLength() { 31 | uint32_t len; 32 | input.read(reinterpret_cast(&len), sizeof(len)); 33 | EndianSwap(&len); 34 | return len; 35 | } 36 | 37 | uint8_t ErlangPort::ReadPacketType() { 38 | uint8_t type; 39 | input.read(reinterpret_cast(&type), sizeof(type)); 40 | return type; 41 | } 42 | 43 | uint8_t ErlangPort::ReadRawFunc() { 44 | uint8_t func; 45 | input.read(reinterpret_cast(&func), sizeof(func)); 46 | return func; 47 | } 48 | 49 | ETERM *ErlangPort::ReadTermPacket(uint32_t len) { 50 | // Read packet data, len bytes 51 | std::string buf(len, 0); 52 | input.read((char *)buf.c_str(), len); 53 | // Decode packet 54 | return erl_decode((unsigned char *)buf.c_str()); 55 | } 56 | 57 | void ErlangPort::WritePacketLength(uint32_t len) { 58 | EndianSwap(&len); 59 | output.write(reinterpret_cast(&len), sizeof(len)); 60 | } 61 | 62 | #define TERM_PACKET 1 63 | #define RAW_PACKET 2 64 | 65 | void ErlangPort::WriteTermPacket(ETERM *packet) { 66 | auto len = erl_term_len(packet); 67 | uint8_t type = TERM_PACKET; 68 | std::string buf(len, 0); 69 | erl_encode(packet, (unsigned char *)buf.c_str()); 70 | WritePacketLength(len + 1); 71 | output.write((const char *)&type, 1); 72 | output.write(buf.c_str(), len); 73 | output.flush(); 74 | } 75 | 76 | void ErlangPort::WriteRawPacket(void *data, size_t size) { 77 | uint8_t type = RAW_PACKET; 78 | WritePacketLength(size + 1); 79 | output.write((const char *)&type, 1); 80 | output.write((const char *)data, size); 81 | output.flush(); 82 | } 83 | 84 | void ErlangPort::Loop() { 85 | while(true) { 86 | // Read packet length, 4 bytes 87 | auto len = ReadPacketLength(); 88 | auto type = ReadPacketType(); 89 | // ErlangHandler handler = NULL; 90 | result = NULL; 91 | if (type == TERM_PACKET) { 92 | tuple = ReadTermPacket(len - 1); 93 | if (!ERL_IS_TUPLE(tuple) || ERL_TUPLE_SIZE(tuple) != 2) continue; 94 | // Retrieve function name and argument 95 | funcAtom = erl_element(1, tuple); 96 | arg = erl_element(2, tuple); 97 | // If first element of tuple is not an atom - skip it 98 | if (!ERL_IS_ATOM(funcAtom)) continue; 99 | std::string termFunc(ERL_ATOM_PTR(funcAtom)); 100 | if (termFunc == "exit") break; 101 | // handle request 102 | try { 103 | result = HandleTermFunction(termFunc, arg); 104 | } catch (Error &error) { 105 | result = error.AsTerm(); 106 | } 107 | } else if (type == RAW_PACKET) { 108 | // read size of function name 109 | uint8_t funcSize = 0; 110 | input.read((char *)&funcSize, 1); 111 | if (funcSize == 0) continue; 112 | // read function name 113 | std::string rawFunc(funcSize, 0); 114 | input.read((char *)rawFunc.c_str(), funcSize); 115 | if (rawFunc == "exit") break; 116 | // read raw data 117 | len = len - 2 - funcSize; 118 | std::shared_ptr data(new char[len]); 119 | input.read((char *)data.get(), len); 120 | // handle request 121 | try { 122 | result = HandleRawFunction(rawFunc, data, len); 123 | } catch (Error &error) { 124 | result = error.AsTerm(); 125 | } 126 | } 127 | 128 | if (result) { 129 | WriteTermPacket(result); 130 | } 131 | }; 132 | } 133 | -------------------------------------------------------------------------------- /lib/cuda/runner/graph.ex: -------------------------------------------------------------------------------- 1 | defimpl Cuda.Runner, for: Cuda.Graph do 2 | alias Cuda.Graph.NodeProto 3 | alias Cuda.Memory 4 | 5 | import Cuda.Graph.Node, only: [input_pin_types: 0] 6 | 7 | def load(%{type: :computation_graph, assigns: assigns} = graph, opts) do 8 | with cuda when not is_nil(cuda) <- Keyword.get(opts, :cuda) do 9 | # load cubin into GPU 10 | {:ok, module} = Cuda.module_load(cuda, assigns.cubin) 11 | # load args into GPU 12 | args = opts 13 | |> Keyword.get(:args, %{}) 14 | |> Enum.reduce(%{}, fn 15 | {k, {m, _} = loaded}, args when m in ~w(memory shared_memory)a -> 16 | Map.put(args, k, loaded) 17 | {k, {type, value}}, args -> 18 | bin = Memory.pack(type, value) 19 | with {:ok, marg} <- Cuda.memory_load(cuda, bin) do 20 | Map.put(args, k, marg) 21 | else 22 | _ -> 23 | # TODO: warning here 24 | args 25 | end 26 | _, args -> 27 | args 28 | end) 29 | {:ok, NodeProto.assign(graph, cuda_module: module, cuda_args: args)} 30 | end 31 | end 32 | def load(%{nodes: nodes} = graph, opts) do 33 | nodes = nodes |> Enum.reduce([], fn node, nodes -> 34 | with {:ok, loaded} <- Cuda.Runner.load(node, opts) do 35 | [loaded] ++ nodes 36 | else 37 | _ -> [node] ++ nodes 38 | end 39 | end) 40 | {:ok, %{graph | nodes: nodes}} 41 | end 42 | 43 | def run(%{type: :computation_graph, assigns: assigns}, inputs, opts) do 44 | with cuda when not is_nil(cuda) <- Keyword.get(opts, :cuda) do 45 | # get input and convert it to binary 46 | pins = Memory.pack(inputs, assigns.memory.pins)# |> IO.inspect) 47 | # load pins into GPU 48 | {:ok, mpins} = Cuda.memory_load(cuda, pins) 49 | # prepare arguments and batch list 50 | args = Map.merge(Map.get(assigns, :cuda_args, %{}), 51 | Keyword.get(opts, :args, %{})) 52 | batches = assigns.batches |> Enum.map(fn batch -> 53 | Enum.map(batch, fn 54 | {:run, {name, k, b, params}} -> 55 | params = Enum.map(params, & Map.get(args, &1)) 56 | {:run, {name, k, b, [mpins | params]}} 57 | {:run, {name, k, b}} -> 58 | {:run, {name, k, b, [mpins]}} 59 | x -> 60 | x 61 | end) 62 | end) 63 | # run computation on GPU 64 | #IO.inspect({assigns.cuda_module, batches}) 65 | :ok = Cuda.stream(cuda, assigns.cuda_module, batches) 66 | {:ok, pins} = Cuda.memory_read(cuda, mpins) 67 | #IO.inspect(for <>, do: x) 68 | #IO.inspect(byte_size(pins)) 69 | output = pins |> Memory.unpack(assigns.memory.pins) 70 | {:ok, output} 71 | else 72 | _ -> {:error, :no_cuda_specified} 73 | end 74 | end 75 | def run(graph, inputs, opts) do 76 | pins = graph.links |> Enum.reduce(%{}, fn 77 | {{:__self__, input}, {dst, pin}}, pins -> 78 | node_pins = Map.get(pins, dst, %{}) 79 | node_pins = Map.put(node_pins, pin, Map.get(inputs, input)) 80 | Map.put(pins, dst, node_pins) 81 | _, pins -> 82 | pins 83 | end) 84 | pins = graph.nodes |> Enum.reduce({:ok, pins}, fn 85 | %{id: id} = node, {:ok, pins} -> 86 | inputs = node 87 | |> NodeProto.pins(input_pin_types()) 88 | |> Enum.map(& &1.id) 89 | |> Enum.into(MapSet.new) 90 | data = Map.get(pins, node.id, %{}) 91 | available = data |> Map.keys() |> Enum.into(MapSet.new) 92 | #IO.inspect(MapSet.difference(inputs, available) |> MapSet.to_list(), label: CHECK) 93 | with [] <- MapSet.difference(inputs, available) |> MapSet.to_list() do 94 | inputs = data |> Map.take(MapSet.to_list(inputs)) 95 | with {:ok, outputs} <- Cuda.Runner.run(node, inputs, opts) do 96 | #data = Map.merge(data, outputs) 97 | pins = graph.links |> Enum.reduce(pins, fn 98 | {{^id, output}, {dst, pin}}, pins -> 99 | node_pins = Map.get(pins, dst, %{}) 100 | node_pins = Map.put(node_pins, pin, Map.get(outputs, output)) 101 | Map.put(pins, dst, node_pins) 102 | _, pins -> 103 | pins 104 | end) 105 | {:ok, pins} 106 | end 107 | else 108 | # if not all inputs are ready - skip node 109 | _ -> {:error, "Not all inputs available. Possible graph loop"} 110 | end 111 | _, error -> 112 | error 113 | end) 114 | with {:ok, pins} <- pins do 115 | {:ok, Map.get(pins, :__self__)} 116 | end 117 | end 118 | end 119 | -------------------------------------------------------------------------------- /c_src/driver.h: -------------------------------------------------------------------------------- 1 | #ifndef __DRIVER_H__ 2 | #define __DRIVER_H__ 3 | 4 | #include 5 | #include 6 | 7 | #include "common.h" 8 | 9 | #define LINKER_BUFFER_SIZE 8192 10 | 11 | struct LinkerOptions { 12 | int maxRegisters; 13 | int threadsPerBlock; 14 | int optimizationLevel; 15 | int target; 16 | int debug; 17 | int verbose; 18 | int infoSize; 19 | int errorSize; 20 | }; 21 | 22 | class Linker { 23 | private: 24 | size_t cubinSize = 0; 25 | float walltime = 0.0; 26 | unsigned int threadsPerBlock = 0; 27 | std::vector optKeys; 28 | std::vector optValues; 29 | CUlinkState state; 30 | char *infoLog; 31 | char *errorLog; 32 | bool initialized = false; 33 | public: 34 | void *cubin = NULL; 35 | Linker(LinkerOptions &options); 36 | ~Linker(); 37 | void Run(std::list sources); 38 | size_t OptionsSize(); 39 | CUjit_option *OptionsKeys(); 40 | void **OptionsValues(); 41 | }; 42 | 43 | typedef std::tuple SharedMemory; 44 | 45 | class DeviceMemory { 46 | private: 47 | CUdeviceptr ptr = (CUdeviceptr)NULL; 48 | bool initialized = false; 49 | bool shared = false; 50 | size_t size; 51 | public: 52 | DeviceMemory(const void *src, size_t srcSize): size(srcSize) { 53 | auto result = cuMemAlloc(&ptr, size); 54 | if (result != CUDA_SUCCESS) throw DriverError(result, "DeviceMemory:allocate"); 55 | result = cuMemcpyHtoD(ptr, src, size); 56 | if (result != CUDA_SUCCESS) throw DriverError(result, "DeviceMemory:copy"); 57 | initialized = true; 58 | DEBUG("Device memory initialized with size " << srcSize); 59 | } 60 | DeviceMemory(CUipcMemHandle handle, size_t memSize): size(memSize) { 61 | auto result = cuIpcOpenMemHandle(&ptr, handle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS); 62 | if (result != CUDA_SUCCESS) throw DriverError(result, "DeviceMemory:ipc"); 63 | shared = true; 64 | initialized = true; 65 | DEBUG("Device memory initialized from shared with size " << memSize); 66 | } 67 | 68 | ~DeviceMemory() { 69 | DEBUG("Device memory destroyed"); 70 | if (initialized) { 71 | shared ? cuIpcCloseMemHandle(ptr) : cuMemFree(ptr); 72 | } 73 | } 74 | 75 | void Read(void *dst, int dstSize = -1) { 76 | if (dstSize < 0) dstSize = size; 77 | auto r = cudaMemcpy(dst, (void *)ptr, dstSize, cudaMemcpyDeviceToHost); 78 | if (r != cudaSuccess) throw RuntimeError(r, "DeviceMemory:read"); 79 | } 80 | 81 | size_t GetSize() { 82 | return size; 83 | } 84 | 85 | CUdeviceptr GetPtr() { 86 | return ptr; 87 | } 88 | 89 | CUdeviceptr *GetPtrPtr() { 90 | return &ptr; 91 | } 92 | }; 93 | 94 | class RunArguments { 95 | private: 96 | std::vector values; 97 | public: 98 | ~RunArguments() { 99 | DEBUG("Run arguments destroyed"); 100 | for (auto it = values.begin(); it != values.end(); ++it) std::free(*it); 101 | values.clear(); 102 | } 103 | 104 | template void Add(T param) { 105 | auto ptr = (T *)malloc(sizeof(T)); 106 | *ptr = param; 107 | values.push_back(ptr); 108 | } 109 | 110 | void Add(DeviceMemory &memory) { 111 | auto ptr = (CUdeviceptr *)malloc(sizeof(CUdeviceptr)); 112 | *ptr = memory.GetPtr(); 113 | values.push_back(ptr); 114 | } 115 | 116 | void **GetPtr() { 117 | if (values.empty()) return NULL; 118 | return values.data(); 119 | } 120 | }; 121 | 122 | class CompileError : public Error { 123 | public: 124 | CUresult code; 125 | std::string infoLog; 126 | std::string errorLog; 127 | CompileError(CUresult errorNo, char *info, char *error) : 128 | Error(), 129 | code(errorNo), 130 | infoLog(info, strlen(info)), 131 | errorLog(error, strlen(error)) {} 132 | virtual ETERM *AsTerm(); 133 | }; 134 | 135 | typedef std::tuple RunParameters; 136 | typedef std::tuple> RunEnvironment; 137 | 138 | class Driver { 139 | private: 140 | CUdevice device; 141 | CUcontext context; 142 | std::map modules; 143 | std::map memory; 144 | public: 145 | Driver(int deviceNo); 146 | ~Driver(); 147 | CUdevice GetHandle() { return device; } 148 | int Compile(std::list sources, LinkerOptions &options); 149 | int LoadModule(std::string cubin, LinkerOptions &options); 150 | CUmodule GetModule(int id); 151 | int LoadMemory(const void *src, size_t size); 152 | int LoadMemory(SharedMemory mem); 153 | void UnloadMemory(int id); 154 | void ReadMemory(int id, void *dst, int size = -1); 155 | int GetMemorySize(int id); 156 | DeviceMemory *GetMemory(int id); 157 | SharedMemory ShareMemory(int id); 158 | void Run(int moduleNo, RunParameters ¶ms, std::shared_ptr &args); 159 | void Stream(int moduleNo, std::vector &batch); 160 | template T Unpack(ETERM *value); 161 | ETERM *PackMemory(int idx); 162 | ETERM *PackMemory(SharedMemory mem); 163 | ETERM *PackModule(int idx); 164 | }; 165 | 166 | #endif // __DRIVER_H__ 167 | -------------------------------------------------------------------------------- /lib/cuda/graph/protocols.ex: -------------------------------------------------------------------------------- 1 | alias Cuda.Graph 2 | alias Cuda.Graph.Node 3 | alias Cuda.Graph.Pin 4 | 5 | defprotocol Cuda.Graph.NodeProto do 6 | @doc """ 7 | Returns pin by its id 8 | """ 9 | @spec pin(node:: Node.t, id: Graph.id) :: Pin.t | nil 10 | def pin(node, id) 11 | 12 | @doc """ 13 | Returns a list of pins of specified type 14 | """ 15 | @spec pins(node :: Node.t, type :: Pin.type | [Pin.type]) :: [Pin.t] 16 | def pins(node, type \\ nil) 17 | 18 | @spec assign(node :: struct, key :: atom, value :: any) :: struct 19 | def assign(node, key, value) 20 | 21 | @spec assign(node :: struct, key :: map | keyword) :: struct 22 | def assign(node, assigns) 23 | end 24 | 25 | defprotocol Cuda.Graph.GraphProto do 26 | @spec add(graph :: Graph.t, node :: Node.t) :: Graph.t 27 | def add(graph, node) 28 | 29 | @doc """ 30 | Replaces node in the graph. 31 | 32 | If the node to replace have same id as a replaced node, you can call this 33 | function with two arguments - graph and the node to replace. If you need to 34 | replace node which id is different from replacing node id, pass id of node 35 | to replace as second argument and replacement node as a third argument. 36 | """ 37 | @spec replace(graph :: Graph.t, node :: Node.t) :: Graph.t 38 | @spec replace(graph :: Graph.t, id :: Graph.id | [Graph.id], node :: Node.t) :: Graph.t 39 | def replace(graph, node) 40 | def replace(graph, id, node) 41 | 42 | @doc """ 43 | Returns node in the graph by its name or path (a list of names) 44 | """ 45 | @spec node(graph :: Graph.t, id :: Graph.id | [Graph.id]) :: Node.t 46 | def node(graph, id) 47 | 48 | @doc """ 49 | Returns pin of link specification. It can be a pin of graph itself or a pin 50 | of child node 51 | """ 52 | @spec link_spec_pin(graph :: Graph.t, link_spec :: Graph.link_spec) :: Pin.t 53 | def link_spec_pin(graph, link_spec) 54 | 55 | @doc """ 56 | Returns a node of link specification. It can be a graph itself or child node 57 | """ 58 | @spec link_spec_node(graph :: Graph.t, link_spec :: Graph.link_spec) :: Node.t | Graph.t 59 | def link_spec_node(graph, link_spec) 60 | end 61 | 62 | defprotocol Cuda.Graph.Factory do 63 | @doc """ 64 | Creates a new evaluation node 65 | """ 66 | @spec new(node :: struct, id :: Graph.id, module :: atom, opts :: keyword, env :: Cuda.Env.t) :: struct 67 | def new(node, id, module, opts \\ [], env \\ []) 68 | end 69 | 70 | defimpl Cuda.Graph.NodeProto, for: Any do 71 | def pin(%{pins: pins}, id) do 72 | pins |> Enum.find(fn 73 | %Pin{id: ^id} -> true 74 | _ -> false 75 | end) 76 | end 77 | def get_pin(_, _), do: nil 78 | 79 | def pins(%{pins: pins}, nil), do: pins 80 | def pins(node, types) when is_list(types) do 81 | Enum.reduce(types, [], &(&2 ++ pins(node, &1))) 82 | end 83 | def pins(%{pins: pins}, type) do 84 | pins |> Enum.filter(fn 85 | %Pin{type: ^type} -> true 86 | _ -> false 87 | end) 88 | end 89 | 90 | def assign(%{assigns: assigns} = node, key, value) do 91 | %{node | assigns: Map.put(assigns, key, value)} 92 | end 93 | 94 | def assign(%{assigns: assigns} = node, data) do 95 | data = data |> Enum.into(%{}) 96 | %{node | assigns: Map.merge(assigns, data)} 97 | end 98 | end 99 | 100 | defimpl Cuda.Graph.GraphProto, for: Any do 101 | require Cuda 102 | import Cuda, only: [compile_error: 1] 103 | 104 | def add(%{nodes: nodes} = graph, %{id: id} = node) do 105 | with nil <- node(graph, id) do 106 | %{graph | nodes: [node | nodes]} 107 | else 108 | _ -> compile_error("Node with id `#{id}` is already in the graph") 109 | end 110 | end 111 | 112 | def replace(%{nodes: nodes} = graph, %{id: id} = node) do 113 | nodes = nodes |> Enum.map(fn 114 | %{id: ^id} -> node 115 | x -> x 116 | end) 117 | %{graph | nodes: nodes} 118 | end 119 | def replace(%{nodes: _} = graph, [], node), do: replace(graph, node) 120 | def replace(%{id: src}, [], %{id: dst} = node) when src == dst, do: node 121 | def replace(graph, [id | path], node) do 122 | with %{} = child <- Cuda.Graph.GraphProto.node(graph, id) do 123 | replace(graph, replace(child, path, node)) 124 | end 125 | end 126 | def replace(%{nodes: nodes} = graph, id, node) do 127 | nodes = nodes |> Enum.map(fn 128 | %{id: ^id} -> node 129 | x -> x 130 | end) 131 | %{graph | nodes: nodes} 132 | end 133 | 134 | def node(_, []), do: nil 135 | def node(%{nodes: _} = graph, [id]), do: node(graph, id) 136 | def node(%{nodes: _} = graph, [id | path]) do 137 | with %{} = child <- Cuda.Graph.GraphProto.node(graph, id) do 138 | Cuda.Graph.GraphProto.node(child, path) 139 | end 140 | end 141 | def node(%{nodes: nodes}, id) do 142 | nodes |> Enum.find(fn 143 | %{id: ^id} -> true 144 | _ -> false 145 | end) 146 | end 147 | def node(_, _), do: nil 148 | 149 | def link_spec_pin(graph, {:__self__, pin}) do 150 | Cuda.Graph.NodeProto.pin(graph, pin) 151 | end 152 | def link_spec_pin(graph, {node, pin}) do 153 | with %{} = node <- node(graph, node) do 154 | Cuda.Graph.NodeProto.pin(node, pin) 155 | end 156 | end 157 | 158 | def link_spec_node(graph, {:__self__, _}) do 159 | graph 160 | end 161 | def link_spec_node(graph, {node, _}) do 162 | node(graph, node) 163 | end 164 | end 165 | -------------------------------------------------------------------------------- /.credo.exs: -------------------------------------------------------------------------------- 1 | # This file contains the configuration for Credo and you are probably reading 2 | # this after creating it with `mix credo.gen.config`. 3 | # 4 | # If you find anything wrong or unclear in this file, please report an 5 | # issue on GitHub: https://github.com/rrrene/credo/issues 6 | # 7 | %{ 8 | # 9 | # You can have as many configs as you like in the `configs:` field. 10 | configs: [ 11 | %{ 12 | # 13 | # Run any config using `mix credo -C `. If no config name is given 14 | # "default" is used. 15 | name: "default", 16 | # 17 | # These are the files included in the analysis: 18 | files: %{ 19 | # 20 | # You can give explicit globs or simply directories. 21 | # In the latter case `**/*.{ex,exs}` will be used. 22 | included: ["lib/", "src/", "web/", "apps/"], 23 | excluded: [~r"/_build/", ~r"/deps/"] 24 | }, 25 | # 26 | # If you create your own checks, you must specify the source files for 27 | # them here, so they can be loaded by Credo before running the analysis. 28 | requires: [], 29 | # 30 | # Credo automatically checks for updates, like e.g. Hex does. 31 | # You can disable this behaviour below: 32 | check_for_updates: true, 33 | # 34 | # If you want to enforce a style guide and need a more traditional linting 35 | # experience, you can change `strict` to `true` below: 36 | strict: true, 37 | # 38 | # If you want to use uncolored output by default, you can change `color` 39 | # to `false` below: 40 | color: true, 41 | # 42 | # You can customize the parameters of any check by adding a second element 43 | # to the tuple. 44 | # 45 | # To disable a check put `false` as second element: 46 | # 47 | # {Credo.Check.Design.DuplicatedCode, false} 48 | # 49 | checks: [ 50 | {Credo.Check.Consistency.ExceptionNames}, 51 | {Credo.Check.Consistency.LineEndings}, 52 | {Credo.Check.Consistency.MultiAliasImportRequireUse}, 53 | {Credo.Check.Consistency.ParameterPatternMatching}, 54 | {Credo.Check.Consistency.SpaceAroundOperators}, 55 | {Credo.Check.Consistency.SpaceInParentheses}, 56 | {Credo.Check.Consistency.TabsOrSpaces}, 57 | 58 | # For some checks, like AliasUsage, you can only customize the priority 59 | # Priority values are: `low, normal, high, higher` 60 | {Credo.Check.Design.AliasUsage, priority: :low}, 61 | 62 | # For others you can set parameters 63 | 64 | # If you don't want the `setup` and `test` macro calls in ExUnit tests 65 | # or the `schema` macro in Ecto schemas to trigger DuplicatedCode, just 66 | # set the `excluded_macros` parameter to `[:schema, :setup, :test]`. 67 | {Credo.Check.Design.DuplicatedCode, excluded_macros: []}, 68 | 69 | # You can also customize the exit_status of each check. 70 | # If you don't want TODO comments to cause `mix credo` to fail, just 71 | # set this value to 0 (zero). 72 | {Credo.Check.Design.TagTODO, exit_status: 2}, 73 | {Credo.Check.Design.TagFIXME}, 74 | 75 | {Credo.Check.Readability.FunctionNames}, 76 | {Credo.Check.Readability.LargeNumbers}, 77 | {Credo.Check.Readability.MaxLineLength, priority: :low, max_length: 80}, 78 | {Credo.Check.Readability.ModuleAttributeNames}, 79 | {Credo.Check.Readability.ModuleDoc}, 80 | {Credo.Check.Readability.ModuleNames}, 81 | {Credo.Check.Readability.ParenthesesOnZeroArityDefs}, 82 | {Credo.Check.Readability.ParenthesesInCondition}, 83 | {Credo.Check.Readability.PredicateFunctionNames}, 84 | {Credo.Check.Readability.PreferImplicitTry}, 85 | {Credo.Check.Readability.RedundantBlankLines}, 86 | {Credo.Check.Readability.StringSigils}, 87 | {Credo.Check.Readability.TrailingBlankLine}, 88 | {Credo.Check.Readability.TrailingWhiteSpace}, 89 | {Credo.Check.Readability.VariableNames}, 90 | {Credo.Check.Readability.Semicolons}, 91 | {Credo.Check.Readability.SpaceAfterCommas}, 92 | 93 | {Credo.Check.Refactor.DoubleBooleanNegation}, 94 | {Credo.Check.Refactor.CondStatements}, 95 | {Credo.Check.Refactor.CyclomaticComplexity}, 96 | # TODO: activate Credo.Check.Refactor.FunctionArity when replacement 97 | # of @lint attribute will be founded 98 | {Credo.Check.Refactor.FunctionArity, max_arity: 6}, 99 | {Credo.Check.Refactor.MatchInCondition}, 100 | {Credo.Check.Refactor.NegatedConditionsInUnless}, 101 | {Credo.Check.Refactor.NegatedConditionsWithElse}, 102 | {Credo.Check.Refactor.Nesting}, 103 | {Credo.Check.Refactor.PipeChainStart}, 104 | {Credo.Check.Refactor.UnlessWithElse}, 105 | 106 | {Credo.Check.Warning.BoolOperationOnSameValues}, 107 | {Credo.Check.Warning.IExPry}, 108 | {Credo.Check.Warning.IoInspect}, 109 | {Credo.Check.Warning.OperationOnSameValues}, 110 | {Credo.Check.Warning.OperationWithConstantResult}, 111 | {Credo.Check.Warning.UnusedEnumOperation}, 112 | {Credo.Check.Warning.UnusedFileOperation}, 113 | {Credo.Check.Warning.UnusedKeywordOperation}, 114 | {Credo.Check.Warning.UnusedListOperation}, 115 | {Credo.Check.Warning.UnusedPathOperation}, 116 | {Credo.Check.Warning.UnusedRegexOperation}, 117 | {Credo.Check.Warning.UnusedStringOperation}, 118 | {Credo.Check.Warning.UnusedTupleOperation}, 119 | 120 | # Controversial and experimental checks (opt-in, just remove `, false`) 121 | # 122 | {Credo.Check.Refactor.ABCSize, false}, 123 | {Credo.Check.Refactor.AppendSingleItem, false}, 124 | {Credo.Check.Refactor.VariableRebinding, false}, 125 | {Credo.Check.Warning.MapGetUnsafePass, false}, 126 | 127 | # Deprecated checks (these will be deleted after a grace period) 128 | {Credo.Check.Readability.Specs, false}, 129 | {Credo.Check.Warning.NameRedeclarationByAssignment, false}, 130 | {Credo.Check.Warning.NameRedeclarationByCase, false}, 131 | {Credo.Check.Warning.NameRedeclarationByDef, false}, 132 | {Credo.Check.Warning.NameRedeclarationByFn, false}, 133 | 134 | # Custom checks can be created using `mix credo.gen.check`. 135 | # 136 | ] 137 | } 138 | ] 139 | } 140 | -------------------------------------------------------------------------------- /lib/cuda/graph/pin.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Graph.Pin do 2 | @moduledoc """ 3 | Represents evaluation graph node connector (pin). 4 | """ 5 | 6 | alias Cuda.Graph 7 | require Logger 8 | 9 | @type type :: :input | :output | :producer | :consumer | :terminator 10 | @type group :: nil | atom 11 | @type alias :: nil | {:group, atom} | Graph.id 12 | 13 | @type t :: %__MODULE__{ 14 | id: Graph.id, 15 | type: type, 16 | group: group, 17 | alias: alias, 18 | data_type: any 19 | } 20 | 21 | defstruct [:id, :type, :group, :alias, :data_type] 22 | 23 | def data_size(%__MODULE__{data_type: data_type}) do 24 | type_size(data_type) 25 | end 26 | def data_size(_), do: 0 27 | 28 | def data_arity(%__MODULE__{data_type: {_, arity}}) do 29 | type_size(arity) 30 | end 31 | def data_arity(%__MODULE__{data_type: _}), do: 1 32 | def data_arity(_), do: 0 33 | 34 | def data_type(%__MODULE__{data_type: {t, _}}), do: t 35 | def data_type(%__MODULE__{data_type: t}), do: t 36 | def data_type(_), do: nil 37 | 38 | def pack(:zero, t) when is_atom(t) or is_bitstring(t), do: pack(0, t) 39 | def pack(_, {:skip, bytes}), do: <<0::unit(8)-size(bytes)>> 40 | def pack(x, :i8), do: <> 41 | def pack(x, :i16), do: <> 42 | def pack(x, :i32), do: <> 43 | def pack(x, :i64), do: <> 44 | def pack(x, "i8"), do: pack(x, :i8) 45 | def pack(x, "i16"), do: pack(x, :i16) 46 | def pack(x, "i32"), do: pack(x, :i32) 47 | def pack(x, "i64"), do: pack(x, :i64) 48 | def pack(x, :u8), do: <> 49 | def pack(x, :u16), do: <> 50 | def pack(x, :u32), do: <> 51 | def pack(x, :u64), do: <> 52 | def pack(x, "u8"), do: pack(x, :u8) 53 | def pack(x, "u16"), do: pack(x, :u16) 54 | def pack(x, "u32"), do: pack(x, :u32) 55 | def pack(x, "u64"), do: pack(x, :u64) 56 | # TODO: pack 16-bit floats 57 | # def pack(x, :f16), do: <> 58 | def pack(x, :f32), do: <> 59 | def pack(x, :f64), do: <> 60 | # def pack(x, "f16"), do: pack(x, :f16) 61 | def pack(x, "f32"), do: pack(x, :f32) 62 | def pack(x, "f64"), do: pack(x, :f64) 63 | def pack(x, {type, arity}) when not is_tuple(arity), do: pack(x, {type, {arity}}) 64 | def pack(x, {type, arity}) when is_list(x) do 65 | arity = type_size(arity) 66 | x = List.flatten(x) 67 | if length(x) == arity do 68 | x |> Enum.map(& pack(&1, type)) |> Enum.join 69 | else 70 | raise RuntimeError, message: "Arity of array #{inspect x} should be #{arity}" 71 | end 72 | end 73 | def pack(:zero, {type, arity}) do 74 | size = type_size(arity) * type_size(type) 75 | <<0::unit(8)-size(size)>> 76 | end 77 | def pack(x, types) when is_list(types) and is_list(x) do 78 | x = case length(types) - length(x) do 79 | n when n > 0 -> x ++ List.duplicate(0, n) 80 | _ -> x 81 | end 82 | x 83 | |> Enum.zip(types) 84 | |> Enum.map(fn {x, type} -> pack(x, type) end) 85 | |> Enum.join() 86 | end 87 | def pack(:zero, types) when is_list(types) do 88 | types |> Enum.map(& pack(:zero, &1)) |> Enum.join() 89 | end 90 | def pack(x, types) when is_list(types), do: pack([x], types) 91 | def pack(x, types) when is_map(types) and is_map(x) do 92 | types 93 | |> Enum.map(fn {k, type} -> 94 | with {:ok, v} <- Map.fetch(x, k) do 95 | pack(v, type) 96 | else 97 | _ -> raise RuntimeError, message: "Coudn't find value for key #{k}" 98 | end 99 | end) 100 | |> Enum.join() 101 | end 102 | def pack(:zero, types) when is_map(types) do 103 | types |> Enum.map(fn {_, type} -> pack(:zero, type) end) |> Enum.join() 104 | end 105 | def pack(nil, type) do 106 | Logger.warn("Attempt to pack `nil` value for type #{inspect type}") 107 | end 108 | def pack(_, _), do: <<>> 109 | 110 | def unpack(<>, :i8), do: x 111 | def unpack(<>, :i16), do: x 112 | def unpack(<>, :i32), do: x 113 | def unpack(<>, :i64), do: x 114 | def unpack(x, "i8"), do: unpack(x, :i8) 115 | def unpack(x, "i16"), do: unpack(x, :i16) 116 | def unpack(x, "i32"), do: unpack(x, :i32) 117 | def unpack(x, "i64"), do: unpack(x, :i64) 118 | def unpack(<>, :u8), do: x 119 | def unpack(<>, :u16), do: x 120 | def unpack(<>, :u32), do: x 121 | def unpack(<>, :u64), do: x 122 | def unpack(x, "u8"), do: unpack(x, :u8) 123 | def unpack(x, "u16"), do: unpack(x, :u16) 124 | def unpack(x, "u32"), do: unpack(x, :u32) 125 | def unpack(x, "u64"), do: unpack(x, :u64) 126 | # TODO: pack 16-bit floats 127 | # def pack(x, :f16), do: <> 128 | def unpack(<>, :f32), do: x 129 | def unpack(<>, :f64), do: x 130 | # def pack(x, "f16"), do: pack(x, :f16) 131 | def unpack(x, "f32"), do: unpack(x, :f32) 132 | def unpack(x, "f64"), do: unpack(x, :f64) 133 | def unpack(x, {type, arity}) when is_tuple(arity) do 134 | arity = arity |> Tuple.to_list |> Enum.reverse 135 | {list, _} = unpack_list(x, {type, arity}) 136 | list 137 | end 138 | def unpack(x, {type, arity}) when not is_tuple(arity) do 139 | {list, _} = unpack_list(x, {type, [arity]}) 140 | list 141 | end 142 | def unpack(x, types) when is_list(types) do 143 | {list, _} = types |> Enum.reduce({[], x}, fn 144 | type, {list, rest} -> 145 | {data, rest} = unpack_list(rest, type) 146 | {list ++ data, rest} 147 | _, acc -> 148 | acc 149 | end) 150 | list 151 | end 152 | def unpack(x, types) when is_map(types) do 153 | {list, _} = types |> Enum.reduce({%{}, x}, fn 154 | {k, type}, {map, rest} -> 155 | {[data], rest} = unpack_list(rest, type) 156 | {Map.put(map, k, data), rest} 157 | _, acc -> 158 | acc 159 | end) 160 | list 161 | end 162 | def unpack(_, _), do: nil 163 | 164 | defp unpack_list(x, {type, [arity]}) do 165 | size = type_size(type) 166 | Enum.reduce(1..arity, {[], x}, fn 167 | _, {list, <>} -> 168 | data = [unpack(x, type)] 169 | {list ++ data, rest} 170 | _, acc -> 171 | acc 172 | end) 173 | end 174 | defp unpack_list(x, {type, [current | arity]}) do 175 | Enum.reduce(1..current, {[], x}, fn 176 | _, {list, rest} -> 177 | {data, rest} = unpack_list(rest, {type, arity})# |> IO.inspect 178 | {list ++ [data], rest} 179 | _, acc -> 180 | acc 181 | end) 182 | end 183 | defp unpack_list(x, {:skip, bytes}) do 184 | <<_::binary-size(bytes), rest::binary>> = x 185 | {[], rest} 186 | end 187 | defp unpack_list(x, type) do 188 | size = type_size(type) 189 | #IO.inspect({x, type, size, byte_size(x)}) 190 | <> = x 191 | {[unpack(x, type)], rest} 192 | end 193 | 194 | @type_re ~r/(\d+)/ 195 | def type_size(:i8), do: 1 196 | def type_size(:i16), do: 2 197 | def type_size(:i32), do: 4 198 | def type_size(:i64), do: 8 199 | def type_size(:u8), do: 1 200 | def type_size(:u16), do: 2 201 | def type_size(:u32), do: 4 202 | def type_size(:u64), do: 8 203 | def type_size(:f16), do: 2 204 | def type_size(:f32), do: 4 205 | def type_size(:f64), do: 8 206 | def type_size({:skip, n}), do: n 207 | def type_size(type) when is_atom(type) or is_bitstring(type) do 208 | case Regex.run(@type_re, "#{type}", capture: :all_but_first) do 209 | [n] -> div(String.to_integer(n), 8) 210 | _ -> 0 211 | end 212 | end 213 | def type_size(tuple) when is_tuple(tuple) do 214 | tuple 215 | |> Tuple.to_list 216 | |> Enum.map(&type_size/1) 217 | |> Enum.reduce(1, &Kernel.*/2) 218 | end 219 | def type_size(i) when is_integer(i), do: i 220 | def type_size(l) when is_list(l) do 221 | l |> Enum.map(&type_size/1) |> Enum.reduce(0, &+/2) 222 | end 223 | def type_size(_), do: 0 224 | end 225 | -------------------------------------------------------------------------------- /lib/cuda/float_16.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Float16 do 2 | @moduledoc """ 3 | Represents Float 16 value 4 | Source: http://www.softelectro.ru/ieee754.html 5 | """ 6 | 7 | # TODO: При проведении тестов с параметрами Float_32, выяснилось что для дробных 8 | # значений младшие биты мантиссы отличаются от системных значений, например: 9 | # - Мантисса числа 0.1234567 после системного преобразования: 11111001101011011011110 10 | # - Мантисса числа 0.1234567 после преобразования данным модулем: 11111001101011011011101 11 | 12 | #======Float_16====== 13 | # sign - 1 bit 14 | # exponent - 5 bits 15 | # mantiss - 10 bits 16 | # Exp bias - 15 17 | # 18 | #======Float_32====== 19 | # sign - 1 bit 20 | # exponent - 8 bits 21 | # mantiss - 23 bits 22 | # Exp bias - 127 23 | # 24 | #==================== 25 | use Bitwise 26 | 27 | @sign_size 1 28 | @exponent_size 5 29 | @mantiss_size 10 30 | @exponent_bias 15 31 | 32 | @pow_2_minus_14 0.000061035 33 | @pow_2_10 1024 34 | @min_normal 0.000061 35 | 36 | @doc """ 37 | Converts number to binary representation of float 16 type 38 | """ 39 | @spec pack(number) :: binary 40 | def pack(x) when is_integer(x), do: pack(x + 0.0) 41 | # subnormal converting 42 | def pack(x) when abs(x) < @min_normal do 43 | # Sign bit encode 44 | {x, sign} = sign_encode(x) 45 | # gets binary fractional number 46 | {_, {m, s}} = float_binary(x / @pow_2_minus_14) 47 | # cut binary to fit mantiss size 48 | m = m >>> (s - @mantiss_size) 49 | # pack converted value to binary 50 | <> 51 | end 52 | # normal converting 53 | def pack(x) do 54 | # Sign bit encode 55 | {x, sign} = sign_encode(x) 56 | # gets binary fractional number 57 | f = float_binary(x) 58 | # gets normalized mantiss and exponent 59 | {m, e} = normalize(f, @exponent_bias) 60 | n = size(m) - 1 61 | # у мантиссы убирается ведущая 1 (старший бит) 62 | m_cutted = m - (1 <<< n) 63 | m_cutted = if (@mantiss_size - n) >= 0 do 64 | # мантисса дополняется младшими разрядами до необходимой разрядности типа 65 | m_cutted <<< (@mantiss_size - n) 66 | else 67 | # если разрядность мантиссы больше установленной то, ее младшие биты 68 | # обрезаются до необходимых размеров 69 | m_cutted >>> (n - @mantiss_size) 70 | end 71 | # pack converted value to binary 72 | <> 73 | end 74 | 75 | @doc """ 76 | Converts binary representation of float 16 type to float value 77 | """ 78 | @spec unpack(binary) :: float 79 | # if exponent == 11111 and mantiss == 0, then it's represents positive or 80 | # negative infinity 81 | def unpack(<>), do: sign == 0 && :positive_infinity || :negative_infinity 82 | # if exponent == 11111 and mantiss != 0, then it's represents not a number value 83 | def unpack(<<_::size(@sign_size), 31::size(@exponent_size), _::size(@mantiss_size)>>), do: :not_a_number 84 | # if exponent == 0 and mantiss = 0, then it's represents 0.0 85 | def unpack(<<_::size(@sign_size), 0::size(@exponent_size), 0::size(@mantiss_size)>>), do: 0.0 86 | # subnormal values 87 | def unpack(<>) do 88 | # Sign bit decode 89 | sign = sign_decode(sign) 90 | # Formula explains at http://www.softelectro.ru/ieee754.html paragraph 4.2 91 | # 92 | # При изменении параметров типа в формуле необходимо изменить @pow_2_10 93 | # на 1 <<< @mantiss_size 94 | sign * @pow_2_minus_14 * m/@pow_2_10 95 | end 96 | # normalized values 97 | def unpack(<>) do 98 | # Sign bit decode 99 | sign = sign_decode(sign) 100 | # Formula explains at http://www.softelectro.ru/ieee754.html paragraph 4.2 101 | # 102 | # При изменении параметров типа в формуле необходимо изменить @pow_2_10 103 | # на 1 <<< @mantiss_size 104 | sign * :math.pow(2, exp - @exponent_bias) * (1 + m/@pow_2_10) 105 | end 106 | 107 | # Encodes sign of the value, and turn value absolute for further convertations 108 | defp sign_encode(x) when x < 0, do: {abs(x), 1} 109 | defp sign_encode(x), do: {x, 0} 110 | 111 | # Decodes sign bit 112 | defp sign_decode(0), do: 1 113 | defp sign_decode(1), do: -1 114 | 115 | # Converts float value to the binary representation 116 | # @type float_binary :: {whole_part :: integer, fractional_part :: binary} 117 | # @spec float_binary(value :: float, precision :: integer) :: float_binary 118 | defp float_binary(value, precision \\ 50) when is_float(value) do 119 | # get the whole part of value 120 | i = round(Float.floor(value)) 121 | # get the fractional part of value 122 | fract = value - i 123 | f = fb_fract_binary(fract, precision) 124 | {i, f} 125 | end 126 | 127 | # Normalizes binary fractional number to mantiss and exponent 128 | # @type normalize :: {mantiss :: binary, exponent :: integer} 129 | # @spec normalize(value :: float_binary, bias :: integer) :: normalize 130 | defp normalize(value, bias) 131 | defp normalize({0, {0, _}}, _), do: {0, 0} 132 | defp normalize({0, {fract, fract_size}}, bias) do 133 | n = size(fract) 134 | # get exponent with bias 135 | exp = -(fract_size - n + 1) + bias 136 | {fract, exp} 137 | end 138 | defp normalize({int, {fract, fract_size}}, bias) do 139 | int_n = size(int) 140 | # get exponent with bias 141 | exp = int_n - 1 + bias 142 | # целое сдвигается влево на кол-во разрядов дробной части, 143 | # добавляется дробная часть 144 | int = (int <<< fract_size) ||| fract 145 | {int, exp} 146 | end 147 | 148 | # Gets binary digits count 149 | defp size(0), do: 0 150 | defp size(x) do 151 | round(Float.floor(:math.log2(x))) + 1 152 | end 153 | 154 | # Implements algorithm from this source https://otvet.mail.ru/question/46720675 155 | # Перевод из десятичной системы счисления в двоичную и шестнадцатеричную: 156 | # а) исходная дробь умножается на основание системы счисления, в которую 157 | # переводится (2 или 16); 158 | # б) в полученном произведении целая часть преобразуется в соответствии с 159 | # таблицей в цифру нужной системы счисления и отбрасывается – она является 160 | # старшей цифрой получаемой дроби; 161 | # в) оставшаяся дробная часть (это правильная дробь) вновь умножается на 162 | # нужное основание системы счисления с последующей обработкой полученного 163 | # произведения в соответствии с шагами а) и б); 164 | # г) процедура умножения продолжается до тех пор, пока ни будет получен 165 | # нулевой результат в дробной части произведения или ни будет достигнуто 166 | # требуемое количество цифр в результате; 167 | # д) формируется искомое число: последовательно отброшенные в шаге б) цифры 168 | # составляют дробную часть результата, причем в порядке уменьшения старшинства. 169 | # 170 | # fb_fract_binary returns: 171 | # {fractional_result, overall_digits_count} 172 | defp fb_fract_binary(fract, precision, result \\ {0, 0}) 173 | defp fb_fract_binary(0.0, _, result), do: result 174 | defp fb_fract_binary(_, 0, result), do: result 175 | defp fb_fract_binary(fract, precision, {result, digits}) do 176 | fract = fract * 2 177 | result = result <<< 1 178 | if fract >= 1 do 179 | fb_fract_binary(fract - 1, precision - 1, {result + 1, digits + 1}) 180 | else 181 | fb_fract_binary(fract, precision - 1, {result, digits + 1}) 182 | end 183 | end 184 | 185 | def dbg_view(data, binary \\ true) 186 | def dbg_view(<>, binary) do 187 | base = binary && :binary || :decimal 188 | IO.inspect(sign, label: :sign, base: base) 189 | IO.inspect(exp, label: :exponent, base: base) 190 | IO.inspect(m, label: :mantiss, base: base) 191 | :ok 192 | end 193 | def dbg_view(_, _), do: raise(ArgumentError, message: "Value has wrong format") 194 | end 195 | -------------------------------------------------------------------------------- /lib/cuda/graph/node.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Graph.Node do 2 | @moduledoc """ 3 | Represents an evaluation graph node. 4 | 5 | You can use this module to define your own evaluation nodes. To do this you 6 | should implement callbacks that will be called with user options, specified 7 | at node creation time and current Cuda environment. Here is a simple example: 8 | 9 | ``` 10 | defmodule MyNode do 11 | use Cuda.Graph.Node 12 | 13 | def __pins__(_assigns) do 14 | [input(:in), output(:out)] 15 | end 16 | 17 | def __type__(_assigns) do 18 | :host 19 | end 20 | end 21 | ``` 22 | """ 23 | 24 | require Cuda 25 | alias Cuda.Graph 26 | alias Cuda.Graph.Pin 27 | alias Cuda.Graph.NodeProto 28 | 29 | @type type :: :gpu | :host | :virtual | :graph | :computation_graph 30 | @type options :: keyword 31 | @type assigns :: %{options: options, env: Cuda.Env.t} 32 | @type t :: %__MODULE__{ 33 | id: Graph.id, 34 | module: module, 35 | type: type, 36 | pins: [Pin.t], 37 | assigns: assigns 38 | } 39 | 40 | @callback __assigns__(id :: Graph.id, opts :: options, env :: Cuda.Env.t) :: map | keyword 41 | 42 | @doc """ 43 | Provides a node protocol that is a structurethat holds node data. 44 | 45 | It can be for example `Cuda.Graph`, `Cuda.Graph.Node`, `Cuda.Graph.GPUNode` 46 | or any other module that implements node protocol functionality. 47 | 48 | By default it will be `Cuda.Graph.Node`. 49 | """ 50 | @callback __proto__() :: atom 51 | 52 | @doc """ 53 | Provides a complete pin list for newly created node. 54 | 55 | You can use `pin/3`, `input/2`, `output/2`, `consumer/2` and `producer/2` 56 | helpers here. 57 | """ 58 | @callback __pins__(assigns :: assigns) :: [Pin.t] 59 | 60 | @doc """ 61 | Provides a node type. 62 | 63 | Following types are supported: 64 | 65 | * `:virtual` - node does not involved in real computations (it does not change 66 | data and does not affect the computation flow). It can be 67 | usefull for intermediate data retrieving and so on. 68 | * `:host` - node makes host (CPU) computations but does not affects any GPU 69 | workflow 70 | * `:gpu` - node affects GPU and optionally CPU workflows 71 | * `:graph` - node with graph nested in it 72 | """ 73 | @callback __type__(assigns :: assigns) :: type 74 | 75 | @doc """ 76 | Called before compilation. 77 | 78 | You can put vars, helpers and other stuff needed by further compilation 79 | process. 80 | """ 81 | @callback __compile__(node :: struct) :: {:ok, struct} | {:error, any} 82 | 83 | @derive [NodeProto] 84 | defstruct [:id, :module, :type, pins: [], assigns: %{}] 85 | 86 | @exports [consumer: 2, consumer: 3, input: 2, input: 3, output: 2, output: 3, 87 | pin: 3, pin: 4, producer: 2, producer: 3] 88 | @input_pins ~w(input consumer terminator)a 89 | @output_pins ~w(output producer)a 90 | @graph_types ~w(graph computation_graph)a 91 | 92 | defmacro __using__(_opts) do 93 | quote do 94 | import unquote(__MODULE__), only: unquote(@exports) 95 | import Cuda.Graph.NodeProto, only: [assign: 3] 96 | @behaviour unquote(__MODULE__) 97 | def __assigns__(_id, _opts, _env), do: %{} 98 | def __proto__(), do: unquote(__MODULE__) 99 | def __compile__(node), do: {:ok, node} 100 | defoverridable __assigns__: 3, __compile__: 1, __proto__: 0 101 | end 102 | end 103 | 104 | defmacro input_pin_types() do 105 | quote(do: unquote(@input_pins)) 106 | end 107 | 108 | defmacro output_pin_types() do 109 | quote(do: unquote(@output_pins)) 110 | end 111 | 112 | defmacro graph_types() do 113 | quote(do: unquote(@graph_types)) 114 | end 115 | 116 | @doc """ 117 | Creates a pin with specified parameters 118 | """ 119 | @spec pin(name :: Graph.id, type :: Pin.type, data_type :: any) :: Pin.t 120 | @spec pin(name :: Graph.id, type :: Pin.type, data_type :: any, group :: Pin.group) :: Pin.t 121 | def pin(name, type, data_type, group \\ nil) do 122 | %Pin{ 123 | id: name, 124 | type: type, 125 | data_type: data_type, 126 | group: group 127 | } 128 | end 129 | 130 | @doc """ 131 | Creates an input pin with specified parameters. 132 | 133 | Input is a pin from which the data passed inside an evaluation node. 134 | """ 135 | @spec input(name :: Graph.id, data_type :: any) :: Pin.t 136 | @spec input(name :: Graph.id, data_type :: any, group :: Pin.group) :: Pin.t 137 | def input(name, data_type, group \\ nil) do 138 | pin(name, :input, data_type, group) 139 | end 140 | 141 | @doc """ 142 | Creates an output pin with specified parameters. 143 | 144 | Ouput is a pin through which you pass data outside from your node. 145 | """ 146 | @spec output(name :: Graph.id, data_type :: any) :: Pin.t 147 | @spec output(name :: Graph.id, data_type :: any, group :: Pin.group) :: Pin.t 148 | def output(name, data_type, group \\ nil) do 149 | pin(name, :output, data_type, group) 150 | end 151 | 152 | @doc """ 153 | Creates a producer pin with specified parameters. 154 | 155 | Producers are nodes that generates some data. Data from this kind of pin can 156 | be passed to `:input` or `:consumer` pins. 157 | """ 158 | @spec producer(name :: Graph.id, data_type :: any) :: Pin.t 159 | @spec producer(name :: Graph.id, data_type :: any, group :: Pin.group) :: Pin.t 160 | def producer(name, data_type, group \\ nil) do 161 | pin(name, :producer, data_type, group) 162 | end 163 | 164 | @doc """ 165 | Creates a consumer pin with specified parameters. 166 | 167 | Consumers are nodes that takes some data. This pin is like a data flow 168 | terminator. Data for this pin can be taked from `:output` or `:producer` 169 | pins. 170 | """ 171 | @spec consumer(name :: Graph.id, data_type :: any) :: Pin.t 172 | @spec consumer(name :: Graph.id, data_type :: any, group :: Pin.group) :: Pin.t 173 | def consumer(name, data_type, group \\ nil) do 174 | pin(name, :consumer, data_type, group) 175 | end 176 | 177 | @doc """ 178 | Returns module of struct that used to store node data. It can be for example 179 | `Cuda.Graph`, `Cuda.Graph.Node`, `Cuda.Graph.GPUNode` or any other module, 180 | related to node type. 181 | """ 182 | @spec proto(module :: atom) :: atom 183 | def proto(module) do 184 | if function_exported?(module, :__proto__, 0) do 185 | module.__proto__() 186 | else 187 | __MODULE__ 188 | end 189 | end 190 | 191 | def string_id(id) when is_tuple(id) do 192 | id |> Tuple.to_list |> Enum.map(&string_id/1) |> Enum.join("__") 193 | end 194 | def string_id(id) do 195 | "#{id}" 196 | end 197 | end 198 | 199 | defimpl Cuda.Graph.Factory, for: Cuda.Graph.Node do 200 | require Cuda 201 | alias Cuda.Graph.Pin 202 | 203 | @types ~w(gpu host virtual graph computation_graph)a 204 | @reserved_names ~w(input output)a 205 | 206 | def new(_, id, module, opts, env) do 207 | with {:module, module} <- Code.ensure_loaded(module) do 208 | if id in @reserved_names do 209 | Cuda.compile_error("Reserved node name '#{id}' used") 210 | end 211 | 212 | assigns = case function_exported?(module, :__assigns__, 3) do 213 | true -> module.__assigns__(id, opts, env) |> Enum.into(%{}) 214 | _ -> %{} 215 | end 216 | assigns = Map.merge(assigns, %{options: opts, env: env}) 217 | 218 | type = case function_exported?(module, :__type__, 1) do 219 | true -> module.__type__(assigns) 220 | _ -> :virtual 221 | end 222 | if not type in @types do 223 | Cuda.compile_error("Unsupported type: #{inspect type}") 224 | end 225 | 226 | pins = case function_exported?(module, :__pins__, 1) do 227 | true -> module.__pins__(assigns) 228 | _ -> [] 229 | end 230 | if not is_list(pins) or not Enum.all?(pins, &valid_pin?/1) do 231 | Cuda.compile_error("Invalid pin list supplied") 232 | end 233 | 234 | struct(Cuda.Graph.Node, id: id, module: module, type: type, pins: pins, 235 | assigns: assigns) 236 | else 237 | _ -> Cuda.compile_error("Node module #{module} could not be loaded") 238 | end 239 | end 240 | 241 | defp valid_pin?(%Pin{}), do: true 242 | defp valid_pin?(_), do: false 243 | end 244 | -------------------------------------------------------------------------------- /test/memory_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Cuda.MemoryTest do 2 | use ExUnit.Case 3 | alias Cuda.Memory.Shape 4 | import Cuda.Memory 5 | 6 | describe "pack/2" do 7 | test "pack a 1d array [1, 2, 3, 4] of i8" do 8 | p = pack([1, 2, 3, 4], {:i8, 4}) 9 | assert p == <<1, 2, 3, 4>> 10 | end 11 | 12 | test "pack a 2d array [[1, 2, 3, 4], [5, 6, 7, 8]] of i8" do 13 | p = pack([[1, 2, 3, 4], [5, 6, 7, 8]], {:i8, {4, 2}}) 14 | assert p == <<1, 2, 3, 4, 5, 6, 7, 8>> 15 | end 16 | 17 | test "pack a 2d array [[1, 2, 3, 4], [5, 6, 7, 8]] of i8 with %Shape{} usage " do 18 | p = pack([[1, 2, 3, 4], [5, 6, 7, 8]], %Shape{type: {:i8, {4, 2}}}) 19 | assert p == <<1, 2, 3, 4, 5, 6, 7, 8>> 20 | end 21 | 22 | test "pack a 2d array [[1, 2, 3, 4], [5, 6, 7, 8]] of i8 with %Shape{} usage and skip 4 bytes after" do 23 | p = pack([[1, 2, 3, 4], [5, 6, 7, 8]], %Shape{type: {:i8, {4, 2}}, skip: 4}) 24 | assert p == <<1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0>> 25 | end 26 | 27 | test "pack a 2d array [[1, 2, 3, 4], [5, 6, 7, 8]] of i8 with %Shape{} usage, skip 2 bytes before, and 4 bytes after" do 28 | p = pack([[1, 2, 3, 4], [5, 6, 7, 8]], %Shape{type: {:i8, {4, 2}}, skip: {2, 4}}) 29 | assert p == <<0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0>> 30 | end 31 | 32 | test "pack a 2d array [[1, 2, 3, 4], [1.2, 13.3, 5.6, 12.1]] of i8, f32" do 33 | p = pack([[1, 2, 3, 4], [1.2, 13.3, 5.6, 12.1]], [{:i8, 4}, {:f32, 4}]) 34 | assert p == <<1, 2, 3, 4, 154, 153, 153, 63, 205, 204, 84, 65, 51, 51, 179, 64, 154, 153, 65, 65>> 35 | end 36 | 37 | test "pack a map %{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]} of i8, f32" do 38 | p = pack(%{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]}, %{a: {:i8, 4}, b: {:f32, 4}}) 39 | assert p == <<1, 2, 3, 4, 154, 153, 153, 63, 205, 204, 84, 65, 51, 51, 179, 64, 154, 153, 65, 65>> 40 | end 41 | 42 | test "pack a map %{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]} of i8, f32, with %Shape{} usage" do 43 | p = pack(%{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]}, %Shape{type: %{a: {:i8, 4}, b: {:f32, 4}}}) 44 | assert p == <<1, 2, 3, 4, 154, 153, 153, 63, 205, 204, 84, 65, 51, 51, 179, 64, 154, 153, 65, 65>> 45 | end 46 | 47 | test "pack a map %{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]} of i8, f32, with %Shape{} usage, skip 4 bytes before for 'a' and 5 bytes before, and 2 bytes after for 'k'" do 48 | p = pack(%{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]}, %Shape{type: %{a: {:i8, 4}, b: {:f32, 4}}, skip: %{a: {4, 0}, b: {5, 2}}}) 49 | assert p == <<0, 0, 0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 154, 153, 153, 63, 205, 204, 84, 65, 51, 51, 179, 64, 154, 153, 65, 65, 0, 0>> 50 | end 51 | 52 | test "pack a 3d array [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] of i8" do 53 | p = pack([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], {:i8, {4, 3}}) 54 | assert p == <<1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>> 55 | end 56 | 57 | test "pack a f32 value and i8 1d array [65.7, [1, 2, 3, 4]]" do 58 | p = pack([65.7, [1, 2, 3, 4]], [:f32, {:i8, 4}]) 59 | assert p == <<102, 102, 131, 66, 1, 2, 3, 4>> 60 | end 61 | 62 | test "pack a zero f32 value" do 63 | p = pack(:zero, :f32) 64 | assert p == <<0, 0, 0, 0>> 65 | end 66 | end 67 | 68 | describe "unpack/2" do 69 | test "unpack a 1d array [1, 2, 3, 4] of i8" do 70 | p = unpack(<<1, 2, 3, 4>>, {:i8, 4}) 71 | assert p == [1, 2, 3, 4] 72 | end 73 | 74 | test "unpack a 2d array [[1, 2, 3, 4], [5, 6, 7, 8]] of i8" do 75 | p = unpack(<<1, 2, 3, 4, 5, 6, 7, 8>>, {:i8, {4, 2}}) 76 | assert p == [[1, 2, 3, 4], [5, 6, 7, 8]] 77 | end 78 | 79 | test "unpack a 2d array [[1, 2, 3, 4], [5, 6, 7, 8]] of i8 with %Shape{} usage " do 80 | p = unpack(<<1, 2, 3, 4, 5, 6, 7, 8>>, %Shape{type: {:i8, {4, 2}}}) 81 | assert p == [[1, 2, 3, 4], [5, 6, 7, 8]] 82 | end 83 | 84 | test "unpack a 2d array [[1, 2, 3, 4], [5, 6, 7, 8]] of i8 with %Shape{} usage and skip 4 bytes after" do 85 | p = unpack(<<1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0>>, %Shape{type: {:i8, {4, 2}}, skip: 4}) 86 | assert p == [[1, 2, 3, 4], [5, 6, 7, 8]] 87 | end 88 | 89 | test "unpack a 2d array [[1, 2, 3, 4], [5, 6, 7, 8]] of i8 with %Shape{} usage, skip 2 bytes before, and 4 bytes after" do 90 | p = unpack(<<0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0>>, %Shape{type: {:i8, {4, 2}}, skip: {2, 4}}) 91 | assert p == [[1, 2, 3, 4], [5, 6, 7, 8]] 92 | end 93 | 94 | test "unpack a 2d array [[1, 2, 3, 4], [1.2, 13.3, 5.6, 12.1]] of i8, f32" do 95 | p = <<1, 2, 3, 4, 154, 153, 153, 63, 205, 204, 84, 65, 51, 51, 179, 64, 154, 153, 65, 65>> 96 | |> unpack([{:i8, 4}, {:f32, 4}]) 97 | |> rnd() 98 | assert p == [[1, 2, 3, 4], [1.2, 13.3, 5.6, 12.1]] 99 | end 100 | 101 | test "unpack a map %{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]} of i8, f32" do 102 | p = <<1, 2, 3, 4, 154, 153, 153, 63, 205, 204, 84, 65, 51, 51, 179, 64, 154, 153, 65, 65>> 103 | |> unpack(%{a: {:i8, 4}, b: {:f32, 4}}) 104 | |> rnd() 105 | assert p == %{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]} 106 | end 107 | 108 | test "unpack a map %{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]} of i8, f32, with %Shape{} usage" do 109 | p = <<1, 2, 3, 4, 154, 153, 153, 63, 205, 204, 84, 65, 51, 51, 179, 64, 154, 153, 65, 65>> 110 | |> unpack(%Shape{type: %{a: {:i8, 4}, b: {:f32, 4}}}) 111 | |> rnd() 112 | assert p == %{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]} 113 | end 114 | 115 | test "unpack a map %{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]} of i8, f32, with %Shape{} usage, skip 4 bytes before for 'a' and 5 bytes before, and 2 bytes after for 'k'" do 116 | p = <<0, 0, 0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 154, 153, 153, 63, 205, 204, 84, 65, 51, 51, 179, 64, 154, 153, 65, 65, 0, 0>> 117 | |> unpack(%Shape{type: %{a: {:i8, 4}, b: {:f32, 4}}, skip: %{a: {4, 0}, b: {5, 2}}}) 118 | |> rnd() 119 | assert p == %{a: [1, 2, 3, 4], b: [1.2, 13.3, 5.6, 12.1]} 120 | end 121 | 122 | test "unpack a 3d array [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] of i8" do 123 | p = unpack(<<1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>>, {:i8, {4, 3}}) 124 | assert p == [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] 125 | end 126 | 127 | test "pack a f32 value and i8 1d array [65.7, [1, 2, 3, 4]]" do 128 | p = <<102, 102, 131, 66, 1, 2, 3, 4>> 129 | |> unpack([:f32, {:i8, 4}]) 130 | |> rnd() 131 | assert p == [65.7, [1, 2, 3, 4]] 132 | end 133 | end 134 | 135 | describe "size/1" do 136 | test "Size of single type" do 137 | assert size(:f32) == 4 138 | end 139 | 140 | test "Size of type's list" do 141 | assert size([:f32, {:i8, 3}]) == 7 142 | end 143 | 144 | test "Size of type's tuple" do 145 | assert size({:i8, {3, 3, 2}}) == 18 146 | end 147 | 148 | test "Size of type's map" do 149 | assert size(%{a: {:i8, {3, 3, 2}}, b: :f32}) == 22 150 | end 151 | 152 | test "Size of type's Shape" do 153 | assert size(%Shape{type: %{a: {:i8, {3, 3, 2}}, b: :f32}}) == 22 154 | end 155 | 156 | test "Size of type's Shape with skip" do 157 | assert size(%Shape{type: %{a: {:i8, {3, 3, 2}}, b: :f32}, skip: %{a: 2, b: {3, 2}}}) == 29 158 | assert size(%Shape{type: %{a: {:i8, {3, 3, 2}}, b: :f32}, skip: %{b: {3, 2}}}) == 27 159 | assert size(%Shape{type: %{a: {:i8, {3, 3, 2}}, b: :f32}, skip: 2}) == 24 160 | assert size(%Shape{type: %{a: {:i8, {3, 3, 2}}, b: :f32}, skip: {3, 2}}) == 27 161 | end 162 | end 163 | 164 | describe "size_equal?/2" do 165 | test "general" do 166 | assert size_equal?(:f32, :f32) 167 | end 168 | 169 | test "different type sizes" do 170 | refute size_equal?(:f32, :i8) 171 | end 172 | 173 | test "equal types sizes" do 174 | assert size_equal?(:f32, {:i8, 4}) 175 | end 176 | end 177 | 178 | def rnd(x, precision \\ 1) 179 | def rnd(%{} = x, precision) do 180 | x 181 | |> Map.to_list() 182 | |> Enum.map(fn {key, val} -> {key, rnd(val, precision)} end) 183 | |> Enum.into(%{}) 184 | end 185 | def rnd(x, precision) when is_list(x) do 186 | x 187 | |> Enum.map(&rnd(&1, precision)) 188 | end 189 | def rnd(x, precision) when is_float(x) do 190 | Float.round(x, precision) 191 | end 192 | def rnd(x, _), do: x 193 | end 194 | -------------------------------------------------------------------------------- /c_src/commands.cpp: -------------------------------------------------------------------------------- 1 | #include "commands.h" 2 | 3 | namespace Commands { 4 | 5 | CUevent Events::Get(std::string name) { 6 | CUevent event; 7 | CUresult result; 8 | auto eventIt = events.find(name); 9 | if (eventIt == events.end()) { 10 | result = cuEventCreate(&event, CU_EVENT_DISABLE_TIMING); 11 | if (result != CUDA_SUCCESS) throw DriverError(result, "Event creation"); 12 | events.insert(std::pair(name, event)); 13 | } else { 14 | event = eventIt->second; 15 | } 16 | return event; 17 | } 18 | 19 | Events::~Events() { 20 | DEBUG("Events destroyed"); 21 | for (auto it = events.begin(); it != events.end(); ++it) { 22 | cuEventDestroy(it->second); 23 | } 24 | } 25 | 26 | Command *Command::Create(Driver *driver, ETERM *item) { 27 | if (ERL_IS_LIST(item)) { 28 | if (erl_length(item) == 0) return new Batch(driver, item); 29 | if (ERL_IS_LIST(erl_hd(item))) return new BatchList(driver, item); 30 | return new Batch(driver, item); 31 | } 32 | if (!ERL_IS_TUPLE(item)) throw StringError("Bad argument"); 33 | if (erl_size(item) != 2) throw StringError("Bad argument"); 34 | auto cmd = erl_element(1, item); 35 | auto args = erl_element(2, item); 36 | if (!ERL_IS_ATOM(cmd)) throw StringError("Bad argument"); 37 | if (ATOM_EQ(cmd, "run")) { 38 | return new RunCommand(driver, args); 39 | } else if (ATOM_EQ(cmd, "event")) { 40 | return new EventCommand(driver, args); 41 | } else if (ATOM_EQ(cmd, "wait")) { 42 | return new WaitCommand(driver, args); 43 | } else { 44 | throw StringError("Bad command"); 45 | } 46 | } 47 | 48 | Batch::Batch(Driver *driver, ETERM *args) : Command(driver) { 49 | if (!ERL_IS_LIST(args)) throw StringError("Bad argument"); 50 | auto bs = erl_length(args); 51 | for (int j = 0; j < bs; j++) { 52 | commands.push_back(Command::Create(driver, erl_hd(args))); 53 | args = erl_tl(args); 54 | } 55 | } 56 | 57 | void Batch::Run(Context &ctx) { 58 | CUresult result; 59 | 60 | result = cuStreamCreate(&ctx.stream, CU_STREAM_NON_BLOCKING); 61 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:stream_create"); 62 | 63 | if (!ctx.events) ctx.events = new Events(); 64 | 65 | DEBUG("Starting stream " << ctx.id); 66 | 67 | for (auto it = commands.begin(); it != commands.end(); ++it) { 68 | Command *cmd = *it; 69 | cmd->Run(ctx); 70 | } 71 | 72 | result = cuEventRecord(ctx.events->Get(ctx.finishEvent), ctx.stream); 73 | if (result != CUDA_SUCCESS) throw DriverError(result, "Event record"); 74 | DEBUG("Finishing stream " << ctx.id); 75 | } 76 | 77 | BatchList::BatchList(Driver *driver, ETERM *args) : Command(driver) { 78 | if (!ERL_IS_LIST(args)) throw StringError("Bad argument"); 79 | auto bs = erl_length(args); 80 | for (int j = 0; j < bs; j++) { 81 | batches.push_back(Command::Create(driver, erl_hd(args))); 82 | args = erl_tl(args); 83 | } 84 | } 85 | 86 | void BatchList::Run(Context &ctx) { 87 | DEBUG("Running batch list"); 88 | if (!ctx.events) ctx.events = new Events(); 89 | std::vector ctxs; 90 | int idx = 0; 91 | for (auto it = batches.begin(); it != batches.end(); ++it) { 92 | Command *cmd = *it; 93 | Context batchCtx = ctx; 94 | batchCtx.id = std::to_string(idx); 95 | batchCtx.finishEvent = std::string("finish") + batchCtx.id; 96 | cmd->Run(batchCtx); 97 | ctxs.push_back(batchCtx); 98 | idx++; 99 | } 100 | // wait for all streams 101 | CUstream batchStream; 102 | CUresult result; 103 | result = cuStreamCreate(&batchStream, CU_STREAM_NON_BLOCKING); 104 | DEBUG("Waiting streams"); 105 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:stream_create"); 106 | for (auto it = ctxs.begin(); it != ctxs.end(); ++it) { 107 | result = cuStreamWaitEvent(it->stream, ctx.events->Get(it->finishEvent), 0); 108 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:stream_create"); 109 | DEBUG("Stream " << it->id << " finished"); 110 | result = cuStreamDestroy(it->stream); 111 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:stream_free"); 112 | } 113 | result = cuStreamSynchronize(batchStream); 114 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:stream_wait"); 115 | DEBUG("All streams finished"); 116 | result = cuStreamDestroy(batchStream); 117 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:stream_free"); 118 | } 119 | 120 | RunCommand::RunCommand(Driver *driver, ETERM *args) : Command(driver) { 121 | if (!ERL_IS_TUPLE(args)) throw StringError("Bad argument"); 122 | auto argc = erl_size(args); 123 | if (argc < 1) throw StringError("Bad argument"); 124 | 125 | auto kernelTerm = erl_element(1, args); 126 | if (!ERL_IS_BINARY(kernelTerm)) throw StringError("Bad argument"); 127 | kernel = std::string((char *)ERL_BIN_PTR(kernelTerm), erl_size(kernelTerm)); 128 | 129 | gx = 1, gy = 1, gz = 1; 130 | bx = 1, by = 1, bz = 1; 131 | arguments = std::make_shared(RunArguments()); 132 | 133 | if (argc > 1) { 134 | ETERM *grid = NULL; 135 | ETERM *block = NULL; 136 | ETERM *params = NULL; 137 | 138 | if (argc == 2) { 139 | params = erl_element(2, args); 140 | if (ERL_IS_TUPLE(params) || ERL_IS_INTEGER(params)) { 141 | block = params; 142 | params = NULL; 143 | } 144 | } else if (argc == 3) { 145 | block = erl_element(2, args); 146 | params = erl_element(3, args); 147 | if (ERL_IS_TUPLE(params) || ERL_IS_INTEGER(params)) { 148 | grid = params; 149 | params = NULL; 150 | } 151 | } else if (argc == 4) { 152 | block = erl_element(2, args); 153 | grid = erl_element(3, args); 154 | params = erl_element(4, args); 155 | } else { 156 | throw StringError("Bad argument"); 157 | } 158 | 159 | if (block) { 160 | if (ERL_IS_INTEGER(block)) { 161 | bx = Get(block); 162 | } else if (ERL_IS_TUPLE(block)) { 163 | auto s = erl_size(block); 164 | if (s > 0) bx = Get(erl_element(1, block)); 165 | if (s > 1) by = Get(erl_element(2, block)); 166 | if (s > 2) bz = Get(erl_element(3, block)); 167 | } else { 168 | throw StringError("Bad argument"); 169 | } 170 | } 171 | 172 | if (grid) { 173 | if (ERL_IS_INTEGER(grid)) { 174 | bx = Get(grid); 175 | } else if (ERL_IS_TUPLE(grid)) { 176 | auto s = erl_size(grid); 177 | if (s > 0) gx = Get(erl_element(1, grid)); 178 | if (s > 1) gy = Get(erl_element(2, grid)); 179 | if (s > 2) gz = Get(erl_element(3, grid)); 180 | } else { 181 | throw StringError("Bad argument"); 182 | } 183 | } 184 | 185 | if (params) { 186 | if (!ERL_IS_LIST(params)) throw StringError("Bad argument"); 187 | auto s = erl_length(params); 188 | for (int i = 0; i < s; i++) { 189 | auto param = erl_hd(params); 190 | params = erl_tl(params); 191 | if (ERL_IS_TUPLE(param)) { 192 | auto param_type = erl_element(1, param); 193 | auto param_value = erl_element(2, param); 194 | if (ERL_IS_ATOM(param_type) && ATOM_EQ(param_type, "memory")) { 195 | auto mem = driver->GetMemory(Get(param_value)); 196 | if (!mem) throw StringError("Invalid memory handle"); 197 | arguments->Add(*mem); 198 | } 199 | } else if (ERL_IS_INTEGER(param)) { 200 | arguments->Add(ERL_INT_VALUE(param)); 201 | } else if (ERL_IS_FLOAT(param)) { 202 | float f = ERL_FLOAT_VALUE(param); 203 | arguments->Add(f); 204 | } else { 205 | throw StringError("Bad argument"); 206 | } 207 | } 208 | } 209 | } 210 | } 211 | 212 | void RunCommand::Run(Context &ctx) { 213 | CUfunction func; 214 | CUresult result; 215 | 216 | result = cuModuleGetFunction(&func, ctx.module, kernel.c_str()); 217 | if (result != CUDA_SUCCESS) throw DriverError(result); 218 | 219 | // DEBUG("Launch DriverPort::Stream"); 220 | result = cuLaunchKernel(func, gx, gy, gz, bx, by, bz, 0, ctx.stream, arguments->GetPtr(), 0); 221 | // DEBUG("Exit DriverPort::Stream"); 222 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:execution"); 223 | // DEBUG("Exit 1 DriverPort::Stream"); 224 | } 225 | 226 | EventCommand::EventCommand(Driver *driver, ETERM *arg) : Command(driver) { 227 | if (!ERL_IS_BINARY(arg)) throw StringError("Invalid argument"); 228 | name = std::string((char *)ERL_BIN_PTR(arg), erl_size(arg)); 229 | } 230 | 231 | void EventCommand::Run(Context &ctx) { 232 | auto event = ctx.events->Get(name); 233 | cuEventRecord(event, ctx.stream); 234 | } 235 | 236 | WaitCommand::WaitCommand(Driver *driver, ETERM *arg) : Command(driver) { 237 | if (!ERL_IS_BINARY(arg)) throw StringError("Invalid argument"); 238 | name = std::string((char *)ERL_BIN_PTR(arg), erl_size(arg)); 239 | } 240 | 241 | void WaitCommand::Run(Context &ctx) { 242 | auto event = ctx.events->Get(name); 243 | cuStreamWaitEvent(ctx.stream, event, 0); 244 | } 245 | 246 | } // namespace Commands 247 | -------------------------------------------------------------------------------- /lib/cuda/template/ptx_helpers.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Template.PtxHelpers do 2 | @moduledoc """ 3 | Helpers available in PTX templates. 4 | """ 5 | 6 | alias Cuda.{Compiler.Context, Graph.Node, Memory, Template} 7 | use Bitwise 8 | require Logger 9 | 10 | @doc """ 11 | Returns true if current node compiled for back propagation graph 12 | """ 13 | def back_propagation?(ctx) do 14 | Map.get(ctx.assigns, :back_propagation) == true || 15 | Context.find_assign(ctx, [:back_propagation]) == true 16 | end 17 | 18 | @doc """ 19 | Returns true if current node compiled for training graph 20 | """ 21 | def training?(ctx) do 22 | Map.get(ctx.assigns, :training) == true || 23 | Context.find_assign(ctx, [:training]) == true 24 | end 25 | 26 | @doc """ 27 | Includes PTX .version directive 28 | """ 29 | def version(vsn \\ "5.0") 30 | def version(vsn) when is_bitstring(vsn) do 31 | ".version #{vsn}\n" 32 | end 33 | def version({major, minor}) when is_integer(major) and is_integer(minor) do 34 | ".version #{major}.#{minor}\n" 35 | end 36 | def version(nil) do 37 | ".version 5.0\n" 38 | end 39 | def version(vsn) do 40 | Logger.warn("Invalid PTX version specified: `#{inspect vsn}`") 41 | "" 42 | end 43 | 44 | @doc """ 45 | Includes PTX .target directive 46 | """ 47 | @targets ~w(sm_60 sm_61 sm_50 sm_52 sm_53 sm_30 sm_32 sm_35 sm_37 sm_20 48 | sm_10 sm_11 sm_12 sm_13 texmode_unified texmode_independent 49 | debug map_f64_to_f32) 50 | def target(ctx, tgt \\ nil) 51 | def target(ctx, tgt) when is_list(tgt) do 52 | tgt 53 | |> Enum.map(&sanitize_target(ctx, &1)) 54 | |> Enum.reject(&is_nil/1) 55 | |> Enum.uniq() 56 | |> Enum.join(", ") 57 | ".target #{tgt}\n" 58 | end 59 | def target(ctx, tgt) do 60 | tgt = sanitize_target(ctx, tgt) 61 | ".target #{tgt}\n" 62 | end 63 | defp sanitize_target(_ctx, tgt) when is_bitstring(tgt) and tgt in @targets do 64 | tgt 65 | end 66 | defp sanitize_target(_ctx, nil) do 67 | # TODO: when we get target from GPU (sm_35 for example) following 68 | # compilation error occured: 69 | # 70 | # SM version specified by .target is higher than default SM 71 | # version assumed 72 | # 73 | # We need to find way to compile with detected targets 74 | {major, minor} = {2, 0} #ctx.env.gpu_info[:compute_capability] 75 | "sm_#{major}#{minor}" 76 | end 77 | defp sanitize_target(ctx, tgt) when is_atom(tgt) do 78 | sanitize_target(ctx, "#{tgt}") 79 | end 80 | defp sanitize_target(ctx, {major, minor}) when is_integer(major) and is_integer(minor) do 81 | sanitize_target(ctx, "sm_#{major}#{minor}") 82 | end 83 | defp sanitize_target(ctx, tgt) when not is_nil(tgt) do 84 | default = sanitize_target(ctx, nil) 85 | Logger.warn("Invalid PTX target specified: `#{inspect tgt}`") 86 | Logger.warn("Default PTX target `#{default}` will be used") 87 | default 88 | end 89 | 90 | @doc """ 91 | Includes PTX .address_size directive 92 | """ 93 | def address_size(ctx, size \\ nil) 94 | def address_size(_ctx, size) when size in [32, 64] do 95 | ".address_size #{size}\n" 96 | end 97 | def address_size(ctx, nil) do 98 | size = ctx.env.gpu_info[:global_memory_bus_width] 99 | ".address_size #{size}\n" 100 | end 101 | def address_size(ctx, size) do 102 | default = ctx.env.gpu_info[:global_memory_bus_width] 103 | Logger.warn("Invalid PTX address size specified: `#{inspect size}`") 104 | Logger.warn("Default PTX address size `#{default}` will be used") 105 | ".address_size #{default}\n" 106 | end 107 | 108 | @doc """ 109 | Includes PTX directives header that includes .version, .target and 110 | .address_size directives 111 | """ 112 | def header(ctx, vsn \\ nil, tgt \\ nil, size \\ nil) do 113 | version(vsn) <> target(ctx, tgt) <> address_size(ctx, size) 114 | end 115 | 116 | @doc """ 117 | Returns offset of variable in specified memory block 118 | """ 119 | def offset(ctx, memory, var) do 120 | #IO.inspect({memory, var}) 121 | shape = Context.find_assign(ctx, [:memory, memory], ctx.path, &has_var?(&1, var)) 122 | shape = with nil <- shape do 123 | get_in(ctx.assigns, [:memory, memory]) 124 | end# |> IO.inspect 125 | with nil <- Memory.offset(shape, var) do 126 | Logger.warn("Can't find offset for `#{inspect var}` in memory `#{memory}`") 127 | nil 128 | end# |> IO.inspect(label: :OFFSET) 129 | end 130 | 131 | @doc """ 132 | Returns offset of variable in shared memory block 133 | """ 134 | def shared_offset(ctx, var) do 135 | #IO.inspect(ctx) 136 | case Template.Helpers.var(ctx, :layer) do 137 | nil -> raise CompileError, description: "Layer variable is not defined" 138 | layer -> offset(ctx, :shared, [var, layer]) 139 | end 140 | end 141 | 142 | @doc """ 143 | Returns offset of specified pin 144 | """ 145 | def pin_offset(ctx, var) do 146 | offset(ctx, :pins, var) 147 | end 148 | 149 | @doc """ 150 | Defines PTX kernel function 151 | """ 152 | def kernel(ctx, name, body, opts \\ []) do 153 | params = [{:pins, :u64, [ptr: true]}] ++ Keyword.get(opts, :args, []) 154 | params = params |> Enum.map(¶m/1) |> Enum.join(", ") 155 | ".visible .entry #{current_node_id(ctx)}__#{name} (#{params}) {\n" <> 156 | body <> 157 | "\n}" 158 | end 159 | 160 | @doc """ 161 | Includes specified include-module 162 | """ 163 | def include(ctx, module, part \\ :body, opts \\ []) do 164 | {part, opts} = case part do 165 | opts when is_list(opts) -> {:body, opts} 166 | part -> {part, opts} 167 | end 168 | with {:module, _} <- Code.ensure_loaded(module) do 169 | case function_exported?(module, :__ptx__, 2) do 170 | true -> module.__ptx__(part, Keyword.put(opts, :ctx, ctx)) 171 | _ -> "" 172 | end 173 | else 174 | _ -> raise CompileError, description: "Couldn't compile include module #{module}" 175 | end 176 | end 177 | 178 | def param({name, type, opts}) do 179 | space = opts 180 | |> Keyword.take(~w(const global local shared)a) 181 | |> Enum.reduce([], fn 182 | {name, true}, [] -> [".#{name}"] 183 | _, acc -> acc 184 | end) 185 | align = opts 186 | |> Keyword.take(~w(align)a) 187 | |> Enum.reduce([], fn 188 | {:align, x}, _ when band(x, x - 1) == 0 -> [".align #{x}"] 189 | _, acc -> acc 190 | end) 191 | param = [".param", ".#{type}"] ++ 192 | (if Keyword.get(opts, :ptr) == true, do: [".ptr"], else: []) ++ 193 | space ++ 194 | align ++ 195 | ["#{name}"] 196 | param |> Enum.join(" ") 197 | end 198 | 199 | defmacro back_propagation?() do 200 | quote do 201 | back_propagation?(var!(ctx)) 202 | end 203 | end 204 | 205 | defmacro training?() do 206 | quote do 207 | training?(var!(ctx)) 208 | end 209 | end 210 | 211 | defmacro offset(memory, var) do 212 | quote do 213 | offset(var!(ctx), unquote(memory), unquote(var)) 214 | end 215 | end 216 | 217 | defmacro shared_offset(var) do 218 | quote do 219 | shared_offset(var!(ctx), unquote(var)) 220 | end 221 | end 222 | 223 | defmacro pin_offset(var) do 224 | quote do 225 | pin_offset(var!(ctx), unquote(var)) 226 | end 227 | end 228 | 229 | defmacro defkernel(ctx, name, args, opts) do 230 | body = Keyword.get(opts, :do) 231 | args = args 232 | |> Enum.map(&parse_arg/1) 233 | |> Enum.filter(&is_tuple/1) 234 | |> Macro.escape 235 | quote do 236 | kernel(unquote(ctx), unquote(name), unquote(body), args: unquote(args)) 237 | end 238 | end 239 | defmacro defkernel(ctx, name, opts) do 240 | body = Keyword.get(opts, :do) 241 | quote do 242 | kernel(unquote(ctx), unquote(name), unquote(body)) 243 | end 244 | end 245 | 246 | defp has_var?(%Memory{} = memory, [key | path]) do 247 | with %{} = map <- Memory.get(memory, key) do 248 | get_in(map, path) 249 | else 250 | _ -> false 251 | end 252 | end 253 | defp has_var?(%Memory{} = memory, var) do 254 | Memory.has_key?(memory, var) 255 | end 256 | defp has_var?(map, path) when is_list(path) do 257 | get_in(map, path) 258 | end 259 | defp has_var?(map, var) do 260 | Map.has_key?(map, var) 261 | end 262 | 263 | defp current_node_id(ctx) do 264 | Node.string_id(Map.get(Context.node(ctx) || %{}, :id)) 265 | end 266 | 267 | defp parse_arg(arg, opts \\ []) 268 | defp parse_arg({name, type}, opts) when is_atom(type) do 269 | {name, type, opts} 270 | end 271 | defp parse_arg({name, {{:., _, [{type, _, x}, opt]}, _, _}}, opts) when is_atom(x) do 272 | {name, type, [{opt, true} | opts]} 273 | end 274 | defp parse_arg({name, {{:., _, [nested, opt]}, _, _}}, opts) do 275 | parse_arg({name, nested}, [{opt, true} | opts]) 276 | end 277 | defp parse_arg({name, {:-, _, [{{:., _, [nested, opt]}, _, _}, v]}}, opts) do 278 | parse_arg({name, nested}, [{opt, v} | opts]) 279 | end 280 | defp parse_arg({name, {type, _, _}}, opts) when is_atom(type) do 281 | {name, type, opts} 282 | end 283 | defp parse_arg(_, _) do 284 | nil 285 | end 286 | end 287 | -------------------------------------------------------------------------------- /lib/cuda.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda do 2 | @moduledoc """ 3 | NVIDIA GPU CUDA library bindings for Erlang and Elixir. 4 | """ 5 | use GenServer 6 | require Logger 7 | 8 | @term_call <<1>> 9 | @raw_call <<2>> 10 | @proxied_calls ~w(call call_raw)a 11 | 12 | @type error_tuple :: {:error, String.t} 13 | 14 | defmacro __using__(_opts) do 15 | quote do 16 | use GenServer 17 | require Logger 18 | 19 | defdelegate info(pid), to: unquote(__MODULE__) 20 | defdelegate info(pid, info), to: unquote(__MODULE__) 21 | defdelegate compile(pid, src), to: unquote(__MODULE__) 22 | defdelegate compile(pid, src, opts), to: unquote(__MODULE__) 23 | defdelegate module_load(pid, src), to: unquote(__MODULE__) 24 | defdelegate module_load(pid, src, opts), to: unquote(__MODULE__) 25 | defdelegate memory_load(pid, data), to: unquote(__MODULE__) 26 | defdelegate module_read(pid, handle), to: unquote(__MODULE__) 27 | defdelegate module_unload(pid, handle), to: unquote(__MODULE__) 28 | defdelegate module_share(pid, handle), to: unquote(__MODULE__) 29 | defdelegate run(pid, module, func, params), to: unquote(__MODULE__) 30 | defdelegate run(pid, module, func, block, params), to: unquote(__MODULE__) 31 | defdelegate run(pid, module, func, block, grid, params), to: unquote(__MODULE__) 32 | defdelegate stream(pid, module, batch), to: unquote(__MODULE__) 33 | defdelegate start_link(), to: unquote(__MODULE__) 34 | defdelegate start_link(opts), to: unquote(__MODULE__) 35 | 36 | def __init__(opts), do: {:ok, opts} 37 | def __handle_call__(_msg, _from, st), do: {:noreply, st} 38 | def __handle_cast__(_msg, st), do: {:noreply, st} 39 | def __handle_into__(_msg, st), do: {:noreply, st} 40 | 41 | def init(opts) do 42 | with {:ok, proxy_st} <- unquote(__MODULE__).init(opts), 43 | {:ok, st} <- __init__(opts) do 44 | {:ok, {st, proxy_st}} 45 | end 46 | end 47 | 48 | def handle_call({x, _, _} = msg, from, {st, proxy_st}) when x in unquote(@proxied_calls) do 49 | case unquote(__MODULE__).handle_call(msg, from, proxy_st) do 50 | {:reply, reply, proxy_st} -> {:reply, reply, {st, proxy_st}} 51 | {:reply, reply, proxy_st, timeout} -> {:reply, reply, {st, proxy_st}, timeout} 52 | {:noreply, proxy_st} -> {:noreply, {st, proxy_st}} 53 | {:noreply, proxy_st, timeout} -> {:noreply, {st, proxy_st}, timeout} 54 | {:stop, reason, reply, proxy_st} -> {:stop, reason, reply, {st, proxy_st}} 55 | {:stop, reason, proxy_st} -> {:stop, reason, {st, proxy_st}} 56 | end 57 | end 58 | def handle_call(msg, from, {st, proxy_st}) do 59 | case __handle_call__(msg, from, st) do 60 | {:reply, reply, st} -> {:reply, reply, {st, proxy_st}} 61 | {:reply, reply, st, timeout} -> {:reply, reply, {st, proxy_st}, timeout} 62 | {:noreply, st} -> {:noreply, {st, proxy_st}} 63 | {:noreply, st, timeout} -> {:noreply, {st, proxy_st}, timeout} 64 | {:stop, reason, reply, st} -> {:stop, reason, reply, {st, proxy_st}} 65 | {:stop, reason, st} -> {:stop, reason, {st, proxy_st}} 66 | end 67 | end 68 | 69 | def handle_cast(msg, {st, proxy_st}) do 70 | case __handle_cast__(msg, st) do 71 | {:noreply, st} -> {:noreply, {st, proxy_st}} 72 | {:noreply, st, timeout} -> {:noreply, {st, proxy_st}, timeout} 73 | {:stop, reason, st} -> {:stop, reason, {st, proxy_st}} 74 | end 75 | end 76 | 77 | def handle_info({_port, {:data, unquote(@term_call) <> _}} = msg, {st, proxy_st}) do 78 | case unquote(__MODULE__).handle_info(msg, proxy_st) do 79 | {:noreply, proxy_st} -> {:noreply, {st, proxy_st}} 80 | {:noreply, proxy_st, timeout} -> {:noreply, {st, proxy_st}, timeout} 81 | {:stop, reason, proxy_st} -> {:stop, reason, {st, proxy_st}} 82 | end 83 | end 84 | def handle_info(msg, {st, proxy_st}) do 85 | case __handle_info__(msg, st) do 86 | {:noreply, st} -> {:noreply, {st, proxy_st}} 87 | {:noreply, st, timeout} -> {:noreply, {st, proxy_st}, timeout} 88 | {:stop, reason, st} -> {:stop, reason, {st, proxy_st}} 89 | end 90 | end 91 | 92 | defoverridable __init__: 1, __handle_call__: 3, __handle_cast__: 2, 93 | __handle_info__: 2 94 | end 95 | end 96 | 97 | def start_link(opts \\ []) do 98 | {name, opts} = Keyword.split(opts, ~w(name)a) 99 | GenServer.start_link(__MODULE__, opts, name) 100 | end 101 | 102 | defmacro compile_error(msg) do 103 | quote do 104 | raise CompileError, description: unquote(msg) 105 | end 106 | end 107 | 108 | @doc """ 109 | Returns NVIDIA driver and CUDA library info 110 | 111 | param can be ommitted or must be one of: 112 | 113 | * device_count - get GPU device count 114 | * driver_version - get CUDA driver version 115 | * memory - get free and total GPU memory 116 | * runtime_version - get CUDA runtime version 117 | """ 118 | @spec info(nil | :device_count | :driver_version | :memory | :runtime_version) :: any 119 | def info(pid, info \\ nil) do 120 | GenServer.call(pid, {:call, :info, info}) 121 | end 122 | 123 | def device_info(pid) do 124 | GenServer.call(pid, {:call, :device_info, nil}) 125 | end 126 | 127 | def compile(pid, sources, opts \\ nil) 128 | def compile(pid, {:file, file}, opts) when is_binary(file) do 129 | with {:ok, source} <- File.read(file) do 130 | GenServer.call(pid, {:call, :compile, {[source], opts}}) 131 | end 132 | end 133 | def compile(pid, {:files, files}, opts) when is_list(files) do 134 | sources = Enum.reduce(files, {:ok, []}, fn 135 | {:ok, sources}, file -> 136 | with {:ok, source} <- File.read(file), do: [source | sources] 137 | error, _ -> 138 | error 139 | end) 140 | with {:ok, sources} <- sources do 141 | GenServer.call(pid, {:call, :compile, {sources, opts}}) 142 | end 143 | end 144 | def compile(pid, source, opts) when is_binary(source) do 145 | GenServer.call(pid, {:call, :compile, {[source], opts}}) 146 | end 147 | def compile(pid, sources, opts) when is_list(sources) do 148 | GenServer.call(pid, {:call, :compile, {sources, opts}}) 149 | end 150 | 151 | def module_load(pid, src, opts \\ []) do 152 | GenServer.call(pid, {:call, :module_load, {src, opts}}) 153 | end 154 | 155 | def memory_load(pid, data) when is_binary(data) do 156 | GenServer.call(pid, {:call_raw, :memory_load, data}) 157 | end 158 | def memory_load(pid, data) do 159 | GenServer.call(pid, {:call, :memory_load, data}) 160 | end 161 | 162 | def memory_read(pid, handle) do 163 | GenServer.call(pid, {:call, :memory_read, handle}) 164 | end 165 | 166 | def memory_unload(pid, handle) do 167 | GenServer.call(pid, {:call, :memory_unload, handle}) 168 | end 169 | 170 | def memory_share(pid, handle) do 171 | GenServer.call(pid, {:call, :memory_share, handle}) 172 | end 173 | 174 | def run(pid, module, func) do 175 | GenServer.call(pid, {:call, :run, {module, func}}) 176 | end 177 | def run(pid, module, func, params) do 178 | GenServer.call(pid, {:call, :run, {module, func, params}}) 179 | end 180 | def run(pid, module, func, block, params) do 181 | GenServer.call(pid, {:call, :run, {module, func, block, params}}) 182 | end 183 | # NOTE: change @lint attribute to something that will be proposed to 184 | # replace deprecation and activate Credo.Check.Refactor.FunctionArity 185 | # in .credo.exs with default max_arity (5) 186 | # @lint {Credo.Check.Refactor.FunctionArity, false} 187 | def run(pid, module, func, block, grid, params) do 188 | GenServer.call(pid, {:call, :run, {module, func, block, grid, params}}) 189 | end 190 | 191 | def stream(pid, module, batch) do 192 | GenServer.call(pid, {:call, :stream, {module, batch}}) 193 | end 194 | 195 | def init(opts) do 196 | cmd = case Keyword.get(opts, :port_bin) do 197 | nil -> Application.app_dir(:cuda, Path.join(~w(priv cuda_driver_port))) 198 | port -> port 199 | end 200 | cmd = case Keyword.get(opts, :device) do 201 | nil -> cmd 202 | device -> "#{cmd} #{device}" 203 | end 204 | port = Port.open({:spawn, cmd}, [:binary, :nouse_stdio, packet: 4]) 205 | {:ok, port} 206 | end 207 | 208 | def handle_call({:call, func, arg}, _from, port) do 209 | Port.command(port, @term_call <> :erlang.term_to_binary({func, arg})) 210 | wait_reply(port) 211 | end 212 | 213 | def handle_call({:call_raw, func, arg}, _from, port) do 214 | func = "#{func}" 215 | size = byte_size(func) 216 | Port.command(port, @raw_call <> <> <> func <> arg) 217 | wait_reply(port) 218 | end 219 | 220 | defp wait_reply(port) do 221 | receive do 222 | {^port, {:data, @term_call <> data}} -> 223 | {:reply, :erlang.binary_to_term(data), port} 224 | {^port, {:data, @raw_call <> data}} -> 225 | {:reply, {:ok, data}, port} 226 | _ -> 227 | {:reply, {:error, "Port communication error"}, port} 228 | end 229 | end 230 | 231 | def handle_info({_port, {:data, @term_call <> data}}, port) do 232 | msg = :erlang.binary_to_term(data) 233 | Logger.warn("Unexpected message from CUDA port: #{inspect msg}") 234 | {:noreply, port} 235 | end 236 | end 237 | -------------------------------------------------------------------------------- /lib/cuda/graph.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Graph do 2 | @moduledoc """ 3 | Represents evaluation graph 4 | """ 5 | require Cuda 6 | import Cuda, only: [compile_error: 1] 7 | 8 | alias Cuda.Graph.Pin 9 | alias Cuda.Graph.Node 10 | alias Cuda.Graph.GraphProto 11 | alias Cuda.Graph.NodeProto 12 | 13 | @type id :: String.t | atom | non_neg_integer 14 | @type link_spec :: {id, id} 15 | @type link :: {link, link} 16 | 17 | @type t :: %__MODULE__{ 18 | id: id, 19 | module: module, 20 | type: Node.type, 21 | pins: [Pin.t], 22 | nodes: [Node.t], 23 | links: [link], 24 | assigns: map 25 | } 26 | 27 | @type dfs_action :: :enter | :move | :leave 28 | @type dfs_result :: {:ok | :error | atom, state :: any} 29 | @type dfs_callback :: (action :: dfs_action, arg :: any, state :: any -> dfs_result) 30 | 31 | @callback __graph__(graph :: t) :: t 32 | @callback __child_options__(id :: id, module :: atom, graph :: t) :: Node.options 33 | 34 | @derive [NodeProto, GraphProto] 35 | defstruct [:id, :module, type: :graph, pins: [], nodes: [], links: [], 36 | assigns: %{}] 37 | 38 | import Node, only: [input_pin_types: 0, output_pin_types: 0] 39 | 40 | @self :__self__ 41 | @exports [add: 3, add: 4, 42 | chain: 3, chain: 4, 43 | close: 1, link: 3] 44 | 45 | defmacro __using__(_opts) do 46 | quote do 47 | use Cuda.Graph.Node 48 | import unquote(__MODULE__), only: unquote(@exports) 49 | @behaviour unquote(__MODULE__) 50 | def __type__(_assigns), do: :graph 51 | def __proto__(), do: unquote(__MODULE__) 52 | def __child_options__(_id, _module, _graph), do: [] 53 | 54 | defoverridable __child_options__: 3, __type__: 1 55 | end 56 | end 57 | 58 | def add(%__MODULE__{} = graph, id, module, opts \\ []) do 59 | with {:module, module} <- Code.ensure_loaded(module) do 60 | if id == graph.id do 61 | compile_error("The id `#{id}` of newly added node is already used by graph") 62 | end 63 | if GraphProto.node(graph, id) != nil do 64 | compile_error("Node with id `#{id}` s already exists in the graph") 65 | end 66 | proto = struct(Node.proto(module)) 67 | opts = case function_exported?(graph.module, :__child_options__, 3) do 68 | true -> id 69 | |> graph.module.__child_options__(module, graph) 70 | |> Keyword.merge(opts) 71 | _ -> opts 72 | end 73 | env = graph |> Map.get(:assigns, %{}) |> Map.get(:env, %Cuda.Env{}) 74 | GraphProto.add(graph, Cuda.Graph.Factory.new(proto, id, module, opts, env)) 75 | else 76 | _ -> compile_error("Graph module #{module} could not be loaded") 77 | end 78 | end 79 | 80 | def chain(graph, id, module, opts \\ []) 81 | def chain(%__MODULE__{nodes: []} = graph, id, module, opts) do 82 | src_pin = case NodeProto.pins(graph, input_pin_types()) do 83 | [src_pin] -> src_pin 84 | _ -> compile_error("Chain allowed only for graphs with single input") 85 | end 86 | with %{nodes: [node]} = graph <- add(graph, id, module, opts) do 87 | dst_pin = case NodeProto.pins(node, input_pin_types()) do 88 | [dst_pin] -> dst_pin 89 | _ -> compile_error("Chain can only be applied to nodes with single input") 90 | end 91 | link(graph, src_pin.id, {id, dst_pin.id}) 92 | end 93 | end 94 | def chain(%__MODULE__{nodes: [src_node | _]} = graph, id, module, opts) do 95 | src_pin = case NodeProto.pins(src_node, output_pin_types()) do 96 | [src_pin] -> src_pin 97 | _ -> compile_error("Chain can only be applied after nodes with single output") 98 | end 99 | with %{nodes: [dst_node | _]} = graph <- add(graph, id, module, opts) do 100 | dst_pin = case NodeProto.pins(dst_node, input_pin_types()) do 101 | [dst_pin] -> dst_pin 102 | _ -> compile_error("Chain can only be applied to nodes with single input") 103 | end 104 | link(graph, {src_node.id, src_pin.id}, {dst_node.id, dst_pin.id}) 105 | end 106 | end 107 | 108 | def close(%__MODULE__{nodes: []} = graph) do 109 | src_pin = case NodeProto.pins(graph, input_pin_types()) do 110 | [src_pin] -> src_pin 111 | _ -> compile_error("Close allowed only for graphs with single input") 112 | end 113 | dst_pin = case NodeProto.pins(graph, output_pin_types()) do 114 | [dst_pin] -> dst_pin 115 | _ -> compile_error("Close allowed only for graphs with single output") 116 | end 117 | link(graph, src_pin.id, dst_pin.id) 118 | end 119 | def close(%__MODULE__{nodes: [node | _]} = graph) do 120 | src_pin = case NodeProto.pins(node, output_pin_types()) do 121 | [src_pin] -> src_pin 122 | _ -> compile_error("Close can only be applied after nodes with single output") 123 | end 124 | dst_pin = case NodeProto.pins(graph, output_pin_types()) do 125 | [dst_pin] -> dst_pin 126 | _ -> compile_error("Close allowed only for graphs with single output") 127 | end 128 | link(graph, {node.id, src_pin.id}, dst_pin.id) 129 | end 130 | 131 | def link(%__MODULE__{links: links} = graph, {sn, sp} = src, {dn, dp} = dst) do 132 | # node to node connection 133 | with {:src, %{} = src_node} <- {:src, GraphProto.node(graph, sn)}, 134 | {:dst, %{} = dst_node} <- {:dst, GraphProto.node(graph, dn)} do 135 | #IO.inspect({src, dst}) 136 | src_pin = assert_pin_type(src_node, sp, output_pin_types()) 137 | dst_pin = assert_pin_type(dst_node, dp, input_pin_types()) 138 | assert_pin_data_type(src_pin, dst_pin) 139 | assert_pin_layout(src_pin, dst_pin) 140 | %{graph | links: [{src, dst} | links]} 141 | else 142 | {:src, _} -> compile_error("Source node `#{sn}` not found") 143 | {:dst, _} -> compile_error("Destination node `#{dn}` not found") 144 | end 145 | end 146 | 147 | def link(%__MODULE__{links: links} = graph, src, {dn, dp} = dst) do 148 | # input to node connection 149 | with %{} = dst_node <- GraphProto.node(graph, dn) do 150 | #IO.inspect({graph.id, src, dst}) 151 | src_pin = assert_pin_type(graph, src, input_pin_types()) 152 | dst_pin = assert_pin_type(dst_node, dp, input_pin_types()) 153 | assert_pin_data_type(src_pin, dst_pin) 154 | assert_pin_layout(src_pin, dst_pin) 155 | %{graph | links: [{{@self, src}, dst} | links]} 156 | else 157 | _ -> compile_error("Destination node `#{dn}` not found") 158 | end 159 | end 160 | 161 | def link(%__MODULE__{links: links} = graph, {sn, sp} = src, dst) do 162 | # node to output connection 163 | with %{} = src_node <- GraphProto.node(graph, sn) do 164 | src_pin = assert_pin_type(graph, dst, output_pin_types()) 165 | dst_pin = assert_pin_type(src_node, sp, output_pin_types()) 166 | assert_pin_data_type(src_pin, dst_pin) 167 | assert_pin_layout(src_pin, dst_pin) 168 | %{graph | links: [{src, {@self, dst}} | links]} 169 | else 170 | _ -> compile_error("Source node `#{sn}` not found") 171 | end 172 | end 173 | 174 | def link(%__MODULE__{links: links} = graph, src, dst) do 175 | # input to output connection 176 | src_pin = assert_pin_type(graph, src, input_pin_types()) 177 | dst_pin = assert_pin_type(graph, dst, output_pin_types()) 178 | assert_pin_data_type(src_pin, dst_pin) 179 | assert_pin_layout(src_pin, dst_pin) 180 | %{graph | links: [{{@self, src}, {@self, dst}} | links]} 181 | end 182 | 183 | defp assert_pin_type(node, pin_name, types) do 184 | with %Pin{} = pin <- NodeProto.pin(node, pin_name) do 185 | if not pin.type in types do 186 | types = types |> Enum.map(& "#{&1}") |> Enum.join(" or ") 187 | id = Node.string_id(node.id) 188 | compile_error("Pin `#{pin_name}` of node `#{id}` has a wrong" <> 189 | " type. The #{types} types are expected.") 190 | end 191 | pin 192 | else 193 | _ -> 194 | id = Node.string_id(node.id) 195 | compile_error("Pin `#{pin_name}` not found in node `#{id}`") 196 | end 197 | end 198 | 199 | # TODO: move pin type checking logic into Cuda.Graph.Pin 200 | defp assert_pin_data_type(%{data_type: {t1, a1}} = p1, %{data_type: {t2, a2}} = p2) do 201 | s1 = Pin.data_size(a1) 202 | s2 = Pin.data_size(a2) 203 | if t1 != t2 or s1 != s2 do 204 | compile_error("The pins #{p1.id} and #{p2.id} has different types") 205 | end 206 | end 207 | # nil data_type assumes auto-detection 208 | defp assert_pin_data_type(%{data_type: nil}, _), do: true 209 | defp assert_pin_data_type(_, %{data_type: nil}), do: true 210 | defp assert_pin_data_type(%{data_type: t1} = p1, %{data_type: t2} = p2) do 211 | if t1 != t2 do 212 | compile_error("The pins #{p1.id} and #{p2.id} has different types") 213 | end 214 | end 215 | 216 | defp assert_pin_layout(%{id: id1, layout: l1}, %{id: id2, layout: l2}) when l1 != l2 do 217 | compile_error("The pins #{id1} and #{id2} has different layout") 218 | end 219 | defp assert_pin_layout(_, _), do: true 220 | end 221 | 222 | defimpl Cuda.Graph.Factory, for: Cuda.Graph do 223 | require Cuda 224 | alias Cuda.Graph.{Node, NodeProto} 225 | 226 | @doc """ 227 | Creates new graph node 228 | """ 229 | def new(_, id, module, opts, env) do 230 | with {:module, module} <- Code.ensure_loaded(module) do 231 | proto = Node.proto(module) 232 | graph = %Node{} 233 | |> Cuda.Graph.Factory.new(id, module, opts, env) 234 | |> Map.from_struct 235 | graph = struct(proto, graph) 236 | graph = case function_exported?(module, :__graph__, 1) do 237 | true -> module.__graph__(graph) 238 | _ -> graph 239 | end 240 | graph |> set_pin_shapes() 241 | else 242 | _ -> Cuda.compile_error("Node module #{module} could not be loaded") 243 | end 244 | end 245 | 246 | defp set_pin_shapes(%{pins: pins} = graph) do 247 | %{graph | pins: pins |> Enum.map(& set_pin_shape(graph, &1))} 248 | end 249 | 250 | defp set_pin_shape(graph, %{alias: {:group, a}, data_type: nil} = pin) when not is_nil(a) do 251 | pins = graph.nodes 252 | |> Enum.map(fn node -> 253 | pins = node.pins 254 | |> Enum.filter(& &1.group == a) 255 | |> Enum.map(& {&1.id, &1.data_type}) 256 | |> Enum.into(%{}) 257 | {node.id, pins} 258 | end) 259 | |> Enum.filter(fn 260 | {_, []} -> false 261 | _ -> true 262 | end) 263 | |> Enum.into(%{}) 264 | case pins do 265 | [] -> Cuda.compile_error("Invalid pin alias group: #{inspect a}") 266 | pins -> %{pin | data_type: pins} 267 | end 268 | end 269 | defp set_pin_shape(graph, %{alias: a, data_type: nil} = pin) when not is_nil(a) do 270 | case NodeProto.pin(graph, a) do 271 | nil -> Cuda.compile_error("Invalid pin alias: #{inspect a}") 272 | aliases -> %{pin | data_type: aliases.data_type} 273 | end 274 | end 275 | defp set_pin_shape(_, pin), do: pin 276 | end 277 | -------------------------------------------------------------------------------- /c_src/driver.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "common.h" 3 | #include "driver.h" 4 | 5 | Driver::Driver(int deviceNo) { 6 | DEBUG("Enter Driver constructor for device: " << deviceNo); 7 | CUresult result = CUDA_SUCCESS; 8 | result = cuDeviceGet(&device, deviceNo); 9 | if (result != CUDA_SUCCESS) throw DriverError(result); 10 | result = cuCtxCreate(&context, 0, device); 11 | DEBUG("Context created: " << result); 12 | if (result != CUDA_SUCCESS) throw DriverError(result); 13 | DEBUG("Driver initialized for device #" << deviceNo); 14 | } 15 | 16 | Driver::~Driver() { 17 | DEBUG("Driver destroyed"); 18 | for (auto module = modules.begin(); module != modules.end(); ++module) { 19 | cuModuleUnload(module->second); 20 | } 21 | for (auto mem = memory.begin(); mem != memory.end(); ++mem) { 22 | delete mem->second; 23 | } 24 | cuCtxDestroy(context); 25 | } 26 | 27 | int Driver::Compile(std::list sources, LinkerOptions &options) { 28 | Linker linker(options); 29 | linker.Run(sources); 30 | 31 | CUmodule module; 32 | auto result = cuModuleLoadData(&module, linker.cubin); 33 | if (result != CUDA_SUCCESS) throw DriverError(result); 34 | int moduleNo = modules.size() + 1; 35 | modules.insert(std::pair(moduleNo, module)); 36 | 37 | return moduleNo; 38 | } 39 | 40 | int Driver::LoadModule(std::string cubin, LinkerOptions &options) { 41 | Linker linker(options); 42 | CUmodule module; 43 | auto result = cuModuleLoadDataEx(&module, cubin.c_str(), linker.OptionsSize(), 44 | linker.OptionsKeys(), linker.OptionsValues()); 45 | if (result != CUDA_SUCCESS) throw DriverError(result); 46 | int moduleNo = modules.size() + 1; 47 | modules.insert(std::pair(moduleNo, module)); 48 | DEBUG("Load module: " << result); 49 | return moduleNo; 50 | } 51 | 52 | CUmodule Driver::GetModule(int id) { 53 | auto module = modules.find(id); 54 | if (module == modules.end()) return NULL; 55 | return module->second; 56 | } 57 | 58 | int Driver::LoadMemory(const void *src, size_t size) { 59 | DeviceMemory *mem = new DeviceMemory(src, size); 60 | int memNo = memory.size() + 1; 61 | memory.insert(std::pair(memNo, mem)); 62 | return memNo; 63 | } 64 | 65 | int Driver::LoadMemory(SharedMemory sharedMemory) { 66 | CUipcMemHandle handle; 67 | size_t size; 68 | std::tie(handle, size) = sharedMemory; 69 | auto mem = new DeviceMemory(handle, size); 70 | int memNo = memory.size() + 1; 71 | memory.insert(std::pair(memNo, mem)); 72 | return memNo; 73 | } 74 | 75 | void Driver::UnloadMemory(int id) { 76 | auto mem = memory.find(id); 77 | if (mem == memory.end()) throw StringError("Invalid memory handle"); 78 | delete mem->second; 79 | memory.erase(id); 80 | } 81 | 82 | void Driver::ReadMemory(int id, void *dst, int size) { 83 | auto mem = memory.find(id); 84 | if (mem == memory.end()) throw StringError("Invalid memory handle"); 85 | mem->second->Read(dst, size); 86 | } 87 | 88 | SharedMemory Driver::ShareMemory(int id) { 89 | auto mem = memory.find(id); 90 | if (mem == memory.end()) throw StringError("Invalid memory handle"); 91 | CUipcMemHandle handle; 92 | auto result = cuIpcGetMemHandle(&handle, mem->second->GetPtr()); 93 | if (result != CUDA_SUCCESS) throw DriverError(result); 94 | return std::make_tuple(handle, mem->second->GetSize()); 95 | } 96 | 97 | int Driver::GetMemorySize(int id) { 98 | auto mem = memory.find(id); 99 | if (mem == memory.end()) return -1; 100 | return mem->second->GetSize(); 101 | } 102 | 103 | DeviceMemory *Driver::GetMemory(int id) { 104 | auto mem = memory.find(id); 105 | if (mem == memory.end()) return NULL; 106 | return mem->second; 107 | } 108 | 109 | void Driver::Run(int moduleNo, RunParameters ¶ms, std::shared_ptr &args) { 110 | auto module = modules.find(moduleNo); 111 | if (module == modules.end()) throw StringError("Invalid module handle"); 112 | 113 | CUfunction func; 114 | std::string funcName; 115 | int gx, gy, gz, bx, by, bz; 116 | 117 | std::tie(funcName, gx, gy, gz, bx, by, bz) = params; 118 | auto result = cuModuleGetFunction(&func, module->second, funcName.c_str()); 119 | if (result != CUDA_SUCCESS) throw DriverError(result); 120 | // void **paramsPtr = args.empty() ? NULL : args.data(); 121 | 122 | result = cuLaunchKernel(func, gx, gy, gz, bx, by, bz, 0, 0, args->GetPtr(), 0); 123 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:execution"); 124 | } 125 | 126 | void Driver::Stream(int moduleNo, std::vector &batch) { 127 | auto module = modules.find(moduleNo); 128 | if (module == modules.end()) throw StringError("Invalid module handle"); 129 | 130 | CUstream stream; 131 | CUfunction func; 132 | CUresult result; 133 | std::string funcName; 134 | int gx, gy, gz, bx, by, bz; 135 | 136 | result = cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING); 137 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:stream_create"); 138 | 139 | for (auto it = batch.begin(); it != batch.end(); ++it) { 140 | std::shared_ptr args; 141 | RunParameters params; 142 | 143 | std::tie(params, args) = *it; 144 | std::tie(funcName, gx, gy, gz, bx, by, bz) = params; 145 | 146 | result = cuModuleGetFunction(&func, module->second, funcName.c_str()); 147 | if (result != CUDA_SUCCESS) throw DriverError(result); 148 | 149 | DEBUG("Launch DriverPort::Stream"); 150 | result = cuLaunchKernel(func, gx, gy, gz, bx, by, bz, 0, stream, args->GetPtr(), 0); 151 | DEBUG("Exit DriverPort::Stream"); 152 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:execution"); 153 | DEBUG("Exit 1 DriverPort::Stream"); 154 | } 155 | 156 | DEBUG("Wait DriverPort::Stream"); 157 | result = cuStreamSynchronize(stream); 158 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:stream_wait"); 159 | result = cuStreamDestroy(stream); 160 | if (result != CUDA_SUCCESS) throw DriverError(result, "Driver:stream_free"); 161 | } 162 | 163 | template <> DeviceMemory *Driver::Unpack(ETERM *value) { 164 | if (!ERL_IS_TUPLE(value) || erl_size(value) != 2) { 165 | throw StringError("Invalid memory handle"); 166 | } 167 | auto a = erl_element(1, value); 168 | auto v = erl_element(2, value); 169 | if (!ERL_IS_ATOM(a) || !ATOM_EQ(a, "memory")) { 170 | throw StringError("Invalid memory handle"); 171 | } 172 | auto mem = GetMemory(Get(v)); 173 | if (!mem) throw StringError("Invalid memory handle"); 174 | return mem; 175 | } 176 | 177 | template <> SharedMemory Driver::Unpack(ETERM *value) { 178 | if (!ERL_IS_TUPLE(value) || erl_size(value) != 2) { 179 | throw StringError("Invalid memory handle"); 180 | } 181 | auto a = erl_element(1, value); 182 | auto v = erl_element(2, value); 183 | if (!ERL_IS_ATOM(a) || !ATOM_EQ(a, "shared_memory")) { 184 | throw StringError("Invalid shared memory handle"); 185 | } 186 | if (!ERL_IS_TUPLE(v)) throw StringError("Invalid shared memory handle"); 187 | auto h = erl_element(1, v); 188 | auto size = Get(erl_element(2, v)); 189 | if (!ERL_IS_BINARY(h) || ERL_BIN_SIZE(h) != sizeof(CUipcMemHandle)) { 190 | throw StringError("Invalid shared memory handle"); 191 | } 192 | CUipcMemHandle *handle = (CUipcMemHandle *)ERL_BIN_PTR(h); 193 | return std::make_tuple(*handle, size); 194 | } 195 | 196 | template <> CUmodule Driver::Unpack(ETERM *value) { 197 | if (!ERL_IS_TUPLE(value) || erl_size(value) != 2) { 198 | throw StringError("Invalid module handle"); 199 | } 200 | auto a = erl_element(1, value); 201 | auto v = erl_element(2, value); 202 | if (!ERL_IS_ATOM(a) || !ATOM_EQ(a, "module")) { 203 | throw StringError("Invalid module handle"); 204 | } 205 | auto module = modules.find(Get(v)); 206 | if (module == modules.end()) throw StringError("Invalid memory handle"); 207 | return module->second; 208 | } 209 | 210 | ETERM *Driver::PackMemory(int idx) { 211 | return FORMAT("{~a,~i}", C_STR("memory"), idx); 212 | } 213 | 214 | ETERM *Driver::PackMemory(SharedMemory mem) { 215 | CUipcMemHandle handle; 216 | size_t size; 217 | std::tie(handle, size) = mem; 218 | return FORMAT("{~a,{~w,~i}}", C_STR("shared_memory"), 219 | erl_mk_binary((char *)&handle, sizeof(CUipcMemHandle)), 220 | size); 221 | } 222 | 223 | ETERM *Driver::PackModule(int idx) { 224 | return FORMAT("{~a,~i}", C_STR("module"), idx); 225 | } 226 | 227 | ETERM *CompileError::AsTerm() { 228 | const char *buf; 229 | auto err = cuGetErrorString(code, &buf); 230 | ETERM *errStr = err == CUDA_SUCCESS ? 231 | erl_mk_binary(buf, strlen(buf)) : 232 | MAKE_BINARY(""); 233 | return FORMAT("{~a,~w,~w,~w}", ERROR_STR, 234 | errStr, 235 | erl_mk_binary(infoLog.c_str(), infoLog.size()), 236 | erl_mk_binary(errorLog.c_str(), errorLog.size()) 237 | ); 238 | } 239 | 240 | Linker::Linker(LinkerOptions &options) { 241 | if (options.threadsPerBlock >= 0 && options.target >= 0) { 242 | throw StringError("threads_per_block linker option can not be used together with target option"); 243 | } 244 | 245 | int infoSize = options.infoSize > 0 ? options.infoSize : LINKER_BUFFER_SIZE; 246 | int errorSize = options.errorSize > 0 ? options.errorSize : LINKER_BUFFER_SIZE; 247 | infoLog = new char[infoSize]; 248 | errorLog = new char[errorSize]; 249 | 250 | optKeys.push_back(CU_JIT_WALL_TIME); 251 | optValues.push_back((void *)&walltime); 252 | optKeys.push_back(CU_JIT_INFO_LOG_BUFFER); 253 | optValues.push_back((void *)infoLog); 254 | optKeys.push_back(CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES); 255 | optValues.push_back((void *)(intptr_t)infoSize); 256 | optKeys.push_back(CU_JIT_ERROR_LOG_BUFFER); 257 | optValues.push_back((void *)errorLog); 258 | optKeys.push_back(CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES); 259 | optValues.push_back((void *)(intptr_t)errorSize); 260 | if (options.maxRegisters >= 0) { 261 | optKeys.push_back(CU_JIT_MAX_REGISTERS); 262 | optValues.push_back((void *)&options.maxRegisters); 263 | } 264 | if (options.optimizationLevel >= 0) { 265 | optKeys.push_back(CU_JIT_OPTIMIZATION_LEVEL); 266 | optValues.push_back((void *)&options.optimizationLevel); 267 | } 268 | if (options.threadsPerBlock >= 0) { 269 | threadsPerBlock = options.threadsPerBlock; 270 | optKeys.push_back(CU_JIT_THREADS_PER_BLOCK); 271 | optValues.push_back((void *)&threadsPerBlock); 272 | } else if (options.target >= 0) { 273 | optKeys.push_back(CU_JIT_TARGET); 274 | optValues.push_back((void *)&options.target); 275 | } 276 | if (options.debug >= 0) { 277 | optKeys.push_back(CU_JIT_GENERATE_DEBUG_INFO); 278 | optValues.push_back((void *)&options.debug); 279 | } 280 | if (options.verbose >= 0) { 281 | optKeys.push_back(CU_JIT_LOG_VERBOSE); 282 | optValues.push_back((void *)&options.verbose); 283 | } 284 | DEBUG("Linker initialized"); 285 | } 286 | 287 | Linker::~Linker() { 288 | DEBUG("Linker destroyed"); 289 | if (infoLog) delete infoLog; 290 | if (errorLog) delete errorLog; 291 | if (initialized) { 292 | auto result = cuLinkDestroy(state); 293 | if (result != CUDA_SUCCESS) throw DriverError(result); 294 | } 295 | } 296 | 297 | size_t Linker::OptionsSize() { 298 | // if (!initialized) throw StringError("Unintialized linker used"); 299 | return optKeys.size(); 300 | } 301 | 302 | CUjit_option *Linker::OptionsKeys() { 303 | // if (!initialized) throw StringError("Unintialized linker used"); 304 | return optKeys.data(); 305 | } 306 | 307 | void **Linker::OptionsValues() { 308 | // if (!initialized) throw StringError("Unintialized linker used"); 309 | return optValues.data(); 310 | } 311 | 312 | void Linker::Run(std::list sources) { 313 | // if (!initialized) throw StringError("Unintialized linker used"); 314 | CUresult result; 315 | 316 | result = cuLinkCreate(optKeys.size(), optKeys.data(), optValues.data(), &state); 317 | if (result != CUDA_SUCCESS) throw DriverError(result); 318 | initialized = true; 319 | 320 | for (auto it = std::begin(sources); it != std::end(sources); ++it) { 321 | result = cuLinkAddData(state, CU_JIT_INPUT_PTX, (void *)it->c_str(), it->size() + 1, 0, 0, 0, 0); 322 | if (result != CUDA_SUCCESS) throw CompileError(result, infoLog, errorLog); 323 | } 324 | 325 | result = cuLinkComplete(state, &cubin, &cubinSize); 326 | if (result != CUDA_SUCCESS) throw CompileError(result, infoLog, errorLog); 327 | } 328 | -------------------------------------------------------------------------------- /c_src/driver_port.cpp: -------------------------------------------------------------------------------- 1 | #include "driver_port.h" 2 | #include "utils.h" 3 | #include "driver.h" 4 | #include "commands.h" 5 | 6 | template <> LinkerOptions DriverPort::Unpack(ETERM *term) { 7 | LinkerOptions options; 8 | options.maxRegisters = -1; 9 | options.threadsPerBlock = -1; 10 | options.optimizationLevel = -1; 11 | options.target = -1; 12 | options.debug = -1; 13 | options.verbose = -1; 14 | options.infoSize = -1; 15 | options.errorSize = -1; 16 | if (!IS_NIL(term)) { 17 | auto opts = GetKeywords(term); 18 | auto it = opts.find("max_registers"); 19 | if (it != opts.end()) { 20 | options.maxRegisters = Get(it->second); 21 | if (options.maxRegisters < 0) throw StringError("Bad argument"); 22 | } 23 | it = opts.find("threads_per_block"); 24 | if (it != opts.end()) { 25 | options.threadsPerBlock = Get(it->second); 26 | if (options.threadsPerBlock < 0) throw StringError("Bad argument"); 27 | } 28 | it = opts.find("optimization_level"); 29 | if (it != opts.end()) { 30 | options.optimizationLevel = Get(it->second); 31 | if (options.optimizationLevel < 0 || options.optimizationLevel > 4) { 32 | throw StringError("Bad argument"); 33 | } 34 | } 35 | // TODO: target parsing here 36 | // it = opts.find("target"); 37 | it = opts.find("debug"); 38 | if (it != opts.end()) options.debug = Get(it->second) ? 1 : 0; 39 | it = opts.find("verbose"); 40 | if (it != opts.end()) options.verbose = Get(it->second) ? 1 : 0; 41 | } 42 | return options; 43 | } 44 | 45 | DriverPort::DriverPort(int device): ErlangPort() { 46 | try { 47 | auto result = cuInit(0); 48 | if (result != CUDA_SUCCESS) throw DriverError(result, "DriverPort initialize"); 49 | if (device < 0) device = BestDevice(); 50 | if (device < 0) throw StringError("Where are no GPU devices to initialize"); 51 | driver = new Driver(device); 52 | } catch (Error &error) { 53 | WriteTermPacket(error.AsTerm()); 54 | throw StringError("Initializing error"); 55 | } 56 | DEBUG("DriverPort initialized"); 57 | } 58 | 59 | DriverPort::~DriverPort() { 60 | DEBUG("DriverPort destroyed"); 61 | if (driver) delete driver; 62 | } 63 | 64 | ETERM *DriverPort::HandleTermFunction(std::string name, ETERM *arg) { 65 | if (name == "compile") return Compile(arg); 66 | if (name == "memory_read") return MemoryRead(arg); 67 | if (name == "memory_unload") return MemoryUnload(arg); 68 | if (name == "memory_share") return MemoryShare(arg); 69 | if (name == "memory_load") return MemoryLoad(arg); 70 | if (name == "module_load") return ModuleLoad(arg); 71 | if (name == "run") return Run(arg); 72 | if (name == "stream") return Stream(arg); 73 | if (name == "device_info") return DeviceInfo(); 74 | return NULL; 75 | } 76 | 77 | ETERM *DriverPort::HandleRawFunction(std::string name, RawData &data, size_t size) { 78 | if (name == "memory_load") return MemoryLoad(data, size); 79 | return NULL; 80 | } 81 | 82 | ETERM *DriverPort::Compile(ETERM *arg) { 83 | DEBUG("Enter DriverPort::Compile"); 84 | if (!ERL_IS_TUPLE(arg) || erl_size(arg) != 2) throw StringError("Bad argument"); 85 | 86 | auto sourcesArg = erl_element(1, arg); 87 | if (!ERL_IS_LIST(sourcesArg)) throw StringError("Bad argument"); 88 | std::list sources; 89 | auto size = erl_length(sourcesArg); 90 | if (size < 1) throw StringError("Bad argument"); 91 | for (int i = 0; i < size; i++) { 92 | auto srcBin = erl_hd(sourcesArg); 93 | sourcesArg = erl_tl(sourcesArg); 94 | if (!ERL_IS_BINARY(srcBin)) throw StringError("Bad argument"); 95 | std::string src((char *)ERL_BIN_PTR(srcBin), erl_size(srcBin)); 96 | sources.push_back(src); 97 | } 98 | 99 | auto options = Unpack(erl_element(2, arg)); 100 | auto module = driver->Compile(sources, options); 101 | return FORMAT("{ok,~w}", driver->PackModule(module)); 102 | } 103 | 104 | ETERM *DriverPort::MemoryRead(ETERM *arg) { 105 | DEBUG("Enter DriverPort::MemoryRead"); 106 | auto n = GetMemoryIndex(arg); 107 | auto size = driver->GetMemorySize(n); 108 | if (size < 0) throw StringError("Invalid memory handle"); 109 | char *data = NULL; 110 | cuMemAllocHost((void **)&data, size); 111 | driver->ReadMemory(n, (void *)data); 112 | WriteRawPacket((void *)data, size); 113 | return NULL; 114 | } 115 | 116 | ETERM *DriverPort::MemoryUnload(ETERM *arg) { 117 | DEBUG("Enter DriverPort::MemoryUnload"); 118 | auto n = GetMemoryIndex(arg); 119 | driver->UnloadMemory(n); 120 | return erl_mk_atom(OK_STR); 121 | } 122 | 123 | ETERM *DriverPort::MemoryShare(ETERM *arg) { 124 | DEBUG("Enter DriverPort::MemoryShare"); 125 | auto n = GetMemoryIndex(arg); 126 | auto mem = driver->ShareMemory(n); 127 | return FORMAT("{ok,~w}", driver->PackMemory(mem)); 128 | } 129 | 130 | ETERM *DriverPort::MemoryLoad(ETERM *arg) { 131 | DEBUG("Enter DriverPort::MemoryLoad"); 132 | auto mem = driver->Unpack(arg); 133 | int n = driver->LoadMemory(mem); 134 | return FORMAT("{ok,~w}", driver->PackMemory(n)); 135 | } 136 | 137 | ETERM *DriverPort::ModuleLoad(ETERM *arg) { 138 | DEBUG("Enter DriverPort::ModuleLoad"); 139 | if (!ERL_IS_TUPLE(arg) || erl_size(arg) != 2) throw StringError("Bad argument"); 140 | 141 | auto srcArg = erl_element(1, arg); 142 | if (!ERL_IS_BINARY(srcArg)) throw StringError("Bad argument"); 143 | std::string src((char *)ERL_BIN_PTR(srcArg), erl_size(srcArg)); 144 | 145 | auto options = Unpack(erl_element(2, arg)); 146 | auto module = driver->LoadModule(src, options); 147 | return FORMAT("{ok,~w}", driver->PackModule(module)); 148 | } 149 | 150 | ETERM *DriverPort::Run(ETERM *arg) { 151 | DEBUG("Enter DriverPort::Run"); 152 | if (!ERL_IS_TUPLE(arg)) throw StringError("Bad argument"); 153 | auto argc = erl_size(arg); 154 | if (argc < 2) throw StringError("Bad argument"); 155 | auto moduleTerm = erl_element(1, arg); 156 | auto commandTerm = erl_element(2, arg); 157 | 158 | auto module = GetModuleIndex(moduleTerm); 159 | Commands::Context ctx; 160 | ctx.module = driver->GetModule(module); 161 | auto cmd = Commands::Command::Create(driver, commandTerm); 162 | cmd->Run(ctx); 163 | DEBUG("Leave DriverPort::Run"); 164 | return erl_mk_atom(OK_STR); 165 | } 166 | 167 | ETERM *DriverPort::Stream(ETERM *arg) { 168 | DEBUG("Enter DriverPort::Stream"); 169 | if (!ERL_IS_TUPLE(arg)) throw StringError("Bad argument"); 170 | auto argc = erl_size(arg); 171 | if (argc < 2) throw StringError("Bad argument"); 172 | auto moduleTerm = erl_element(1, arg); 173 | auto commandTerm = erl_element(2, arg); 174 | if (!ERL_IS_LIST(commandTerm)) throw StringError("Bad argument"); 175 | 176 | auto module = GetModuleIndex(moduleTerm); 177 | Commands::Context ctx; 178 | ctx.module = driver->GetModule(module); 179 | auto cmd = Commands::Command::Create(driver, commandTerm); 180 | cmd->Run(ctx); 181 | 182 | return erl_mk_atom(OK_STR); 183 | } 184 | 185 | ETERM *DriverPort::MemoryLoad(RawData &data, size_t size) { 186 | DEBUG("Enter raw DriverPort::MemoryLoad"); 187 | int n = driver->LoadMemory(data.get(), size); 188 | return FORMAT("{ok,~w}", driver->PackMemory(n)); 189 | } 190 | 191 | std::shared_ptr DriverPort::UnpackRunArguments(ETERM *term) { 192 | std::shared_ptr args = std::make_shared(RunArguments()); 193 | 194 | if (!ERL_IS_LIST(term)) throw StringError("Bad argument"); 195 | auto s = erl_length(term); 196 | for (int i = 0; i < s; i++) { 197 | auto param = erl_hd(term); 198 | term = erl_tl(term); 199 | if (ERL_IS_TUPLE(param)) { 200 | auto param_type = erl_element(1, param); 201 | auto param_value = erl_element(2, param); 202 | if (ERL_IS_ATOM(param_type) && ATOM_EQ(param_type, "memory")) { 203 | auto mem = driver->GetMemory(Get(param_value)); 204 | if (!mem) throw StringError("Invalid memory handle"); 205 | args->Add(*mem); 206 | } 207 | } else if (ERL_IS_INTEGER(param)) { 208 | args->Add(ERL_INT_VALUE(param)); 209 | } else if (ERL_IS_FLOAT(param)) { 210 | float f = ERL_FLOAT_VALUE(param); 211 | args->Add(f); 212 | } else { 213 | throw StringError("Bad argument"); 214 | } 215 | } 216 | return args; 217 | } 218 | 219 | ETERM *DriverPort::DeviceInfo() { 220 | int v[44]; 221 | CUdevice_attribute c[44] = { 222 | CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK, 223 | CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, 224 | CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, 225 | CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, 226 | CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, 227 | CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, 228 | CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, 229 | CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, 230 | CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY, 231 | CU_DEVICE_ATTRIBUTE_WARP_SIZE, 232 | CU_DEVICE_ATTRIBUTE_MAX_PITCH, 233 | CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, 234 | CU_DEVICE_ATTRIBUTE_CLOCK_RATE, 235 | CU_DEVICE_ATTRIBUTE_GPU_OVERLAP, 236 | CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, 237 | CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT, 238 | CU_DEVICE_ATTRIBUTE_INTEGRATED, 239 | CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY, 240 | CU_DEVICE_ATTRIBUTE_COMPUTE_MODE, 241 | CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS, 242 | CU_DEVICE_ATTRIBUTE_ECC_ENABLED, 243 | CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, 244 | CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, 245 | CU_DEVICE_ATTRIBUTE_TCC_DRIVER, 246 | CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, 247 | CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, 248 | CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, 249 | CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR, 250 | CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, 251 | CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, 252 | CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, 253 | CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED, 254 | CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED, 255 | CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, 256 | CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR, 257 | CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY, 258 | CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD, 259 | CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID, 260 | CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED, 261 | CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO, 262 | CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS, 263 | CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, 264 | CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED, 265 | CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM, 266 | }; 267 | const char *b[18]; 268 | CUresult x; 269 | CUdevice d = driver->GetHandle(); 270 | for (int i = 0; i < 44; i++) { 271 | x = cuDeviceGetAttribute(&v[i], c[i], d); 272 | if (x != CUDA_SUCCESS) throw DriverError(x); 273 | } 274 | b[0] = v[13] == 1 ? "true" : "false"; 275 | b[1] = v[15] == 1 ? "true" : "false"; 276 | b[2] = v[16] == 1 ? "true" : "false"; 277 | b[3] = v[17] == 1 ? "true" : "false"; 278 | b[4] = "unknown"; 279 | switch (v[18]) { 280 | case CU_COMPUTEMODE_DEFAULT: b[4] = "default"; break; 281 | case CU_COMPUTEMODE_PROHIBITED: b[4] = "prohibited"; break; 282 | case CU_COMPUTEMODE_EXCLUSIVE_PROCESS: b[4] = "exclusive_process"; break; 283 | } 284 | b[5] = v[19] == 1 ? "true" : "false"; 285 | b[6] = v[20] == 1 ? "true" : "false"; 286 | b[7] = v[23] == 1 ? "true" : "false"; 287 | b[8] = v[28] == 1 ? "true" : "false"; 288 | b[9] = v[31] == 1 ? "true" : "false"; 289 | b[10] = v[32] == 1 ? "true" : "false"; 290 | b[11] = v[35] == 1 ? "true" : "false"; 291 | b[12] = v[36] == 1 ? "true" : "false"; 292 | b[13] = v[38] == 1 ? "true" : "false"; 293 | b[14] = v[40] == 1 ? "true" : "false"; 294 | b[15] = v[41] == 1 ? "true" : "false"; 295 | b[16] = v[42] == 1 ? "true" : "false"; 296 | b[17] = v[43] == 1 ? "true" : "false"; 297 | return FORMAT( 298 | "{ok,[" 299 | "{max_threads_per_block,~i}," 300 | "{max_block,{~i,~i,~i}}," 301 | "{max_grid,{~i,~i,~i}}," 302 | "{max_shared_memory_per_block,~i}," 303 | "{total_constant_memory,~i}," 304 | "{warp_size,~i}," 305 | "{max_pitch,~i}," 306 | "{max_registers_per_block,~i}," 307 | "{clock_rate,~i}," 308 | "{gpu_overlap,~a}," 309 | "{miltiprocessor_count,~i}," 310 | "{kernel_exec_timeout,~a}," 311 | "{integrated,~a}," 312 | "{can_map_host_memory,~a}," 313 | "{compute_mode,~a}," 314 | "{concurrent_kernels,~a}," 315 | "{ecc_enabled,~a}," 316 | "{pci_bus_id,~i}," 317 | "{pci_device_id,~i}," 318 | "{tcc_driver,~a}," 319 | "{memory_clock_rate,~i}," 320 | "{global_memory_bus_width,~i}," 321 | "{l2_cache_size,~i}," 322 | "{max_threads_per_multiprocessor,~i}," 323 | "{unified_arressing,~a}," 324 | "{compute_capability,{~i,~i}}," 325 | "{global_l1_cache_supported,~a}," 326 | "{glocal_l1_cache_supported,~a}," 327 | "{max_shared_memory_per_multiprocessor,~i}," 328 | "{max_registers_per_multiprocessor,~i}," 329 | "{managed_memory,~a}," 330 | "{multi_gpu_board,~a}," 331 | "{multi_gpu_board_group_id,~i}," 332 | "{host_native_atomic_supported,~a}," 333 | "{single_to_double_precision_perf_ratio,~i}," 334 | "{pageable_memory_access,~a}," 335 | "{concurrent_managed_access,~a}," 336 | "{compute_preemption_supported,~a}," 337 | "{can_use_host_pointer_for_registered_mem,~a}" 338 | "]}", 339 | v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], v[8], v[9], v[10], v[11], 340 | v[12], b[0], v[14], b[1], b[2], b[3], b[4], b[5], b[6], v[21], v[22], b[7], 341 | v[24], v[25], v[26], v[27], b[8], v[29], v[30], b[9], b[10], v[33], v[34], 342 | b[11], b[12], v[37], b[13], v[39], b[14], b[15], b[16], b[17]); 343 | } 344 | -------------------------------------------------------------------------------- /lib/cuda/memory.ex: -------------------------------------------------------------------------------- 1 | defmodule Cuda.Memory do 2 | require Logger 3 | alias Cuda.Float16 4 | 5 | defmodule Shape do 6 | defstruct [:type, skip: 0] 7 | def new(x), do: x 8 | end 9 | 10 | defstruct type: :owned, vars: [] 11 | 12 | def new(vars, type \\ :owned) do 13 | {vars, _} = Enum.reduce(vars, {[], 0}, fn 14 | {k, {o, _} = v}, {vars, offset} when is_integer(o) -> 15 | {vars ++ [{k, v}], offset} 16 | {k, type}, {vars, offset} -> 17 | {vars ++ [{k, {offset, type}}], offset + size(type)} 18 | end) 19 | %__MODULE__{vars: vars, type: type} 20 | end 21 | 22 | def get(%__MODULE__{vars: vars}, key, default \\ nil) do 23 | with {_, v} <- get_var(vars, key, default), do: v 24 | end 25 | 26 | def has_key?(%__MODULE__{vars: vars}, key) do 27 | Enum.any?(vars, fn 28 | {^key, _} -> true 29 | _ -> false 30 | end) 31 | end 32 | 33 | defp get_var(vars, key, default) do 34 | with {_, v} <- Enum.find(vars, fn {k, _} -> k == key end) do 35 | v 36 | else 37 | _ -> default 38 | end 39 | end 40 | 41 | def offset(nil, _), do: nil 42 | def offset(_, nil), do: nil 43 | def offset(%__MODULE__{vars: vars}, [field | rest]) do 44 | case Keyword.get(vars, field) do 45 | {offset, type} -> 46 | with n when is_integer(n) <- offset(type, rest) do 47 | offset + n 48 | end 49 | _ -> 50 | nil 51 | end 52 | end 53 | def offset(%__MODULE__{vars: vars}, field) do 54 | case Keyword.get(vars, field) do 55 | {offset, _} -> offset 56 | _ -> nil 57 | end 58 | end 59 | def offset(shape, [field | rest]) when is_map(shape) do 60 | result = shape |> Enum.reduce({:not_found, 0}, fn 61 | {^field, _}, {:not_found, offset} -> {:ok, offset} 62 | {_, type}, {:not_found, offset} -> {:not_found, offset + offset(type, rest)} 63 | _, result -> result 64 | end) 65 | with {:ok, offset} <- result do 66 | offset 67 | else 68 | _ -> nil 69 | end 70 | end 71 | def offset(shape, []), do: size(shape) 72 | def offset(shape, field), do: offset(shape, [field]) 73 | 74 | def inspect_structure(%__MODULE__{} = mem, opts \\ []) do 75 | vars = mem.vars |> Enum.map(fn {k, v} -> {Cuda.Graph.Node.string_id(k), v} end) 76 | w1 = vars |> Enum.map(& elem(&1, 0)) |> Enum.map(&String.length/1) |> Enum.max() 77 | w2 = vars |> Enum.map(& elem(elem(&1, 1), 0)) |> Enum.map(&String.length("#{&1}")) |> Enum.max() 78 | vars = vars 79 | |> Enum.map(fn {k, {o, t}} -> 80 | name = String.pad_trailing(k, w1, " ") 81 | offset = String.pad_leading("#{o}", w2, " ") 82 | "#{name} | #{offset} | #{inspect t}" 83 | end) 84 | |> Enum.join("\n") 85 | label = case Keyword.get(opts, :label) do 86 | nil -> "" 87 | label -> " #{inspect(label)}" 88 | end 89 | IO.puts("\nMemory#{label}. Type: #{mem.type}\n#{vars}") 90 | mem 91 | end 92 | 93 | def arity({_, arity}) do 94 | size(arity) 95 | end 96 | def arity(_), do: 1 97 | 98 | def type({t, _}), do: t 99 | def type(t), do: t 100 | 101 | def pack(:zero, t) when is_atom(t) or is_bitstring(t), do: pack(0, t) 102 | def pack(_, {:skip, bytes}) when bytes < 0, do: <<>> 103 | def pack(_, {:skip, bytes}), do: <<0::unit(8)-size(bytes)>> 104 | def pack(x, :i8), do: <> 105 | def pack(x, :i16), do: <> 106 | def pack(x, :i32), do: <> 107 | def pack(x, :i64), do: <> 108 | def pack(x, "i8"), do: pack(x, :i8) 109 | def pack(x, "i16"), do: pack(x, :i16) 110 | def pack(x, "i32"), do: pack(x, :i32) 111 | def pack(x, "i64"), do: pack(x, :i64) 112 | def pack(x, :u8), do: <> 113 | def pack(x, :u16), do: <> 114 | def pack(x, :u32), do: <> 115 | def pack(x, :u64), do: <> 116 | def pack(x, "u8"), do: pack(x, :u8) 117 | def pack(x, "u16"), do: pack(x, :u16) 118 | def pack(x, "u32"), do: pack(x, :u32) 119 | def pack(x, "u64"), do: pack(x, :u64) 120 | def pack(x, :f16), do: Float16.pack(x) 121 | def pack(x, :f32), do: <> 122 | def pack(x, :f64), do: <> 123 | def pack(x, "f16"), do: pack(x, :f16) 124 | def pack(x, "f32"), do: pack(x, :f32) 125 | def pack(x, "f64"), do: pack(x, :f64) 126 | def pack(x, {type, arity}) when not is_tuple(arity), do: pack(x, {type, {arity}}) 127 | def pack(x, {type, arity}) when is_list(x) do 128 | arity = size(arity) 129 | x = List.flatten(x) 130 | if length(x) == arity do 131 | x |> Enum.map(& pack(&1, type)) |> Enum.join 132 | else 133 | raise RuntimeError, message: "Arity of array #{inspect x} should be #{arity}" 134 | end 135 | end 136 | def pack(:zero, {type, arity}) do 137 | size = size(arity) * size(type) 138 | <<0::unit(8)-size(size)>> 139 | end 140 | def pack(x, types) when is_list(types) and is_list(x) do 141 | x = case length(types) - length(x) do 142 | n when n > 0 -> x ++ List.duplicate(0, n) 143 | _ -> x 144 | end 145 | x 146 | |> Enum.zip(types) 147 | |> Enum.map(fn {x, type} -> pack(x, type) end) 148 | |> Enum.join() 149 | end 150 | def pack(:zero, types) when is_list(types) do 151 | types |> Enum.map(& pack(:zero, &1)) |> Enum.join() 152 | end 153 | def pack(x, types) when is_list(types), do: pack([x], types) 154 | def pack(x, %__MODULE__{vars: vars}) when is_map(x) do 155 | vars 156 | |> Enum.map(fn {k, {_offset, type}} -> 157 | #IO.inspect({k, type, Map.get(x, k, :zero)}) 158 | pack(Map.get(x, k, :zero), type) 159 | end) 160 | |> Enum.join() 161 | end 162 | def pack(%{} = x, %Shape{type: type, skip: %{} = skip}) do 163 | x 164 | |> Map.to_list() 165 | |> Enum.reduce("", fn {key, val}, acc -> 166 | t = Map.get(type, key) 167 | s = Map.get(skip, key) 168 | acc <> pack(val, %Shape{type: t, skip: s}) 169 | end) 170 | end 171 | def pack(x, %Shape{type: type, skip: {sbefore, safter}}) do 172 | pack(0, {:skip, sbefore}) <> pack(x, type) <> pack(0, {:skip, safter}) 173 | end 174 | def pack(x, %Shape{type: type, skip: skip}) when is_integer(skip) do 175 | pack(x, type) <> pack(0, {:skip, skip}) 176 | end 177 | def pack(x, types) when is_map(types) and is_map(x) do 178 | types 179 | |> Enum.map(fn {k, type} -> 180 | with {:ok, v} <- Map.fetch(x, k) do 181 | pack(v, type) 182 | else 183 | _ -> raise RuntimeError, message: "Coudn't find value for key #{k}" 184 | end 185 | end) 186 | |> Enum.join() 187 | end 188 | def pack(:zero, types) when is_map(types) do 189 | types |> Enum.map(fn {_, type} -> pack(:zero, type) end) |> Enum.join() 190 | end 191 | def pack(nil, type) do 192 | Logger.warn("Attempt to pack `nil` value for type #{inspect type}") 193 | end 194 | def pack(_, _), do: <<>> 195 | 196 | def unpack(<>, :i8), do: x 197 | def unpack(<>, :i16), do: x 198 | def unpack(<>, :i32), do: x 199 | def unpack(<>, :i64), do: x 200 | def unpack(x, "i8"), do: unpack(x, :i8) 201 | def unpack(x, "i16"), do: unpack(x, :i16) 202 | def unpack(x, "i32"), do: unpack(x, :i32) 203 | def unpack(x, "i64"), do: unpack(x, :i64) 204 | def unpack(<>, :u8), do: x 205 | def unpack(<>, :u16), do: x 206 | def unpack(<>, :u32), do: x 207 | def unpack(<>, :u64), do: x 208 | def unpack(x, "u8"), do: unpack(x, :u8) 209 | def unpack(x, "u16"), do: unpack(x, :u16) 210 | def unpack(x, "u32"), do: unpack(x, :u32) 211 | def unpack(x, "u64"), do: unpack(x, :u64) 212 | def unpack(x, :f16), do: Float16.unpack(x) 213 | def unpack(<>, :f32), do: x 214 | def unpack(<>, :f64), do: x 215 | def unpack(x, "f16"), do: unpack(x, :f16) 216 | def unpack(x, "f32"), do: unpack(x, :f32) 217 | def unpack(x, "f64"), do: unpack(x, :f64) 218 | def unpack(x, {type, arity}) when is_tuple(arity) do 219 | arity = arity |> Tuple.to_list |> Enum.reverse 220 | {list, _} = unpack_list(x, {type, arity}) 221 | list 222 | |> unp_flat() 223 | end 224 | def unpack(x, {type, arity}) when not is_tuple(arity) do 225 | {list, _} = unpack_list(x, {type, [arity]}) 226 | list 227 | |> unp_flat() 228 | end 229 | def unpack(x, types) when is_list(types) do 230 | {list, _} = types |> Enum.reduce({[], x}, fn 231 | type, {list, rest} -> 232 | {data, rest} = unpack_list(rest, type) 233 | {list ++ data, rest} 234 | _, acc -> 235 | acc 236 | end) 237 | list 238 | |> unp_flat() 239 | end 240 | def unpack(x, %__MODULE__{vars: vars}) do 241 | vars 242 | |> Enum.map(fn {k, {offset, type}} -> 243 | x = x 244 | |> binary_part(offset, size(type)) 245 | |> unpack(type) 246 | {k, x} 247 | end) 248 | |> Enum.into(%{}) 249 | end 250 | def unpack(x, %Shape{type: %{} = type, skip: %{} = skip}) do 251 | type = type 252 | |> Map.to_list() 253 | |> Enum.map(fn {key, val} -> 254 | val = is_list(val) && val || [val] 255 | case Map.get(skip, key, 0) do 256 | {sbefore, safter} -> {key, [{:skip, sbefore} | val] ++ [{:skip, safter}]} 257 | 0 -> {key, val} 258 | s -> {key, val ++ [{:skip, s}]} 259 | end 260 | end) 261 | |> Enum.into(%{}) 262 | x 263 | |> unpack(type) 264 | end 265 | def unpack(x, %Shape{type: type, skip: {sbefore, safter}}) do 266 | size = byte_size(x) - sbefore - safter 267 | x 268 | |> binary_part(sbefore, size) 269 | |> unpack(type) 270 | end 271 | def unpack(x, %Shape{type: type, skip: skip}) when is_integer(skip) and skip < 0 do 272 | unpack(x, type) 273 | end 274 | def unpack(x, %Shape{type: type, skip: skip}) when is_integer(skip) do 275 | size = byte_size(x) - skip 276 | <> = x 277 | unpack(data, type) 278 | end 279 | def unpack(x, types) when is_map(types) do 280 | {list, _} = types |> Enum.reduce({%{}, x}, fn 281 | {k, type}, {map, rest} -> 282 | {[data], rest} = unpack_list(rest, type) 283 | {Map.put(map, k, data), rest} 284 | _, acc -> 285 | acc 286 | end) 287 | list 288 | |> unp_flat() 289 | end 290 | def unpack(_, _), do: nil 291 | 292 | defp unpack_list(x, {type, [arity]}) do 293 | size = size(type) 294 | Enum.reduce(1..arity, {[], x}, fn 295 | _, {list, <>} -> 296 | data = [unpack(x, type)] 297 | {list ++ data, rest} 298 | _, acc -> 299 | acc 300 | end) 301 | end 302 | defp unpack_list(x, {type, [current | arity]}) do 303 | Enum.reduce(1..current, {[], x}, fn 304 | _, {list, rest} -> 305 | {data, rest} = unpack_list(rest, {type, arity})# |> IO.inspect 306 | {list ++ [data], rest} 307 | _, acc -> 308 | acc 309 | end) 310 | end 311 | defp unpack_list(x, {:skip, bytes}) when bytes < 0 do 312 | {[], x} 313 | end 314 | defp unpack_list(x, {:skip, bytes}) do 315 | <<_::binary-size(bytes), rest::binary>> = x 316 | {[], rest} 317 | end 318 | defp unpack_list(x, type) do 319 | size = size(type) 320 | #IO.inspect({x, type, size, byte_size(x)}) 321 | <> = x 322 | {[unpack(x, type)], rest} 323 | end 324 | 325 | defp unp_flat(%{} = value) do 326 | value 327 | |> Map.to_list() 328 | |> Enum.map(fn {k, v} -> {k, unp_flat(v)} end) 329 | |> Enum.into(%{}) 330 | end 331 | defp unp_flat([value]) when is_list(value), do: unp_flat(value) 332 | defp unp_flat(value), do: value 333 | 334 | @type_re ~r/(\d+)/ 335 | def size(:i8), do: 1 336 | def size(:i16), do: 2 337 | def size(:i32), do: 4 338 | def size(:i64), do: 8 339 | def size(:u8), do: 1 340 | def size(:u16), do: 2 341 | def size(:u32), do: 4 342 | def size(:u64), do: 8 343 | def size(:f16), do: 2 344 | def size(:f32), do: 4 345 | def size(:f64), do: 8 346 | def size({:skip, n}) when n < 0, do: 0 347 | def size({:skip, n}), do: n 348 | def size(type) when is_atom(type) or is_bitstring(type) do 349 | case Regex.run(@type_re, "#{type}", capture: :all_but_first) do 350 | [n] -> div(String.to_integer(n), 8) 351 | _ -> 0 352 | end 353 | end 354 | def size(tuple) when is_tuple(tuple) do 355 | tuple 356 | |> Tuple.to_list 357 | |> Enum.map(&size/1) 358 | |> Enum.reduce(1, &Kernel.*/2) 359 | end 360 | def size(i) when is_integer(i), do: i 361 | def size(l) when is_list(l) do 362 | l |> Enum.map(&size/1) |> Enum.reduce(0, &+/2) 363 | end 364 | def size(%__MODULE__{vars: vars}) do 365 | vars |> Enum.reduce(0, fn {_, {_, t}}, a -> a + size(t) end) 366 | end 367 | def size(%Shape{type: nil}), do: 0 368 | def size(%Shape{type: type, skip: %{} = skip}) do 369 | ssize = skip 370 | |> Map.to_list() 371 | |> Enum.reduce(0, fn 372 | {_, {v1, v2}}, acc -> acc + v1 + v2 373 | {_, v}, acc -> acc + v 374 | end) 375 | size(type) + ssize 376 | end 377 | def size(%Shape{type: type, skip: {sbefore, safter}}) do 378 | size({:skip, sbefore + safter}) + size(type) 379 | end 380 | def size(%Shape{type: type, skip: skip}) when is_integer(skip) do 381 | size({:skip, skip}) + size(type) 382 | end 383 | def size(m) when is_map(m) do 384 | m 385 | |> Enum.map(fn {_, v} -> size(v) end) 386 | |> Enum.reduce(0, &+/2) 387 | end 388 | def size(_), do: 0 389 | 390 | def size_equal?(x, y) do 391 | size(x) == size(y) 392 | end 393 | end 394 | --------------------------------------------------------------------------------